Source code for langchain_ai21.ai21_base
import os
from typing import Any, Dict, Optional
from ai21 import AI21Client
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
_DEFAULT_TIMEOUT_SEC = 300
[docs]
class AI21Base(BaseModel):
    """Base class for AI21 models."""
    class Config:
        arbitrary_types_allowed = True
    client: Any = Field(default=None, exclude=True)  #: :meta private:
    api_key: Optional[SecretStr] = None
    api_host: Optional[str] = None
    timeout_sec: Optional[float] = None
    num_retries: Optional[int] = None
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        api_key = convert_to_secret_str(
            values.get("api_key") or os.getenv("AI21_API_KEY") or ""
        )
        values["api_key"] = api_key
        api_host = (
            values.get("api_host")
            or os.getenv("AI21_API_URL")
            or "https://api.ai21.com"
        )
        values["api_host"] = api_host
        timeout_sec = values.get("timeout_sec") or float(
            os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
        )
        values["timeout_sec"] = timeout_sec
        if values.get("client") is None:
            values["client"] = AI21Client(
                api_key=api_key.get_secret_value(),
                api_host=api_host,
                timeout_sec=None if timeout_sec is None else float(timeout_sec),
                via="langchain",
            )
        return values