import os
from typing import Any, Dict, List, Optional, cast
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.outputs import (
    ChatGeneration,
    ChatResult,
    Generation,
    LLMResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._utils import enforce_stop_tokens
from langchain_google_vertexai.model_garden import VertexAIModelGarden
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"
[docs]
def gemma_messages_to_prompt(history: List[BaseMessage]) -> str:
    """Converts a list of messages to a chat prompt for Gemma."""
    messages: List[str] = []
    if len(messages) == 1:
        content = cast(str, history[0].content)
        if isinstance(history[0], SystemMessage):
            raise ValueError("Gemma currently doesn't support system message!")
        return content
    for message in history:
        content = cast(str, message.content)
        if isinstance(message, SystemMessage):
            raise ValueError("Gemma currently doesn't support system message!")
        elif isinstance(message, AIMessage):
            messages.append(MODEL_CHAT_TEMPLATE.format(prompt=content))
        elif isinstance(message, HumanMessage):
            messages.append(USER_CHAT_TEMPLATE.format(prompt=content))
        else:
            raise ValueError(f"Unexpected message with type {type(message)}")
    messages.append("<start_of_turn>model\n")
    return "".join(messages) 
def _parse_gemma_chat_response(response: str) -> str:
    """Removes chat history from the response."""
    pattern = "<start_of_turn>model\n"
    pos = response.rfind(pattern)
    if pos == -1:
        return response
    text = response[(pos + len(pattern)) :]
    pos = text.find("<start_of_turn>user\n")
    if pos > 0:
        return text[:pos]
    return text
class _GemmaBase(BaseModel):
    max_tokens: Optional[int] = None
    """The maximum number of tokens to generate."""
    temperature: Optional[float] = None
    """The temperature to use for sampling."""
    top_p: Optional[float] = None
    """The top-p value to use for sampling."""
    top_k: Optional[int] = None
    """The top-k value to use for sampling."""
    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling gemma."""
        params = {
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
        }
        return {k: v for k, v in params.items() if v is not None}
    def _get_params(self, **kwargs) -> Dict[str, Any]:
        return {k: kwargs.get(k, v) for k, v in self._default_params.items()}
[docs]
class GemmaVertexAIModelGarden(VertexAIModelGarden):
    allowed_model_args: Optional[List[str]] = [
        "temperature",
        "top_p",
        "top_k",
        "max_tokens",
    ]
    @property
    def _llm_type(self) -> str:
        return "gemma_vertexai_model_garden"
    # Needed so that mypy doesn't flag missing aliased init args.
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs) 
[docs]
class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseChatModel):
    allowed_model_args: Optional[List[str]] = [
        "temperature",
        "top_p",
        "top_k",
        "max_tokens",
    ]
    parse_response: bool = False
    """Whether to post-process the chat response and clean repeations """
    """or multi-turn statements."""
    def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
        """Needed for mypy typing to recognize model_name as a valid arg."""
        if model_name:
            kwargs["model_name"] = model_name
        super().__init__(**kwargs)
    class Config:
        """Configuration for this pydantic object."""
        allow_population_by_field_name = True
    @property
    def _llm_type(self) -> str:
        return "gemma_vertexai_model_garden"
    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling gemma."""
        params = {"max_length": self.max_tokens}
        return {k: v for k, v in params.items() if v is not None}
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        request = self._get_params(**kwargs)
        request["prompt"] = gemma_messages_to_prompt(messages)
        output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
        text = output.predictions[0]
        if self.parse_response or kwargs.get("parse_response"):
            text = _parse_gemma_chat_response(text)
        if stop:
            text = enforce_stop_tokens(text, stop)
        generations = [
            ChatGeneration(
                message=AIMessage(content=text),
            )
        ]
        return ChatResult(generations=generations)
    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Top Level call"""
        request = self._get_params(**kwargs)
        request["prompt"] = gemma_messages_to_prompt(messages)
        output = await self.async_client.predict(
            endpoint=self.endpoint_path, instances=[request]
        )
        text = output.predictions[0]
        if self.parse_response or kwargs.get("parse_response"):
            text = _parse_gemma_chat_response(text)
        if stop:
            text = enforce_stop_tokens(text, stop)
        generations = [
            ChatGeneration(
                message=AIMessage(content=text),
            )
        ]
        return ChatResult(generations=generations) 
class _GemmaLocalKaggleBase(_GemmaBase):
    """Local gemma model loaded from Kaggle."""
    client: Any = None  #: :meta private:
    keras_backend: str = "jax"
    model_name: str = Field(default="gemma_2b_en", alias="model")
    """Gemma model name."""
    class Config:
        """Configuration for this pydantic object."""
        allow_population_by_field_name = True
    def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
        """Needed for mypy typing to recognize model_name as a valid arg."""
        if model_name:
            kwargs["model_name"] = model_name
        super().__init__(**kwargs)
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that llama-cpp-python library is installed."""
        try:
            os.environ["KERAS_BACKEND"] = values["keras_backend"]
            from keras_nlp.models import GemmaCausalLM  # type: ignore
        except ImportError:
            raise ImportError(
                "Could not import GemmaCausalLM library. "
                "Please install the GemmaCausalLM library to "
                "use this  model: pip install keras-nlp keras>=3 kaggle"
            )
        values["client"] = GemmaCausalLM.from_preset(values["model_name"])
        return values
    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling gemma."""
        params = {"max_length": self.max_tokens}
        return {k: v for k, v in params.items() if v is not None}
    def _get_params(self, **kwargs) -> Dict[str, Any]:
        mapping = {"max_tokens": "max_length"}
        params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
        return {**self._default_params, **params}
[docs]
class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):  # type: ignore
    """Local gemma chat model loaded from Kaggle."""
    def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
        """Only needed for typing."""
        if model_name:
            kwargs["model_name"] = model_name
        super().__init__(**kwargs)
    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompt and input."""
        params = self._get_params(**kwargs)
        results = self.client.generate(prompts, **params)
        results = [results] if isinstance(results, str) else results
        if stop:
            results = [enforce_stop_tokens(text, stop) for text in results]
        return LLMResult(generations=[[Generation(text=result)] for result in results])
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "gemma_local_kaggle" 
[docs]
class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel):  # type: ignore
    parse_response: bool = False
    """Whether to post-process the chat response and clean repeations """
    """or multi-turn statements."""
    def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
        """Needed for mypy typing to recognize model_name as a valid arg."""
        if model_name:
            kwargs["model_name"] = model_name
        super().__init__(**kwargs)
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        params = self._get_params(**kwargs)
        prompt = gemma_messages_to_prompt(messages)
        text = self.client.generate(prompt, **params)
        if self.parse_response or kwargs.get("parse_response"):
            text = _parse_gemma_chat_response(text)
        if stop:
            text = enforce_stop_tokens(text, stop)
        generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(generations=[generation])
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "gemma_local_chat_kaggle" 
class _GemmaLocalHFBase(_GemmaBase):
    """Local gemma model loaded from HuggingFace."""
    tokenizer: Any = None  #: :meta private:
    client: Any = None  #: :meta private:
    hf_access_token: str
    cache_dir: Optional[str] = None
    model_name: str = Field(default="google/gemma-2b", alias="model")
    """Gemma model name."""
    class Config:
        """Configuration for this pydantic object."""
        allow_population_by_field_name = True
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that llama-cpp-python library is installed."""
        try:
            from transformers import AutoTokenizer, GemmaForCausalLM  # type: ignore
        except ImportError:
            raise ImportError(
                "Could not import GemmaForCausalLM library. "
                "Please install the GemmaForCausalLM library to "
                "use this  model: pip install transformers>=4.38.1"
            )
        values["tokenizer"] = AutoTokenizer.from_pretrained(
            values["model_name"], token=values["hf_access_token"]
        )
        values["client"] = GemmaForCausalLM.from_pretrained(
            values["model_name"],
            token=values["hf_access_token"],
            cache_dir=values["cache_dir"],
        )
        return values
    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling gemma."""
        params = {"max_length": self.max_tokens}
        return {k: v for k, v in params.items() if v is not None}
    def _get_params(self, **kwargs) -> Dict[str, Any]:
        mapping = {"max_tokens": "max_length"}
        params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
        return {**self._default_params, **params}
    def _run(self, prompt: str, **kwargs: Any) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt")
        params = self._get_params(**kwargs)
        generate_ids = self.client.generate(inputs.input_ids, **params)
        return self.tokenizer.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
[docs]
class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM):  # type: ignore
    """Local gemma model loaded from HuggingFace."""
    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompt and input."""
        results = [self._run(prompt, **kwargs) for prompt in prompts]
        if stop:
            results = [enforce_stop_tokens(text, stop) for text in results]
        return LLMResult(generations=[[Generation(text=text)] for text in results])
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "gemma_local_hf" 
[docs]
class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel):  # type: ignore
    parse_response: bool = False
    """Whether to post-process the chat response and clean repeations """
    """or multi-turn statements."""
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        prompt = gemma_messages_to_prompt(messages)
        text = self._run(prompt, **kwargs)
        if self.parse_response or kwargs.get("parse_response"):
            text = _parse_gemma_chat_response(text)
        if stop:
            text = enforce_stop_tokens(text, stop)
        generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(generations=[generation])
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "gemma_local_chat_hf"