from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from google.cloud.aiplatform import telemetry
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, BaseLLM
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatResult, LLMResult
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import BaseModel, Field
from vertexai.preview.vision_models import (  # type: ignore[import-untyped]
    GeneratedImage,
    ImageGenerationModel,
)
from vertexai.vision_models import Image, ImageTextModel  # type: ignore[import-untyped]
from langchain_google_vertexai._image_utils import (
    ImageBytesLoader,
    create_image_content_part,
    get_image_str_from_content_part,
    get_text_str_from_content_part,
    image_bytes_to_b64_string,
)
from langchain_google_vertexai._utils import get_user_agent
class _BaseImageTextModel(BaseModel):
    """Base class for all integrations that use ImageTextModel"""
    cached_client: Any = Field(default=None)
    model_name: str = Field(default="imagetext@001")
    """ Name of the model to use"""
    number_of_results: int = Field(default=1)
    """Number of results to return from one query"""
    language: str = Field(default="en")
    """Language of the query"""
    project: Union[str, None] = Field(default=None)
    """Google cloud project"""
    @property
    def client(self) -> ImageTextModel:
        if self.cached_client is None:
            self.cached_client = ImageTextModel.from_pretrained(
                model_name=self.model_name,
            )
        return self.cached_client
    def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None:
        """Given a message part obtain a image if the part represents it.
        Args:
            message_part: Item of a message content.
        Returns:
            Image is successful otherwise None.
        """
        image_str = get_image_str_from_content_part(message_part)
        if isinstance(image_str, str):
            loader = ImageBytesLoader(project=self.project)
            image_bytes = loader.load_bytes(image_str)
            return Image(image_bytes=image_bytes)
        else:
            return None
    def _get_text_from_message_part(self, message_part: str | Dict) -> str | None:
        """Given a message part obtain a text if the part represents it.
        Args:
            message_part: Item of a message content.
        Returns:
            str is successful otherwise None.
        """
        return get_text_str_from_content_part(message_part)
    @property
    def _llm_type(self) -> str:
        """Returns the type of LLM"""
        return "vertexai-vision"
    @property
    def _user_agent(self) -> str:
        """Gets the User Agent."""
        _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}")
        return user_agent
    @property
    def _default_params(self) -> Dict[str, Any]:
        return {"number_of_results": self.number_of_results, "language": self.language}
    def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
        params = self._default_params
        for key, value in kwargs.items():
            if key in params and value is not None:
                params[key] = value
        return params
class _BaseVertexAIImageCaptioning(_BaseImageTextModel):
    """Base class for Image Captioning models."""
    def _get_captions(
        self,
        image: Image,
        number_of_results: Optional[int] = None,
        language: Optional[str] = None,
    ) -> List[str]:
        """Uses the sdk methods to generate a list of captions.
        Args:
            image: Image to get the captions for.
            number_of_results: Number of results to return from one query.
            language: Language of the query.
        Returns:
            List of captions obtained from the image.
        """
        with telemetry.tool_context_manager(self._user_agent):
            params = self._prepare_params(
                number_of_results=number_of_results, language=language
            )
            captions = self.client.get_captions(image=image, **params)
            return captions
[docs]
class VertexAIImageCaptioning(_BaseVertexAIImageCaptioning, BaseLLM):
    """Implementation of the Image Captioning model as an LLM."""
    def _generate(
        self,
        prompts: List[str],
        stop: List[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Generates the captions.
        Args:
            prompts: List of prompts to use. Each prompt must be a string
                that represents an image. Currently supported are:
                - Google Cloud Storage URI
                - B64 encoded string
                - Local file path
                - Remote url
        Returns:
            Captions generated from every prompt.
        """
        generations = [
            self._generate_one(prompt=prompt, **kwargs) for prompt in prompts
        ]
        return LLMResult(generations=generations)
    def _generate_one(self, prompt: str, **kwargs) -> List[Generation]:
        """Generates the captions for a single prompt.
        Args:
            prompt: Image url for the generation.
        Returns:
            List of generations
        """
        image_loader = ImageBytesLoader(project=self.project)
        image_bytes = image_loader.load_bytes(prompt)
        image = Image(image_bytes=image_bytes)
        caption_list = self._get_captions(image=image, **kwargs)
        return [Generation(text=caption) for caption in caption_list] 
[docs]
class VertexAIImageCaptioningChat(_BaseVertexAIImageCaptioning, BaseChatModel):
    """Implementation of the Image Captioning model as a chat."""
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: List[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Generates the results.
        Args:
            messages: List of messages. Currently only one message is supported.
                The message content must be a list with only one element with
                a dict with format:
                {
                    'type': 'image_url',
                    'image_url': {
                        'url' <image_string>
                    }
                }
                Currently supported image strings are:
                - Google Cloud Storage URI
                - B64 encoded string
                - Local file path
                - Remote url
        """
        image = None
        is_valid = (
            len(messages) == 1
            and isinstance(messages[0].content, List)
            and len(messages[0].content) == 1
        )
        if is_valid:
            content = messages[0].content[0]
            image = self._get_image_from_message_part(content)
        if image is None:
            raise ValueError(
                f"{self.__class__.__name__} messages should be a list with "
                "only one message. This message content must be a list with "
                "one dictionary with the format: "
                "{'type': 'image_url', 'image_url': {'image': <image_str>}}"
            )
        captions = self._get_captions(image, **messages[0].additional_kwargs)
        generations = [
            ChatGeneration(message=AIMessage(content=caption)) for caption in captions
        ]
        return ChatResult(generations=generations) 
[docs]
class VertexAIVisualQnAChat(_BaseImageTextModel, BaseChatModel):
    """Chat implementation of a visual QnA model"""
    @property
    def _default_params(self) -> Dict[str, Any]:
        return {"number_of_results": self.number_of_results}
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: List[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Generates the results.
        Args:
            messages: List of messages. The first message should contain a
            string representation of the image.
                Currently supported are:
                    - Google Cloud Storage URI
                    - B64 encoded string
                    - Local file path
                    - Remote url
                There has to be at least other message with the first question.
        """
        image = None
        user_question = None
        is_valid = (
            len(messages) == 1
            and isinstance(messages[0].content, List)
            and len(messages[0].content) == 2
        )
        if is_valid:
            image_part = messages[0].content[0]
            user_question_part = messages[0].content[1]
            image = self._get_image_from_message_part(image_part)
            user_question = self._get_text_from_message_part(user_question_part)
        if (image is None) or (user_question is None):
            raise ValueError(
                f"{self.__class__.__name__} messages should be a list with "
                "only one message. The message content should be a list with "
                "two elements. The first element should be the image, a dictionary "
                "with format"
                "{'type': 'image_url', 'image_url': {'image': <image_str>}}."
                "The second one should be the user question. Either a simple string"
                "or a dictionary with format {'type': 'text', 'text': <message>}"
            )
        answers = self._ask_questions(
            image=image, query=user_question, **messages[0].additional_kwargs
        )
        generations = [
            ChatGeneration(message=AIMessage(content=answer)) for answer in answers
        ]
        return ChatResult(generations=generations)
    def _ask_questions(
        self, image: Image, query: str, number_of_results: Optional[int] = None
    ) -> List[str]:
        """Interfaces with the sdk to get the question.
        Args:
            image: Image to question about.
            query: User query.
        Returns:
            List of responses to the query.
        """
        with telemetry.tool_context_manager(self._user_agent):
            params = self._prepare_params(number_of_results=number_of_results)
            answers = self.client.ask_question(image=image, question=query, **params)
            return answers 
class _BaseVertexAIImageGenerator(BaseModel):
    """Base class form generation and edition of images."""
    cached_client: Any = Field(default=None)
    model_name: str = Field(default="imagegeneration@002")
    """Name of the base model"""
    negative_prompt: Union[str, None] = Field(default=None)
    """A description of what you want to omit in
        the generated images"""
    number_of_results: int = Field(default=1)
    """Number of images to generate"""
    guidance_scale: Union[float, None] = Field(default=None)
    """Controls the strength of the prompt"""
    language: Union[str, None] = Field(default=None)
    """Language of the text prompt for the image Supported values are "en" for English, 
    "hi" for Hindi, "ja" for Japanese, "ko" for Korean, and "auto" for automatic 
    language detection"""
    seed: Union[int, None] = Field(default=None)
    """Random seed for the image generation"""
    project: Union[str, None] = Field(default=None)
    """Google cloud project id"""
    @property
    def client(self) -> ImageGenerationModel:
        if not self.cached_client:
            self.cached_client = ImageGenerationModel.from_pretrained(
                model_name=self.model_name,
            )
        return self.cached_client
    @property
    def _default_params(self) -> Dict[str, Any]:
        return {
            "number_of_images": self.number_of_results,
            "language": self.language,
            "negative_prompt": self.negative_prompt,
            "guidance_scale": self.guidance_scale,
            "seed": self.seed,
        }
    def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
        params = self._default_params
        mapping = {"number_of_results": "number_of_images"}
        for key, value in kwargs.items():
            key = mapping.get(key, key)
            if key in params and value is not None:
                params[key] = value
        return {k: v for k, v in params.items() if v is not None}
    def _generate_images(self, prompt: str, **kwargs: Any) -> List[str]:
        """Generates images given a prompt.
        Args:
            prompt: Description of what the image should look like.
        Returns:
            List of b64 encoded strings.
        """
        with telemetry.tool_context_manager(self._user_agent):
            generation_result = self.client.generate_images(
                prompt=prompt, **self._prepare_params(**kwargs)
            )
        image_str_list = [
            self._to_b64_string(image) for image in generation_result.images
        ]
        return image_str_list
    def _edit_images(self, image_str: str, prompt: str, **kwargs: Any) -> List[str]:
        """Edit an image given a image and a prompt.
        Args:
            image_str: String representation of the image.
            prompt: Description of what the image should look like.
        Returns:
            List of b64 encoded strings.
        """
        with telemetry.tool_context_manager(self._user_agent):
            image_loader = ImageBytesLoader(project=self.project)
            image_bytes = image_loader.load_bytes(image_str)
            image = Image(image_bytes=image_bytes)
            generation_result = self.client.edit_image(
                prompt=prompt, base_image=image, **self._prepare_params(**kwargs)
            )
        image_str_list = [
            self._to_b64_string(image) for image in generation_result.images
        ]
        return image_str_list
    def _to_b64_string(self, image: GeneratedImage) -> str:
        """Transforms a generated image into a b64 encoded string.
        Args:
            image: Image to convert.
        Returns:
            b64 encoded string of the image.
        """
        # This is a hack because at the moment, GeneratedImage doesn't provide
        # a way to get the bytes of the image (or anything else). There is
        # only private methods that are not reliable.
        from tempfile import NamedTemporaryFile
        temp_file = NamedTemporaryFile()
        image.save(temp_file.name, include_generation_parameters=False)
        temp_file.seek(0)
        image_bytes = temp_file.read()
        temp_file.close()
        return image_bytes_to_b64_string(image_bytes=image_bytes)
    @property
    def _llm_type(self) -> str:
        """Returns the type of LLM"""
        return "vertexai-vision"
    @property
    def _user_agent(self) -> str:
        """Gets the User Agent."""
        _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}")
        return user_agent
[docs]
class VertexAIImageGeneratorChat(_BaseVertexAIImageGenerator, BaseChatModel):
    """Generates an image from a prompt."""
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: List[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        """
        Args:
            messages: The message must be a list of only one element with one part:
                The user prompt.
        """
        # Only one message allowed with one text part.
        user_query = None
        if len(messages) == 1:
            if isinstance(messages[0].content, str):
                user_query = messages[0].content
            elif len(messages[0].content) == 1:
                user_query = get_text_str_from_content_part(messages[0].content[0])
        if user_query is None:
            raise ValueError(
                "Only one message with one text part allowed for image generation"
                " Must The prompt of the image"
            )
        image_str_list = self._generate_images(
            prompt=user_query, **messages[0].additional_kwargs
        )
        image_content_part_list = [
            create_image_content_part(image_str=image_str)
            for image_str in image_str_list
        ]
        generations = [
            ChatGeneration(message=AIMessage(content=[content_part]))
            for content_part in image_content_part_list
        ]
        return ChatResult(generations=generations) 
[docs]
class VertexAIImageEditorChat(_BaseVertexAIImageGenerator, BaseChatModel):
    """Given an image and a prompt, edits the image.
    Currently only supports mask free editing.
    """
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: List[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        """
        Args:
            messages: The message must be a list of only one element with two part:
                - The image as a dict {
                    'type': 'image_url', 'image_url': {'url': <message_str>}
                    }
                - The user prompt.
        """
        # Only one message allowed with two parts: the image and the text.
        user_query = None
        is_valid = len(messages) == 1 and len(messages[0].content) == 2
        if is_valid:
            image_str = get_image_str_from_content_part(messages[0].content[0])
            user_query = get_text_str_from_content_part(messages[0].content[1])
        if (user_query is None) or (image_str is None):
            raise ValueError(
                "Only one message allowed for image edition. The message must have"
                "two parts: First the image and then the user prompt."
            )
        image_str_list = self._edit_images(
            image_str=image_str, prompt=user_query, **messages[0].additional_kwargs
        )
        image_content_part_list = [
            create_image_content_part(image_str=image_str)
            for image_str in image_str_list
        ]
        generations = [
            ChatGeneration(message=AIMessage(content=[content_part]))
            for content_part in image_content_part_list
        ]
        return ChatResult(generations=generations)