Retrieval Pipeline Plugin¶
Implement a custom retrieval pipeline and distribute it as a plugin.
Overview¶
| Field | Value |
|---|---|
| Entry Point Group | autorag_research.pipelines |
| Config Base Class | BaseRetrievalPipelineConfig |
| Pipeline Base Class | BaseRetrievalPipeline |
| Config Module | autorag_research.config |
| Pipeline Module | autorag_research.pipelines.retrieval.base |
A retrieval pipeline plugin consists of two classes: a config dataclass that declares parameters and a pipeline class that implements search logic. The config tells the executor how to build the pipeline; the pipeline performs the actual retrieval.
Scaffold¶
Generate boilerplate with the CLI:
autorag-research plugin create my_search --type=retrieval
This creates a ready-to-edit package under my_search_plugin/ with config,
pipeline, YAML, and pyproject.toml pre-wired.
Config Class¶
Subclass BaseRetrievalPipelineConfig and define your custom parameters as
dataclass fields. Implement get_pipeline_class() and get_pipeline_kwargs().
from dataclasses import dataclass, field
from typing import Any
from autorag_research.config import BaseRetrievalPipelineConfig, PipelineType
@dataclass(kw_only=True)
class MySearchPipelineConfig(BaseRetrievalPipelineConfig):
"""Configuration for MySearch retrieval pipeline."""
pipeline_type: PipelineType = field(default=PipelineType.RETRIEVAL, init=False)
# Add custom config fields
index_path: str = "/data/index"
similarity_threshold: float = 0.5
def get_pipeline_class(self) -> type["MySearchPipeline"]:
return MySearchPipeline
def get_pipeline_kwargs(self) -> dict[str, Any]:
return {
"index_path": self.index_path,
"similarity_threshold": self.similarity_threshold,
}
Inherited Fields¶
Every retrieval config inherits these fields from BasePipelineConfig:
| Field | Type | Default | Description |
|---|---|---|---|
name |
str |
required | Unique pipeline instance name |
description |
str |
"" |
Optional description |
top_k |
int |
10 |
Results per query |
batch_size |
int |
128 |
Queries per DB batch |
max_concurrency |
int |
16 |
Max concurrent async operations |
max_retries |
int |
3 |
Retry attempts for failed queries |
retry_delay |
float |
1.0 |
Base delay for exponential backoff |
Abstract Methods¶
| Method | Returns | Purpose |
|---|---|---|
get_pipeline_class() |
type[BaseRetrievalPipeline] |
Pipeline class to instantiate |
get_pipeline_kwargs() |
dict[str, Any] |
Custom kwargs passed to the pipeline constructor (beyond session_factory, name, schema which are injected automatically) |
get_run_kwargs() |
dict[str, Any] |
Already implemented by BaseRetrievalPipelineConfig -- returns top_k, batch_size, max_concurrency, max_retries, and retry_delay |
You must implement get_pipeline_class() and get_pipeline_kwargs().
get_run_kwargs() is provided by the base class and normally does not need
overriding.
Pipeline Class¶
Subclass BaseRetrievalPipeline and implement the three abstract methods.
from typing import Any
from sqlalchemy.orm import Session, sessionmaker
from autorag_research.pipelines.retrieval.base import BaseRetrievalPipeline
class MySearchPipeline(BaseRetrievalPipeline):
"""MySearch retrieval pipeline."""
def __init__(
self,
session_factory: sessionmaker[Session],
name: str,
schema: Any | None = None,
index_path: str = "/data/index",
similarity_threshold: float = 0.5,
):
super().__init__(session_factory, name, schema)
self.index_path = index_path
self.similarity_threshold = similarity_threshold
def _get_pipeline_config(self) -> dict[str, Any]:
return {
"type": "my_search",
"index_path": self.index_path,
"similarity_threshold": self.similarity_threshold,
}
async def _retrieve_by_id(self, query_id: int | str, top_k: int) -> list[dict[str, Any]]:
"""Retrieve using query ID (query exists in database).
Used for batch processing where queries have pre-computed embeddings.
"""
# Access query embedding from DB via self._service
# Perform your search logic
return [{"doc_id": chunk_id, "score": score}]
async def _retrieve_by_text(self, query_text: str, top_k: int) -> list[dict[str, Any]]:
"""Retrieve using raw query text (may need to compute embedding).
Used for ad-hoc retrieval and by generation pipelines.
"""
# Compute embedding on-the-fly if needed
# Perform your search logic
return [{"doc_id": chunk_id, "score": score}]
Abstract Methods¶
| Method | When Called | Use Case |
|---|---|---|
_retrieve_by_id(query_id, top_k) |
pipeline.run() batch processing |
Queries exist in DB with stored embeddings |
_retrieve_by_text(query_text, top_k) |
pipeline.retrieve() single query |
Ad-hoc queries, used by generation pipelines |
_get_pipeline_config() |
Pipeline initialization | Returns dict stored in DB for reproducibility |
Return Format¶
Both _retrieve_by_id and _retrieve_by_text return a list of dicts. Each dict
contains:
| Key | Type | Description |
|---|---|---|
doc_id |
int \| str |
Chunk ID in the database |
score |
float |
Relevance score (higher is better) |
The base class handles persisting these results to ChunkRetrievedResult rows
automatically.
YAML Configuration¶
Place a YAML file inside your plugin package. The executor loads it via Hydra-style instantiation.
# src/my_search_plugin/retrieval/my_search.yaml
_target_: my_search_plugin.pipeline.MySearchPipelineConfig
description: "MySearch retrieval pipeline"
name: my_search
top_k: 10
batch_size: 128
max_concurrency: 16
index_path: /data/index
similarity_threshold: 0.5
_target_ is the fully-qualified class name of your config dataclass. The
remaining keys map directly to dataclass fields. When the executor loads this
file, it instantiates MySearchPipelineConfig with these values.
Entry Points¶
Register your plugin in pyproject.toml so the framework discovers it at
runtime:
[project.entry-points."autorag_research.pipelines"]
my_search = "my_search_plugin"
The key (my_search) is the plugin name used in plugin sync. The value is the
top-level package that contains your YAML configs.
After installing the package, run:
autorag-research plugin sync
This copies your YAML files into the project's configs/ directory.
Testing¶
Test the config independently of the database:
from my_search_plugin.pipeline import MySearchPipelineConfig
def test_config():
config = MySearchPipelineConfig(name="my_search")
assert config.name == "my_search"
assert config.get_pipeline_class() is not None
def test_config_custom_fields():
config = MySearchPipelineConfig(
name="my_search",
index_path="/custom/index",
similarity_threshold=0.8,
)
kwargs = config.get_pipeline_kwargs()
assert kwargs["index_path"] == "/custom/index"
assert kwargs["similarity_threshold"] == 0.8
For integration tests that exercise _retrieve_by_id and _retrieve_by_text,
use the db_session fixture from the test conftest.py and seed test data per
the project testing guidelines.
Next¶
- Generation Pipeline -- build a generation pipeline plugin
- Best Practices -- packaging, versioning, and distribution tips