from __future__ import annotations
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import yaml  # type: ignore[import-untyped]
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from typing_extensions import TYPE_CHECKING, Literal
from langchain_aws.vectorstores.inmemorydb.constants import INMEMORYDB_VECTOR_DTYPE_MAP
if TYPE_CHECKING:
    from redis.commands.search.field import (  # type: ignore
        NumericField,
        TagField,
        TextField,
        VectorField,
    )
[docs]
class InMemoryDBDistanceMetric(str, Enum):
    """Distance metrics for Redis vector fields."""
    l2 = "L2"
    cosine = "COSINE"
    ip = "IP" 
[docs]
class InMemoryDBField(BaseModel):
    """Base class for Redis fields."""
    name: str = Field(...) 
[docs]
class TextFieldSchema(InMemoryDBField):
    """Schema for text fields in Redis."""
    weight: float = 1
    no_stem: bool = False
    phonetic_matcher: Optional[str] = None
    withsuffixtrie: bool = False
    no_index: bool = False
    sortable: Optional[bool] = False
[docs]
    def as_field(self) -> TextField:
        from redis.commands.search.field import TextField  # type: ignore
        return TextField(
            self.name,
            weight=self.weight,
            no_stem=self.no_stem,
            phonetic_matcher=self.phonetic_matcher,  # type: ignore
            sortable=self.sortable,
            no_index=self.no_index,
        ) 
 
[docs]
class TagFieldSchema(InMemoryDBField):
    """Schema for tag fields in Redis."""
    separator: str = ","
    case_sensitive: bool = False
    no_index: bool = False
    sortable: Optional[bool] = False
[docs]
    def as_field(self) -> TagField:
        from redis.commands.search.field import TagField  # type: ignore
        return TagField(
            self.name,
            separator=self.separator,
            case_sensitive=self.case_sensitive,
            sortable=self.sortable,
            no_index=self.no_index,
        ) 
 
[docs]
class NumericFieldSchema(InMemoryDBField):
    """Schema for numeric fields in Redis."""
    no_index: bool = False
    sortable: Optional[bool] = False
[docs]
    def as_field(self) -> NumericField:
        from redis.commands.search.field import NumericField  # type: ignore
        return NumericField(self.name, sortable=self.sortable, no_index=self.no_index) 
 
[docs]
class InMemoryDBVectorField(InMemoryDBField):
    """Base class for Redis vector fields."""
    dims: int = Field(...)
    algorithm: object = Field(...)
    datatype: str = Field(default="FLOAT32")
    distance_metric: InMemoryDBDistanceMetric = Field(default="COSINE")
    initial_cap: Optional[int] = None
    @validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True)
    def uppercase_strings(cls, v: str) -> str:
        return v.upper()
    @validator("datatype", pre=True)
    def uppercase_and_check_dtype(cls, v: str) -> str:
        if v.upper() not in INMEMORYDB_VECTOR_DTYPE_MAP:
            raise ValueError(
                f"datatype must be one of {INMEMORYDB_VECTOR_DTYPE_MAP.keys()}. Got {v}"
            )
        return v.upper()
    def _fields(self) -> Dict[str, Any]:
        field_data = {
            "TYPE": self.datatype,
            "DIM": self.dims,
            "DISTANCE_METRIC": self.distance_metric,
        }
        if self.initial_cap is not None:  # Only include it if it's set
            field_data["INITIAL_CAP"] = self.initial_cap
        return field_data 
[docs]
class FlatVectorField(InMemoryDBVectorField):
    """Schema for flat vector fields in Redis."""
    algorithm: Literal["FLAT"] = "FLAT"
    block_size: Optional[int] = None
[docs]
    def as_field(self) -> VectorField:
        from redis.commands.search.field import VectorField  # type: ignore
        field_data = super()._fields()
        if self.block_size is not None:
            field_data["BLOCK_SIZE"] = self.block_size
        return VectorField(self.name, self.algorithm, field_data) 
 
[docs]
class HNSWVectorField(InMemoryDBVectorField):
    """Schema for HNSW vector fields in Redis."""
    algorithm: Literal["HNSW"] = "HNSW"
    m: int = Field(default=16)
    ef_construction: int = Field(default=200)
    ef_runtime: int = Field(default=10)
    epsilon: float = Field(default=0.01)
[docs]
    def as_field(self) -> VectorField:
        from redis.commands.search.field import VectorField  # type: ignore
        field_data = super()._fields()
        field_data.update(
            {
                "M": self.m,
                "EF_CONSTRUCTION": self.ef_construction,
                "EF_RUNTIME": self.ef_runtime,
                "EPSILON": self.epsilon,
            }
        )
        return VectorField(self.name, self.algorithm, field_data) 
 
[docs]
class InMemoryDBModel(BaseModel):
    """Schema for MemoryDB index."""
    # always have a content field for text
    text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
    tag: Optional[List[TagFieldSchema]] = None
    numeric: Optional[List[NumericFieldSchema]] = None
    extra: Optional[List[InMemoryDBField]] = None
    # filled by default_vector_schema
    vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
    content_key: str = "content"
    content_vector_key: str = "content_vector"
[docs]
    def add_content_field(self) -> None:
        if self.text is None:
            self.text = []
        for field in self.text:
            if field.name == self.content_key:
                return
        self.text.append(TextFieldSchema(name=self.content_key)) 
[docs]
    def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
        # catch case where user inputted no vector field spec
        # in the index schema
        if self.vector is None:
            self.vector = []
        # ignore types as pydantic is handling type validation and conversion
        if vector_field["algorithm"] == "FLAT":
            self.vector.append(FlatVectorField(**vector_field))  # type: ignore
        elif vector_field["algorithm"] == "HNSW":
            self.vector.append(HNSWVectorField(**vector_field))  # type: ignore
        else:
            raise ValueError(
                f"algorithm must be either FLAT or HNSW. Got "
                f"{vector_field['algorithm']}"
            ) 
[docs]
    def as_dict(self) -> Dict[str, List[Any]]:
        schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
        # iter over all class attributes
        for attr, attr_value in self.__dict__.items():
            # only non-empty lists
            if isinstance(attr_value, list) and len(attr_value) > 0:
                field_values: List[Dict[str, Any]] = []
                # iterate over all fields in each category (tag, text, etc)
                for val in attr_value:
                    value: Dict[str, Any] = {}
                    # iterate over values within each field to extract
                    # settings for that field (i.e. name, weight, etc)
                    for field, field_value in val.__dict__.items():
                        # make enums into strings
                        if isinstance(field_value, Enum):
                            value[field] = field_value.value
                        # don't write null values
                        elif field_value is not None:
                            value[field] = field_value
                    field_values.append(value)
                schemas[attr] = field_values
        schema: Dict[str, List[Any]] = {}
        # only write non-empty lists from defaults
        for k, v in schemas.items():
            if len(v) > 0:
                schema[k] = v
        return schema 
    @property
    def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
        if not self.vector:
            raise ValueError("No vector fields found")
        for field in self.vector:
            if field.name == self.content_vector_key:
                return field
        raise ValueError("No content_vector field found")
    @property
    def vector_dtype(self) -> np.dtype:
        # should only ever be called after pydantic has validated the schema
        return INMEMORYDB_VECTOR_DTYPE_MAP[self.content_vector.datatype]
    @property
    def is_empty(self) -> bool:
        return all(
            field is None for field in [self.tag, self.text, self.numeric, self.vector]
        )
[docs]
    def get_fields(self) -> List["InMemoryDBField"]:
        memorydb_fields: List["InMemoryDBField"] = []
        if self.is_empty:
            return memorydb_fields
        for field_name in self.__fields__.keys():
            if field_name not in ["content_key", "content_vector_key", "extra"]:
                field_group = getattr(self, field_name)
                if field_group is not None:
                    for field in field_group:
                        memorydb_fields.append(field.as_field())
        return memorydb_fields 
    @property
    def metadata_keys(self) -> List[str]:
        keys: List[str] = []
        if self.is_empty:
            return keys
        for field_name in self.__fields__.keys():
            field_group = getattr(self, field_name)
            if field_group is not None:
                for field in field_group:
                    # check if it's a metadata field. exclude vector and content key
                    if not isinstance(field, str) and field.name not in [
                        self.content_key,
                        self.content_vector_key,
                    ]:
                        keys.append(field.name)
        return keys 
[docs]
def read_schema(
    index_schema: Optional[Union[Dict[str, List[Any]], str, os.PathLike]],
) -> Dict[str, Any]:
    """Read in the index schema from a dict or yaml file.
    Check if it is a dict and return RedisModel otherwise, check if it's a path and
    read in the file assuming it's a yaml file and return a RedisModel
    """
    if isinstance(index_schema, dict):
        return index_schema
    elif isinstance(index_schema, Path):
        with open(index_schema, "rb") as f:
            return yaml.safe_load(f)
    elif isinstance(index_schema, str):
        if Path(index_schema).resolve().is_file():
            with open(index_schema, "rb") as f:
                return yaml.safe_load(f)
        else:
            raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
    else:
        raise TypeError(
            f"index_schema must be a dict, or path to a yaml file "
            f"Got {type(index_schema)}"
        )