"""Wrapper around Together AI's Embeddings API."""
import logging
import os
import warnings
from typing import (
    Any,
    Dict,
    List,
    Literal,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
)
import openai
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    SecretStr,
    root_validator,
)
from langchain_core.utils import (
    convert_to_secret_str,
    get_from_dict_or_env,
    get_pydantic_field_names,
)
logger = logging.getLogger(__name__)
[docs]
class TogetherEmbeddings(BaseModel, Embeddings):
    """TogetherEmbeddings embedding model.
    To use, set the environment variable `TOGETHER_API_KEY` with your API key or
    pass it as a named parameter to the constructor.
    Example:
        .. code-block:: python
            from langchain_together import TogetherEmbeddings
            model = TogetherEmbeddings()
    """
    client: Any = Field(default=None, exclude=True)  #: :meta private:
    async_client: Any = Field(default=None, exclude=True)  #: :meta private:
    model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
    """Embeddings model name to use. 
    Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
    """
    dimensions: Optional[int] = None
    """The number of dimensions the resulting output embeddings should have.
    Not yet supported.
    """
    together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
    """API Key for Solar API."""
    together_api_base: str = Field(
        default="https://api.together.ai/v1/", alias="base_url"
    )
    """Endpoint URL to use."""
    embedding_ctx_length: int = 4096
    """The maximum number of tokens to embed at once.
    Not yet supported.
    """
    allowed_special: Union[Literal["all"], Set[str]] = set()
    """Not yet supported."""
    disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
    """Not yet supported."""
    chunk_size: int = 1000
    """Maximum number of texts to embed in each batch.
    Not yet supported.
    """
    max_retries: int = 2
    """Maximum number of retries to make when generating."""
    request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
        default=None, alias="timeout"
    )
    """Timeout for requests to Together embedding API. Can be float, httpx.Timeout or
        None."""
    show_progress_bar: bool = False
    """Whether to show a progress bar when embedding.
    Not yet supported.
    """
    model_kwargs: Dict[str, Any] = Field(default_factory=dict)
    """Holds any model parameters valid for `create` call not explicitly specified."""
    skip_empty: bool = False
    """Whether to skip empty strings when embedding or raise an error.
    Defaults to not skipping.
    Not yet supported."""
    default_headers: Union[Mapping[str, str], None] = None
    default_query: Union[Mapping[str, object], None] = None
    # Configure a custom httpx client. See the
    # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
    http_client: Union[Any, None] = None
    """Optional httpx.Client. Only used for sync invocations. Must specify
        http_async_client as well if you'd like a custom client for async invocations.
    """
    http_async_client: Union[Any, None] = None
    """Optional httpx.AsyncClient. Only used for async invocations. Must specify
        http_client as well if you'd like a custom client for sync invocations."""
    class Config:
        """Configuration for this pydantic object."""
        extra = Extra.forbid
        allow_population_by_field_name = True
    @root_validator(pre=True)
    def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Build extra kwargs from additional params that were passed in."""
        all_required_field_names = get_pydantic_field_names(cls)
        extra = values.get("model_kwargs", {})
        for field_name in list(values):
            if field_name in extra:
                raise ValueError(f"Found {field_name} supplied twice.")
            if field_name not in all_required_field_names:
                warnings.warn(
                    f"""WARNING! {field_name} is not default parameter.
                    {field_name} was transferred to model_kwargs.
                    Please confirm that {field_name} is what you intended."""
                )
                extra[field_name] = values.pop(field_name)
        invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
        if invalid_model_kwargs:
            raise ValueError(
                f"Parameters {invalid_model_kwargs} should be specified explicitly. "
                f"Instead they were passed in as part of `model_kwargs` parameter."
            )
        values["model_kwargs"] = extra
        return values
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        together_api_key = get_from_dict_or_env(
            values, "together_api_key", "TOGETHER_API_KEY"
        )
        values["together_api_key"] = (
            convert_to_secret_str(together_api_key) if together_api_key else None
        )
        values["together_api_base"] = values["together_api_base"] or os.getenv(
            "TOGETHER_API_BASE"
        )
        client_params = {
            "api_key": (
                values["together_api_key"].get_secret_value()
                if values["together_api_key"]
                else None
            ),
            "base_url": values["together_api_base"],
            "timeout": values["request_timeout"],
            "max_retries": values["max_retries"],
            "default_headers": values["default_headers"],
            "default_query": values["default_query"],
        }
        if not values.get("client"):
            sync_specific = (
                {"http_client": values["http_client"]} if values["http_client"] else {}
            )
            values["client"] = openai.OpenAI(
                **client_params, **sync_specific
            ).embeddings
        if not values.get("async_client"):
            async_specific = (
                {"http_client": values["http_async_client"]}
                if values["http_async_client"]
                else {}
            )
            values["async_client"] = openai.AsyncOpenAI(
                **client_params, **async_specific
            ).embeddings
        return values
    @property
    def _invocation_params(self) -> Dict[str, Any]:
        params: Dict = {"model": self.model, **self.model_kwargs}
        if self.dimensions is not None:
            params["dimensions"] = self.dimensions
        return params
[docs]
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of document texts using passage model.
        Args:
            texts: The list of texts to embed.
        Returns:
            List of embeddings, one for each text.
        """
        embeddings = []
        params = self._invocation_params
        params["model"] = params["model"]
        for text in texts:
            response = self.client.create(input=text, **params)
            if not isinstance(response, dict):
                response = response.model_dump()
                embeddings.extend([i["embedding"] for i in response["data"]])
        return embeddings 
[docs]
    def embed_query(self, text: str) -> List[float]:
        """Embed query text using query model.
        Args:
            text: The text to embed.
        Returns:
            Embedding for the text.
        """
        params = self._invocation_params
        params["model"] = params["model"]
        response = self.client.create(input=text, **params)
        if not isinstance(response, dict):
            response = response.model_dump()
        return response["data"][0]["embedding"] 
[docs]
    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of document texts using passage model asynchronously.
        Args:
            texts: The list of texts to embed.
        Returns:
            List of embeddings, one for each text.
        """
        embeddings = []
        params = self._invocation_params
        params["model"] = params["model"]
        for text in texts:
            response = await self.async_client.create(input=text, **params)
            if not isinstance(response, dict):
                response = response.model_dump()
                embeddings.extend([i["embedding"] for i in response["data"]])
        return embeddings 
[docs]
    async def aembed_query(self, text: str) -> List[float]:
        """Asynchronous Embed query text using query model.
        Args:
            text: The text to embed.
        Returns:
            Embedding for the text.
        """
        params = self._invocation_params
        params["model"] = params["model"]
        response = await self.async_client.create(input=text, **params)
        if not isinstance(response, dict):
            response = response.model_dump()
        return response["data"][0]["embedding"]