from __future__ import annotations
from typing import (
    Any,
    AsyncIterator,
    Callable,
    Iterator,
    List,
    Mapping,
    Optional,
    Union,
    cast,
)
from typing_extensions import TypedDict
from langchain_core.runnables.base import (
    Input,
    Output,
    Runnable,
    RunnableSerializable,
    coerce_to_runnable,
)
from langchain_core.runnables.config import (
    RunnableConfig,
    get_config_list,
    get_executor_for_config,
)
from langchain_core.runnables.utils import (
    ConfigurableFieldSpec,
    gather_with_concurrency,
    get_unique_config_specs,
)
[docs]
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
    """
    Runnable that routes to a set of Runnables based on Input['key'].
    Returns the output of the selected Runnable.
    Parameters:
        runnables: A mapping of keys to Runnables.
    For example,
    .. code-block:: python
        from langchain_core.runnables.router import RouterRunnable
        from langchain_core.runnables import RunnableLambda
        add = RunnableLambda(func=lambda x: x + 1)
        square = RunnableLambda(func=lambda x: x**2)
        router = RouterRunnable(runnables={"add": add, "square": square})
        router.invoke({"key": "square", "input": 3})
    """
    runnables: Mapping[str, Runnable[Any, Output]]
    @property
    def config_specs(self) -> List[ConfigurableFieldSpec]:
        return get_unique_config_specs(
            spec for step in self.runnables.values() for spec in step.config_specs
        )
    def __init__(
        self,
        runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
    ) -> None:
        super().__init__(  # type: ignore[call-arg]
            runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
        )
    class Config:
        arbitrary_types_allowed = True
    @classmethod
    def is_lc_serializable(cls) -> bool:
        """Return whether this class is serializable."""
        return True
    @classmethod
    def get_lc_namespace(cls) -> List[str]:
        """Get the namespace of the langchain object."""
        return ["langchain", "schema", "runnable"]
[docs]
    def invoke(
        self, input: RouterInput, config: Optional[RunnableConfig] = None
    ) -> Output:
        key = input["key"]
        actual_input = input["input"]
        if key not in self.runnables:
            raise ValueError(f"No runnable associated with key '{key}'")
        runnable = self.runnables[key]
        return runnable.invoke(actual_input, config) 
[docs]
    async def ainvoke(
        self,
        input: RouterInput,
        config: Optional[RunnableConfig] = None,
        **kwargs: Optional[Any],
    ) -> Output:
        key = input["key"]
        actual_input = input["input"]
        if key not in self.runnables:
            raise ValueError(f"No runnable associated with key '{key}'")
        runnable = self.runnables[key]
        return await runnable.ainvoke(actual_input, config) 
[docs]
    def batch(
        self,
        inputs: List[RouterInput],
        config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
        *,
        return_exceptions: bool = False,
        **kwargs: Optional[Any],
    ) -> List[Output]:
        if not inputs:
            return []
        keys = [input["key"] for input in inputs]
        actual_inputs = [input["input"] for input in inputs]
        if any(key not in self.runnables for key in keys):
            raise ValueError("One or more keys do not have a corresponding runnable")
        def invoke(
            runnable: Runnable, input: Input, config: RunnableConfig
        ) -> Union[Output, Exception]:
            if return_exceptions:
                try:
                    return runnable.invoke(input, config, **kwargs)
                except Exception as e:
                    return e
            else:
                return runnable.invoke(input, config, **kwargs)
        runnables = [self.runnables[key] for key in keys]
        configs = get_config_list(config, len(inputs))
        with get_executor_for_config(configs[0]) as executor:
            return cast(
                List[Output],
                list(executor.map(invoke, runnables, actual_inputs, configs)),
            ) 
[docs]
    async def abatch(
        self,
        inputs: List[RouterInput],
        config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
        *,
        return_exceptions: bool = False,
        **kwargs: Optional[Any],
    ) -> List[Output]:
        if not inputs:
            return []
        keys = [input["key"] for input in inputs]
        actual_inputs = [input["input"] for input in inputs]
        if any(key not in self.runnables for key in keys):
            raise ValueError("One or more keys do not have a corresponding runnable")
        async def ainvoke(
            runnable: Runnable, input: Input, config: RunnableConfig
        ) -> Union[Output, Exception]:
            if return_exceptions:
                try:
                    return await runnable.ainvoke(input, config, **kwargs)
                except Exception as e:
                    return e
            else:
                return await runnable.ainvoke(input, config, **kwargs)
        runnables = [self.runnables[key] for key in keys]
        configs = get_config_list(config, len(inputs))
        return await gather_with_concurrency(
            configs[0].get("max_concurrency"),
            *(
                ainvoke(runnable, input, config)
                for runnable, input, config in zip(runnables, actual_inputs, configs)
            ),
        ) 
[docs]
    def stream(
        self,
        input: RouterInput,
        config: Optional[RunnableConfig] = None,
        **kwargs: Optional[Any],
    ) -> Iterator[Output]:
        key = input["key"]
        actual_input = input["input"]
        if key not in self.runnables:
            raise ValueError(f"No runnable associated with key '{key}'")
        runnable = self.runnables[key]
        yield from runnable.stream(actual_input, config) 
[docs]
    async def astream(
        self,
        input: RouterInput,
        config: Optional[RunnableConfig] = None,
        **kwargs: Optional[Any],
    ) -> AsyncIterator[Output]:
        key = input["key"]
        actual_input = input["input"]
        if key not in self.runnables:
            raise ValueError(f"No runnable associated with key '{key}'")
        runnable = self.runnables[key]
        async for output in runnable.astream(actual_input, config):
            yield output