import asyncio
from functools import partial
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    LangSmithParams,
    generate_from_stream,
)
from langchain_core.messages import (
    BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_ai21.ai21_base import AI21Base
from langchain_ai21.chat.chat_adapter import ChatAdapter
from langchain_ai21.chat.chat_factory import create_chat_adapter
[docs]
class ChatAI21(BaseChatModel, AI21Base):
    """ChatAI21 chat model. Different model types support different parameters and
    different parameter values. Please read the [AI21 reference documentation]
    (https://docs.ai21.com/reference) for your model to understand which parameters
    are available.
    Example:
        .. code-block:: python
            from langchain_ai21 import ChatAI21
            model = ChatAI21(
                # defaults to os.environ.get("AI21_API_KEY")
                api_key="my_api_key"
            )
    """
    model: str
    """Model type you wish to interact with. 
        You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
    num_results: int = 1
    """The number of responses to generate for a given prompt."""
    stop: Optional[List[str]] = None
    """Default stop sequences."""
    max_tokens: int = 16
    """The maximum number of tokens to generate for each response."""
    min_tokens: int = 0
    """The minimum number of tokens to generate for each response.
    _Not supported for all models._"""
    temperature: float = 0.7
    """A value controlling the "creativity" of the model's responses."""
    top_p: float = 1
    """A value controlling the diversity of the model's responses."""
    top_k_return: int = 0
    """The number of top-scoring tokens to consider for each generation step.
    _Not supported for all models._"""
    frequency_penalty: Optional[Any] = None
    """A penalty applied to tokens that are frequently generated.
    _Not supported for all models._"""
    presence_penalty: Optional[Any] = None
    """ A penalty applied to tokens that are already present in the prompt.
    _Not supported for all models._"""
    count_penalty: Optional[Any] = None
    """A penalty applied to tokens based on their frequency 
    in the generated responses. _Not supported for all models._"""
    n: int = 1
    """Number of chat completions to generate for each prompt."""
    streaming: bool = False
    _chat_adapter: ChatAdapter
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        values = super().validate_environment(values)
        model = values.get("model")
        values["_chat_adapter"] = create_chat_adapter(model)  # type: ignore
        return values
    class Config:
        """Configuration for this pydantic object."""
        arbitrary_types_allowed = True
    @property
    def _llm_type(self) -> str:
        """Return type of chat model."""
        return "chat-ai21"
    @property
    def _default_params(self) -> Mapping[str, Any]:
        base_params = {
            "model": self.model,
            "num_results": self.num_results,
            "max_tokens": self.max_tokens,
            "min_tokens": self.min_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k_return": self.top_k_return,
            "n": self.n,
        }
        if self.stop:
            base_params["stop_sequences"] = self.stop
        if self.count_penalty is not None:
            base_params["count_penalty"] = self.count_penalty.to_dict()
        if self.frequency_penalty is not None:
            base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
        if self.presence_penalty is not None:
            base_params["presence_penalty"] = self.presence_penalty.to_dict()
        return base_params
    def _get_ls_params(
        self, stop: Optional[List[str]] = None, **kwargs: Any
    ) -> LangSmithParams:
        """Get standard params for tracing."""
        params = self._get_invocation_params(stop=stop, **kwargs)
        ls_params = LangSmithParams(
            ls_provider="ai21",
            ls_model_name=self.model,
            ls_model_type="chat",
            ls_temperature=params.get("temperature", self.temperature),
        )
        if ls_max_tokens := params.get("max_tokens", self.max_tokens):
            ls_params["ls_max_tokens"] = ls_max_tokens
        if ls_stop := stop or params.get("stop", None) or self.stop:
            ls_params["ls_stop"] = ls_stop
        return ls_params
    def _build_params_for_request(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> Mapping[str, Any]:
        params = {}
        converted_messages = self._chat_adapter.convert_messages(messages)
        if stop is not None:
            if "stop" in kwargs:
                raise ValueError("stop is defined in both stop and kwargs")
            params["stop_sequences"] = stop
        return {
            **converted_messages,
            **self._default_params,
            **params,
            **kwargs,
        }
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        stream: Optional[bool] = None,
        **kwargs: Any,
    ) -> ChatResult:
        should_stream = stream or self.streaming
        if should_stream:
            return self._handle_stream_from_generate(
                messages=messages,
                stop=stop,
                run_manager=run_manager,
                **kwargs,
            )
        params = self._build_params_for_request(
            messages=messages,
            stop=stop,
            stream=should_stream,
            **kwargs,
        )
        messages = self._chat_adapter.call(self.client, **params)
        generations = [ChatGeneration(message=message) for message in messages]
        return ChatResult(generations=generations)
    def _handle_stream_from_generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        stream_iter = self._stream(
            messages=messages,
            stop=stop,
            run_manager=run_manager,
            **kwargs,
        )
        return generate_from_stream(stream_iter)
    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        params = self._build_params_for_request(
            messages=messages,
            stop=stop,
            stream=True,
            **kwargs,
        )
        for chunk in self._chat_adapter.call(self.client, **params):
            if run_manager and isinstance(chunk.message.content, str):
                run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
            yield chunk
    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        return await asyncio.get_running_loop().run_in_executor(
            None, partial(self._generate, **kwargs), messages, stop, run_manager
        )
    def _get_system_message_from_message(self, message: BaseMessage) -> str:
        if not isinstance(message.content, str):
            raise ValueError(
                f"System Message must be of type str. Got {type(message.content)}"
            )
        return message.content