from __future__ import annotations
import os
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr, root_validator
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
[docs]
class Ranking(BaseModel):
    index: int
    logit: float 
[docs]
class NVIDIARerank(BaseDocumentCompressor):
    """
    LangChain Document Compressor that uses the NVIDIA NeMo Retriever Reranking API.
    """
    class Config:
        validate_assignment = True
    _client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)
    _default_batch_size: int = 32
    _default_model_name: str = "nv-rerank-qa-mistral-4b:1"
    _default_base_url: str = "https://integrate.api.nvidia.com/v1"
    base_url: str = Field(
        description="Base url for model listing an invocation",
    )
    top_n: int = Field(5, ge=0, description="The number of documents to return.")
    model: Optional[str] = Field(description="The model to use for reranking.")
    truncate: Optional[Literal["NONE", "END"]] = Field(
        description=(
            "Truncate input text if it exceeds the model's maximum token length. "
            "Default is model dependent and is likely to raise error if an "
            "input is too long."
        ),
    )
    max_batch_size: int = Field(
        _default_batch_size, ge=1, description="The maximum batch size."
    )
    _base_url_var = "NVIDIA_BASE_URL"
    @root_validator(pre=True)
    def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        values["base_url"] = (
            values.get(cls._base_url_var.lower())
            or values.get("base_url")
            or os.getenv(cls._base_url_var)
            or cls._default_base_url
        )
        return values
    def __init__(self, **kwargs: Any):
        """
        Create a new NVIDIARerank document compressor.
        This class provides access to a NVIDIA NIM for reranking. By default, it
        connects to a hosted NIM, but can be configured to connect to a local NIM
        using the `base_url` parameter. An API key is required to connect to the
        hosted NIM.
        Args:
            model (str): The model to use for reranking.
            nvidia_api_key (str): The API key to use for connecting to the hosted NIM.
            api_key (str): Alternative to nvidia_api_key.
            base_url (str): The base URL of the NIM to connect to.
            truncate (str): "NONE", "END", truncate input text if it exceeds
                            the model's context length. Default is model dependent and
                            is likely to raise an error if an input is too long.
        API Key:
        - The recommended way to provide the API key is through the `NVIDIA_API_KEY`
            environment variable.
        """
        super().__init__(**kwargs)
        self._client = _NVIDIAClient(
            base_url=self.base_url,
            model_name=self.model,
            default_hosted_model_name=self._default_model_name,
            api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
            infer_path="{base_url}/ranking",
            cls=self.__class__.__name__,
        )
        # todo: only store the model in one place
        # the model may be updated to a newer name during initialization
        self.model = self._client.model_name
    @property
    def available_models(self) -> List[Model]:
        """
        Get a list of available models that work with NVIDIARerank.
        """
        return self._client.get_available_models(self.__class__.__name__)
[docs]
    @classmethod
    def get_available_models(
        cls,
        **kwargs: Any,
    ) -> List[Model]:
        """
        Get a list of available models that work with NVIDIARerank.
        """
        return cls(**kwargs).available_models 
    # todo: batching when len(documents) > endpoint's max batch size
    def _rank(self, documents: List[str], query: str) -> List[Ranking]:
        payload = {
            "model": self.model,
            "query": {"text": query},
            "passages": [{"text": passage} for passage in documents],
        }
        if self.truncate:
            payload["truncate"] = self.truncate
        response = self._client.get_req(payload=payload)
        if response.status_code != 200:
            response.raise_for_status()
        # todo: handle errors
        rankings = response.json()["rankings"]
        # todo: callback support
        return [Ranking(**ranking) for ranking in rankings[: self.top_n]]
[docs]
    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        """
        Compress documents using the NVIDIA NeMo Retriever Reranking microservice API.
        Args:
            documents: A sequence of documents to compress.
            query: The query to use for compressing the documents.
            callbacks: Callbacks to run during the compression process.
        Returns:
            A sequence of compressed documents.
        """
        if len(documents) == 0 or self.top_n < 1:
            return []
        def batch(ls: list, size: int) -> Generator[List[Document], None, None]:
            for i in range(0, len(ls), size):
                yield ls[i : i + size]
        doc_list = list(documents)
        results = []
        for doc_batch in batch(doc_list, self.max_batch_size):
            rankings = self._rank(
                query=query, documents=[d.page_content for d in doc_batch]
            )
            for ranking in rankings:
                assert (
                    0 <= ranking.index < len(doc_batch)
                ), "invalid response from server: index out of range"
                doc = doc_batch[ranking.index]
                doc.metadata["relevance_score"] = ranking.logit
                results.append(doc)
        # if we batched, we need to sort the results
        if len(doc_list) > self.max_batch_size:
            results.sort(key=lambda x: x.metadata["relevance_score"], reverse=True)
        return results[: self.top_n]