Source code for langchain.chains.transform
"""Chain that runs an arbitrary python function."""
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
logger = logging.getLogger(__name__)
[docs]
class TransformChain(Chain):
    """Chain that transforms the chain output.
    Example:
        .. code-block:: python
            from langchain.chains import TransformChain
            transform_chain = TransformChain(input_variables=["text"],
             output_variables["entities"], transform=func())
    """
    input_variables: List[str]
    """The keys expected by the transform's input dictionary."""
    output_variables: List[str]
    """The keys returned by the transform's output dictionary."""
    transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
    """The transform function."""
    atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = (
        Field(None, alias="atransform")
    )
    """The async coroutine transform function."""
    @staticmethod
    @functools.lru_cache
    def _log_once(msg: str) -> None:
        """Log a message once.
        :meta private:
        """
        logger.warning(msg)
    @property
    def input_keys(self) -> List[str]:
        """Expect input keys.
        :meta private:
        """
        return self.input_variables
    @property
    def output_keys(self) -> List[str]:
        """Return output keys.
        :meta private:
        """
        return self.output_variables
    def _call(
        self,
        inputs: Dict[str, str],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        return self.transform_cb(inputs)
    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        if self.atransform_cb is not None:
            return await self.atransform_cb(inputs)
        else:
            self._log_once(
                "TransformChain's atransform is not provided, falling"
                " back to synchronous transform"
            )
            return self.transform_cb(inputs)