Source code for langchain.chains.qa_generation.base

from __future__ import annotations

import json
from typing import Any, Dict, List, Optional

from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR


[docs] @deprecated( since="0.2.7", alternative=( "example in API reference with more detail: " "https://api.python.langchain.com/en/latest/chains/langchain.chains.qa_generation.base.QAGenerationChain.html" # noqa: E501 ), removal="1.0", ) class QAGenerationChain(Chain): """Base class for question-answer generation chains. This class is deprecated. See below for an alternative implementation. Advantages of this implementation include: - Supports async and streaming; - Surfaces prompt and text splitter for easier customization; - Use of JsonOutputParser supports JSONPatch operations in streaming mode, as well as robustness to markdown. .. code-block:: python from langchain.chains.qa_generation.prompt import CHAT_PROMPT as prompt # Note: import PROMPT if using a legacy non-chat model. from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import ( RunnableLambda, RunnableParallel, RunnablePassthrough, ) from langchain_core.runnables.base import RunnableEach from langchain_openai import ChatOpenAI from langchain_text_splitters import RecursiveCharacterTextSplitter llm = ChatOpenAI() text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=500) split_text = RunnableLambda( lambda x: text_splitter.create_documents([x]) ) chain = RunnableParallel( text=RunnablePassthrough(), questions=( split_text | RunnableEach(bound=prompt | llm | JsonOutputParser()) ) ) """ llm_chain: LLMChain """LLM Chain that generates responses from user input and context.""" text_splitter: TextSplitter = Field( default=RecursiveCharacterTextSplitter(chunk_overlap=500) ) """Text splitter that splits the input into chunks.""" input_key: str = "text" """Key of the input to the chain.""" output_key: str = "questions" """Key of the output of the chain.""" k: Optional[int] = None """Number of questions to generate."""
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> QAGenerationChain: """ Create a QAGenerationChain from a language model. Args: llm: a language model prompt: a prompt template **kwargs: additional arguments Returns: a QAGenerationChain class """ _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) chain = LLMChain(llm=llm, prompt=_prompt) return cls(llm_chain=chain, **kwargs)
@property def _chain_type(self) -> str: raise NotImplementedError @property def input_keys(self) -> List[str]: return [self.input_key] @property def output_keys(self) -> List[str]: return [self.output_key] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, List]: docs = self.text_splitter.create_documents([inputs[self.input_key]]) results = self.llm_chain.generate( [{"text": d.page_content} for d in docs], run_manager=run_manager ) qa = [json.loads(res[0].text) for res in results.generations] return {self.output_key: qa}