Source code for langchain_mongodb.retrievers.hybrid_search
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pymongo.collection import Collection
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.pipelines import (
    combine_pipelines,
    final_hybrid_stage,
    reciprocal_rank_stage,
    text_search_stage,
    vector_search_stage,
)
from langchain_mongodb.utils import make_serializable
[docs]
class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
    """Hybrid Search Retriever combines vector and full-text searches
    weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
    Increasing the vector_penalty will reduce the importance on the vector search.
    Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
    For more on the algorithm,see
    https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
    """
    vectorstore: MongoDBAtlasVectorSearch
    """MongoDBAtlas VectorStore"""
    search_index_name: str
    """Atlas Search Index (full-text) name"""
    top_k: int = 4
    """Number of documents to return."""
    oversampling_factor: int = 10
    """This times top_k is the number of candidates chosen at each step"""
    pre_filter: Optional[Dict[str, Any]] = None
    """(Optional) Any MQL match expression comparing an indexed field"""
    post_filter: Optional[List[Dict[str, Any]]] = None
    """(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
    vector_penalty: float = 60.0
    """Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
    fulltext_penalty: float = 60.0
    """Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
    show_embeddings: float = False
    """If true, returned Document metadata will include vectors."""
    @property
    def collection(self) -> Collection:
        return self.vectorstore._collection
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        """Retrieve documents that are highest scoring / most similar  to query.
        Note that the same query is used in both searches,
        embedded for vector search, and as-is for full-text search.
        Args:
            query: String to find relevant documents for
            run_manager: The callback handler to use
        Returns:
            List of relevant documents
        """
        query_vector = self.vectorstore._embedding.embed_query(query)
        scores_fields = ["vector_score", "fulltext_score"]
        pipeline: List[Any] = []
        # First we build up the aggregation pipeline,
        # then it is passed to the server to execute
        # Vector Search stage
        vector_pipeline = [
            vector_search_stage(
                query_vector=query_vector,
                search_field=self.vectorstore._embedding_key,
                index_name=self.vectorstore._index_name,
                top_k=self.top_k,
                filter=self.pre_filter,
                oversampling_factor=self.oversampling_factor,
            )
        ]
        vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
        combine_pipelines(pipeline, vector_pipeline, self.collection.name)
        # Full-Text Search stage
        text_pipeline = text_search_stage(
            query=query,
            search_field=self.vectorstore._text_key,
            index_name=self.search_index_name,
            limit=self.top_k,
            filter=self.pre_filter,
        )
        text_pipeline.extend(
            reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
        )
        combine_pipelines(pipeline, text_pipeline, self.collection.name)
        # Sum and sort stage
        pipeline.extend(
            final_hybrid_stage(scores_fields=scores_fields, limit=self.top_k)
        )
        # Removal of embeddings unless requested.
        if not self.show_embeddings:
            pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
        # Post filtering
        if self.post_filter is not None:
            pipeline.extend(self.post_filter)
        # Execution
        cursor = self.collection.aggregate(pipeline)  # type: ignore[arg-type]
        # Formatting
        docs = []
        for res in cursor:
            text = res.pop(self.vectorstore._text_key)
            # score = res.pop("score")  # The score remains buried!
            make_serializable(res)
            docs.append(Document(page_content=text, metadata=res))
        return docs