Source code for langchain_elasticsearch.cache

import base64
import hashlib
import logging
from datetime import datetime
from functools import cached_property
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
)

from elasticsearch import (
    Elasticsearch,
    exceptions,
    helpers,
)
from elasticsearch.helpers import BulkIndexError
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.load import dumps, loads
from langchain_core.stores import ByteStore

from langchain_elasticsearch.client import create_elasticsearch_client

if TYPE_CHECKING:
    from elasticsearch import Elasticsearch

logger = logging.getLogger(__name__)


def _manage_cache_index(
    es_client: Elasticsearch, index_name: str, mapping: Dict[str, Any]
) -> bool:
    """Write or update an index or alias according to the default mapping"""
    if es_client.indices.exists_alias(name=index_name):
        es_client.indices.put_mapping(index=index_name, body=mapping["mappings"])
        return True

    elif not es_client.indices.exists(index=index_name):
        logger.debug(f"Creating new Elasticsearch index: {index_name}")
        es_client.indices.create(index=index_name, body=mapping)
        return False

    return False


[docs] class ElasticsearchCache(BaseCache): """An Elasticsearch cache integration for LLMs."""
[docs] def __init__( self, index_name: str, store_input: bool = True, store_input_params: bool = True, metadata: Optional[Dict[str, Any]] = None, *, es_url: Optional[str] = None, es_cloud_id: Optional[str] = None, es_user: Optional[str] = None, es_api_key: Optional[str] = None, es_password: Optional[str] = None, es_params: Optional[Dict[str, Any]] = None, ): """ Initialize the Elasticsearch cache store by specifying the index/alias to use and determining which additional information (like input, input parameters, and any other metadata) should be stored in the cache. Args: index_name (str): The name of the index or the alias to use for the cache. If they do not exist an index is created, according to the default mapping defined by the `mapping` property. store_input (bool): Whether to store the LLM input in the cache, i.e., the input prompt. Default to True. store_input_params (bool): Whether to store the input parameters in the cache, i.e., the LLM parameters used to generate the LLM response. Default to True. metadata (Optional[dict]): Additional metadata to store in the cache, for filtering purposes. This must be JSON serializable in an Elasticsearch document. Default to None. es_url: URL of the Elasticsearch instance to connect to. es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. es_user: Username to use when connecting to Elasticsearch. es_password: Password to use when connecting to Elasticsearch. es_api_key: API key to use when connecting to Elasticsearch. es_params: Other parameters for the Elasticsearch client. """ self._index_name = index_name self._store_input = store_input self._store_input_params = store_input_params self._metadata = metadata self._es_client = create_elasticsearch_client( url=es_url, cloud_id=es_cloud_id, api_key=es_api_key, username=es_user, password=es_password, params=es_params, ) self._is_alias = _manage_cache_index( self._es_client, self._index_name, self.mapping, )
@cached_property def mapping(self) -> Dict[str, Any]: """Get the default mapping for the index.""" return { "mappings": { "properties": { "llm_output": {"type": "text", "index": False}, "llm_params": {"type": "text", "index": False}, "llm_input": {"type": "text", "index": False}, "metadata": {"type": "object"}, "timestamp": {"type": "date"}, } } } @staticmethod def _key(prompt: str, llm_string: str) -> str: """Generate a key for the cache store.""" return hashlib.md5((prompt + llm_string).encode()).hexdigest()
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" cache_key = self._key(prompt, llm_string) if self._is_alias: # get the latest record according to its writing date, in order to # address cases where multiple indices have a doc with the same id result = self._es_client.search( index=self._index_name, body={ "query": {"term": {"_id": cache_key}}, "sort": {"timestamp": {"order": "asc"}}, }, source_includes=["llm_output"], ) if result["hits"]["total"]["value"] > 0: record = result["hits"]["hits"][0] else: return None else: try: record = self._es_client.get( index=self._index_name, id=cache_key, source=["llm_output"] ) except exceptions.NotFoundError: return None return [loads(item) for item in record["_source"]["llm_output"]]
[docs] def build_document( self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE ) -> Dict[str, Any]: """Build the Elasticsearch document for storing a single LLM interaction""" body: Dict[str, Any] = { "llm_output": [dumps(item) for item in return_val], "timestamp": datetime.now().isoformat(), } if self._store_input_params: body["llm_params"] = llm_string if self._metadata is not None: body["metadata"] = self._metadata if self._store_input: body["llm_input"] = prompt return body
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update based on prompt and llm_string.""" body = self.build_document(prompt, llm_string, return_val) self._es_client.index( index=self._index_name, id=self._key(prompt, llm_string), body=body, require_alias=self._is_alias, refresh=True, )
[docs] def clear(self, **kwargs: Any) -> None: """Clear cache.""" self._es_client.delete_by_query( index=self._index_name, body={"query": {"match_all": {}}}, refresh=True, wait_for_completion=True, )
[docs] class ElasticsearchEmbeddingsCache(ByteStore): """An Elasticsearch store for caching embeddings."""
[docs] def __init__( self, index_name: str, store_input: bool = True, metadata: Optional[Dict[str, Any]] = None, namespace: Optional[str] = None, maximum_duplicates_allowed: int = 1, *, es_url: Optional[str] = None, es_cloud_id: Optional[str] = None, es_user: Optional[str] = None, es_api_key: Optional[str] = None, es_password: Optional[str] = None, es_params: Optional[Dict[str, Any]] = None, ): """ Initialize the Elasticsearch cache store by specifying the index/alias to use and determining which additional information (like input, input parameters, and any other metadata) should be stored in the cache. Provide a namespace to organize the cache. Args: index_name (str): The name of the index or the alias to use for the cache. If they do not exist an index is created, according to the default mapping defined by the `mapping` property. store_input (bool): Whether to store the input in the cache. Default to True. metadata (Optional[dict]): Additional metadata to store in the cache, for filtering purposes. This must be JSON serializable in an Elasticsearch document. Default to None. namespace (Optional[str]): A namespace to use for the cache. maximum_duplicates_allowed (int): Defines the maximum number of duplicate keys permitted. Must be used in scenarios where the same key appears across multiple indices that share the same alias. Default to 1. es_url: URL of the Elasticsearch instance to connect to. es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. es_user: Username to use when connecting to Elasticsearch. es_password: Password to use when connecting to Elasticsearch. es_api_key: API key to use when connecting to Elasticsearch. es_params: Other parameters for the Elasticsearch client. """ self._namespace = namespace self._maximum_duplicates_allowed = maximum_duplicates_allowed self._index_name = index_name self._store_input = store_input self._metadata = metadata self._es_client = create_elasticsearch_client( url=es_url, cloud_id=es_cloud_id, api_key=es_api_key, username=es_user, password=es_password, params=es_params, ) self._is_alias = _manage_cache_index( self._es_client, self._index_name, self.mapping, )
[docs] @staticmethod def encode_vector(data: bytes) -> str: """Encode the vector data as bytes to as a base64 string.""" return base64.b64encode(data).decode("utf-8")
[docs] @staticmethod def decode_vector(data: str) -> bytes: """Decode the base64 string to vector data as bytes.""" return base64.b64decode(data)
@cached_property def mapping(self) -> Dict[str, Any]: """Get the default mapping for the index.""" return { "mappings": { "properties": { "text_input": {"type": "text", "index": False}, "vector_dump": { "type": "binary", "doc_values": False, }, "metadata": {"type": "object"}, "timestamp": {"type": "date"}, "namespace": {"type": "keyword"}, } } } def _key(self, input_text: str) -> str: """Generate a key for the store.""" return hashlib.md5(((self._namespace or "") + input_text).encode()).hexdigest() @classmethod def _deduplicate_hits(cls, hits: List[dict]) -> Dict[str, bytes]: """ Collapse the results from a search query with multiple indices returning only the latest version of the documents """ map_ids = {} for hit in sorted( hits, key=lambda x: datetime.fromisoformat(x["_source"]["timestamp"]), reverse=True, ): vector_id: str = hit["_id"] if vector_id not in map_ids: map_ids[vector_id] = cls.decode_vector(hit["_source"]["vector_dump"]) return map_ids
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: """Get the values associated with the given keys.""" if not any(keys): return [] cache_keys = [self._key(k) for k in keys] if self._is_alias: try: results = self._es_client.search( index=self._index_name, body={ "query": {"ids": {"values": cache_keys}}, "size": len(cache_keys) * self._maximum_duplicates_allowed, }, source_includes=["vector_dump", "timestamp"], ) except exceptions.BadRequestError as e: if "window too large" in ( e.body.get("error", {}).get("root_cause", [{}])[0].get("reason", "") ): logger.warning( "Exceeded the maximum window size, " "Reduce the duplicates manually or lower " "`maximum_duplicate_allowed.`" ) raise e total_hits = results["hits"]["total"]["value"] if self._maximum_duplicates_allowed > 1 and total_hits > len(cache_keys): logger.warning( f"Deduplicating, found {total_hits} hits for {len(cache_keys)} keys" ) map_ids = self._deduplicate_hits(results["hits"]["hits"]) else: map_ids = { r["_id"]: self.decode_vector(r["_source"]["vector_dump"]) for r in results["hits"]["hits"] } return [map_ids.get(k) for k in cache_keys] else: records = self._es_client.mget( index=self._index_name, ids=cache_keys, source_includes=["vector_dump"] ) return [ self.decode_vector(r["_source"]["vector_dump"]) if r["found"] else None for r in records["docs"] ]
[docs] def build_document(self, text_input: str, vector: bytes) -> Dict[str, Any]: """Build the Elasticsearch document for storing a single embedding""" body: Dict[str, Any] = { "vector_dump": self.encode_vector(vector), "timestamp": datetime.now().isoformat(), } if self._metadata is not None: body["metadata"] = self._metadata if self._store_input: body["text_input"] = text_input if self._namespace: body["namespace"] = self._namespace return body
def _bulk(self, actions: Iterable[Dict[str, Any]]) -> None: try: helpers.bulk( client=self._es_client, actions=actions, index=self._index_name, require_alias=self._is_alias, refresh=True, ) except BulkIndexError as e: first_error = e.errors[0].get("index", {}).get("error", {}) logger.error(f"First bulk error reason: {first_error.get('reason')}") raise e
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: """Set the values for the given keys.""" actions = ( { "_op_type": "index", "_id": self._key(key), "_source": self.build_document(key, vector), } for key, vector in key_value_pairs ) self._bulk(actions)
[docs] def mdelete(self, keys: Sequence[str]) -> None: """Delete the given keys and their associated values.""" actions = ({"_op_type": "delete", "_id": self._key(key)} for key in keys) self._bulk(actions)
[docs] def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: """Get an iterator over keys that match the given prefix.""" # TODO This method is not currently used by CacheBackedEmbeddings, # we can leave it blank. It could be implemented with ES "index_prefixes", # but they are limited and expensive. raise NotImplementedError()