from abc import ABC
from typing import Any, Dict, List, Optional, Sequence
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform_v1beta1 import (
    EvaluationServiceAsyncClient,
    EvaluationServiceClient,
)
from google.cloud.aiplatform_v1beta1.types import (
    EvaluateInstancesRequest,
    EvaluateInstancesResponse,
)
from google.protobuf.json_format import MessageToDict
from langchain_google_vertexai._utils import (
    get_client_info,
    get_user_agent,
)
from langchain_google_vertexai.evaluators._core import (
    PairwiseStringEvaluator,
    StringEvaluator,
)
_METRICS = [
    "bleu",
    "exact_match",
    "rouge",
    "coherence",
    "fluency",
    "safety",
    "groundedness",
    "fulfillment",
    "summarization_quality",
    "summarization_helpfulness",
    "summarization_verbosity",
    "question_answering_quality",
    "question_answering_relevance",
    "question_answering_correctness",
]
_PAIRWISE_METRICS = [
    "pairwise_question_answering_quality",
    "pairwise_summarization_quality",
]
_METRICS_INPUTS = {
    "rouge1": {"rouge_type": "rouge1"},
    "rouge2": {"rouge_type": "rouge2"},
    "rougeL": {"rouge_type": "rougeL"},
    "rougeLsum": {"rouge_type": "rougeLsum"},
}
_METRICS_ATTRS = {
    "safety": ["prediction"],
    "coherence": ["prediction"],
    "fluency": ["prediction"],
    "groundedness": ["context", "prediction"],
    "fulfillment": ["prediction", "instruction"],
    "summarization_quality": ["prediction", "instruction", "context"],
    "summarization_helpfulness": ["prediction", "context"],
    "summarization_verbosity": ["prediction", "context"],
    "question_answering_quality": ["prediction", "context", "instruction"],
    "question_answering_relevance": ["prediction", "instruction"],
    "question_answering_correctness": ["prediction", "instruction"],
    "pairwise_question_answering_quality": [
        "prediction",
        "baseline_prediction",
        "context",
        "instruction",
    ],
    "pairwise_summarization_quality": [
        "prediction",
        "baseline_prediction",
        "context",
        "instruction",
    ],
}
_METRICS_OPTIONAL_ATTRS = {
    "summarization_quality": ["reference"],
    "summarization_helpfulness": ["reference", "instruction"],
    "summarization_verbosity": ["reference", "instruction"],
    "question_answering_quality": ["reference"],
    "question_answering_relevance": ["reference", "context"],
    "question_answering_correctness": ["reference", "context"],
    "pairwise_question_answering_quality": ["reference"],
    "pairwise_summarization_quality": ["reference"],
}
# a client supports multiple instances per request for these metrics
_METRICS_MULTIPLE_INSTANCES = ["bleu", "exact_match", "rouge"]
def _format_metric(metric: str) -> str:
    if metric.startswith("rouge"):
        return "rouge"
    return metric
def _format_instance(instance: Dict[str, str], metric: str) -> Dict[str, str]:
    attrs = _METRICS_ATTRS.get(metric, ["prediction", "reference"])
    result = {a: instance[a] for a in attrs}
    for attr in _METRICS_OPTIONAL_ATTRS.get(metric, []):
        if attr in instance:
            result[attr] = instance[attr]
    return result
def _prepare_request(
    instances: Sequence[Dict[str, str]], metric: str, location: str
) -> EvaluateInstancesRequest:
    request = EvaluateInstancesRequest()
    metric_input: Dict[str, Any] = {"metric_spec": _METRICS_INPUTS.get(metric, {})}
    if _format_metric(metric) not in _METRICS_MULTIPLE_INSTANCES:
        if len(instances) > 1:
            raise ValueError(
                f"Metric {metric} supports only a single instance per request, "
                f"got {len(instances)}!"
            )
        metric_input["instance"] = _format_instance(instances[0], metric=metric)
    else:
        metric_input["instances"] = [
            _format_instance(i, metric=metric) for i in instances
        ]
    setattr(request, f"{_format_metric(metric)}_input", metric_input)
    request.location = location
    return request
def _parse_response(
    response: EvaluateInstancesResponse, metric: str
) -> List[Dict[str, Any]]:
    metric = _format_metric(metric)
    result = MessageToDict(response._pb, preserving_proto_field_name=True)
    if metric in _METRICS_MULTIPLE_INSTANCES:
        return result[f"{metric}_results"][f"{metric}_metric_values"]
    return [result[f"{metric}_result"]]
class _EvaluatorBase(ABC):
    @property
    def _user_agent(self) -> str:
        """Gets the User Agent."""
        _, user_agent = get_user_agent(f"{type(self).__name__}_{self._metric}")
        return user_agent
    def __init__(self, metric: str, project_id: str, location: str = "us-central1"):
        self._metric = metric
        client_options = ClientOptions(
            api_endpoint=f"{location}-{constants.PREDICTION_API_BASE_PATH}"
        )
        self._client = EvaluationServiceClient(
            client_options=client_options,
            client_info=get_client_info(module=self._user_agent),
        )
        self._async_client = EvaluationServiceAsyncClient(
            client_options=client_options,
            client_info=get_client_info(module=self._user_agent),
        )
        self._location = self._client.common_location_path(project_id, location)
    def _prepare_request(
        self,
        prediction: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> EvaluateInstancesRequest:
        instance = {"prediction": prediction}
        if reference:
            instance["reference"] = reference
        if input:
            instance["context"] = input
        instance = {**instance, **kwargs}
        return _prepare_request(
            [instance], metric=self._metric, location=self._location
        )
[docs]
class VertexStringEvaluator(_EvaluatorBase, StringEvaluator):
    """Evaluate the perplexity of a predicted string."""
[docs]
    def __init__(self, metric: str, **kwargs):
        super().__init__(metric, **kwargs)
        if _format_metric(metric) not in _METRICS:
            raise ValueError(f"Metric {metric} is not supported yet!") 
    def _evaluate_strings(
        self,
        *,
        prediction: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> dict:
        request = self._prepare_request(prediction, reference, input, **kwargs)
        response = self._client.evaluate_instances(request)
        return _parse_response(response, metric=self._metric)[0]
[docs]
    def evaluate(
        self,
        examples: Sequence[Dict[str, str]],
        predictions: Sequence[Dict[str, str]],
        *,
        question_key: str = "context",
        answer_key: str = "reference",
        prediction_key: str = "prediction",
        instruction_key: str = "instruction",
        **kwargs: Any,
    ) -> List[dict]:
        instances: List[dict] = []
        for example, prediction in zip(examples, predictions):
            row = {"prediction": prediction[prediction_key]}
            if answer_key in example:
                row["reference"] = example[answer_key]
            if question_key in example:
                row["context"] = example[question_key]
            if instruction_key in example:
                row["instruction"] = example[instruction_key]
            instances.append(row)
        if self._metric in _METRICS_MULTIPLE_INSTANCES:
            request = _prepare_request(
                instances, metric=self._metric, location=self._location
            )
            response = self._client.evaluate_instances(request)
            return _parse_response(response, metric=self._metric)
        else:
            return [self._evaluate_strings(**i) for i in instances] 
    async def _aevaluate_strings(
        self,
        *,
        prediction: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> dict:
        request = self._prepare_request(prediction, reference, input, **kwargs)
        response = await self._async_client.evaluate_instances(request)
        return _parse_response(response, metric=self._metric)[0] 
[docs]
class VertexPairWiseStringEvaluator(_EvaluatorBase, PairwiseStringEvaluator):
    """Evaluate the perplexity of a predicted string."""
[docs]
    def __init__(self, metric: str, **kwargs):
        super().__init__(metric, **kwargs)
        if _format_metric(metric) not in _PAIRWISE_METRICS:
            raise ValueError(f"Metric {metric} is not supported yet!") 
    def _evaluate_string_pairs(
        self,
        *,
        prediction: str,
        prediction_b: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> dict:
        request = self._prepare_request(
            prediction_b, reference, input, baseline_prediction=prediction, **kwargs
        )
        response = self._client.evaluate_instances(request)
        return _parse_response(response, metric=self._metric)[0]
    async def _aevaluate_string_pairs(
        self,
        *,
        prediction: str,
        prediction_b: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> dict:
        request = self._prepare_request(
            prediction_b, reference, input, baseline_prediction=prediction, **kwargs
        )
        response = await self._async_client.evaluate_instances(request)
        return _parse_response(response, metric=self._metric)[0]