from __future__ import annotations
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import Generation
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
PROMPT,
QUESTION_GENERATOR_PROMPT,
FinishedOutputParser,
)
from langchain.chains.llm import LLMChain
class _ResponseChain(LLMChain):
"""Base class for chains that generate responses."""
prompt: BasePromptTemplate = PROMPT
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
return self.prompt.input_variables
def generate_tokens_and_log_probs(
self,
_input: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[Sequence[str], Sequence[float]]:
llm_result = self.generate([_input], run_manager=run_manager)
return self._extract_tokens_and_log_probs(llm_result.generations[0])
@abstractmethod
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
"""Extract tokens and log probs from response."""
class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context."""
llm: BaseLanguageModel
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
tokens = []
log_probs = []
for gen in generations:
if gen.generation_info is None:
raise ValueError
tokens.extend(gen.generation_info["logprobs"]["tokens"])
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
return tokens, log_probs
[docs]
class QuestionGeneratorChain(LLMChain):
"""Chain that generates questions from uncertain spans."""
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
"""Prompt template for the chain."""
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input", "context", "response"]
def _low_confidence_spans(
tokens: Sequence[str],
log_probs: Sequence[float],
min_prob: float,
min_token_gap: int,
num_pad_tokens: int,
) -> List[str]:
_low_idx = np.where(np.exp(log_probs) < min_prob)[0]
low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])]
if len(low_idx) == 0:
return []
spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]]
for i, idx in enumerate(low_idx[1:]):
end = idx + num_pad_tokens + 1
if idx - low_idx[i] < min_token_gap:
spans[-1][1] = end
else:
spans.append([idx, end])
return ["".join(tokens[start:end]) for start, end in spans]
[docs]
class FlareChain(Chain):
"""Chain that combines a retriever, a question generator,
and a response generator."""
question_generator_chain: QuestionGeneratorChain
"""Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
retriever: BaseRetriever
"""Retriever that retrieves relevant documents from a user input."""
min_prob: float = 0.2
"""Minimum probability for a token to be considered low confidence."""
min_token_gap: int = 5
"""Minimum number of tokens between two low confidence spans."""
num_pad_tokens: int = 2
"""Number of tokens to pad around a low confidence span."""
max_iter: int = 10
"""Maximum number of iterations."""
start_with_retrieval: bool = True
"""Whether to start with retrieval."""
@property
def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input"]
@property
def output_keys(self) -> List[str]:
"""Output keys for the chain."""
return ["response"]
def _do_generation(
self,
questions: List[str],
user_input: str,
response: str,
_run_manager: CallbackManagerForChainRun,
) -> Tuple[str, bool]:
callbacks = _run_manager.get_child()
docs = []
for question in questions:
docs.extend(self.retriever.invoke(question))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,
context=context,
response=response,
callbacks=callbacks,
)
marginal, finished = self.output_parser.parse(result)
return marginal, finished
def _do_retrieval(
self,
low_confidence_spans: List[str],
_run_manager: CallbackManagerForChainRun,
user_input: str,
response: str,
initial_response: str,
) -> Tuple[str, bool]:
question_gen_inputs = [
{
"user_input": user_input,
"current_response": initial_response,
"uncertain_span": span,
}
for span in low_confidence_spans
]
callbacks = _run_manager.get_child()
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
_run_manager.on_text(
f"Generated Questions: {questions}", color="yellow", end="\n"
)
return self._do_generation(questions, user_input, response, _run_manager)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
user_input = inputs[self.input_keys[0]]
response = ""
for i in range(self.max_iter):
_run_manager.on_text(
f"Current Response: {response}", color="blue", end="\n"
)
_input = {"user_input": user_input, "context": "", "response": response}
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
_input, run_manager=_run_manager
)
low_confidence_spans = _low_confidence_spans(
tokens,
log_probs,
self.min_prob,
self.min_token_gap,
self.num_pad_tokens,
)
initial_response = response.strip() + " " + "".join(tokens)
if not low_confidence_spans:
response = initial_response
final_response, finished = self.output_parser.parse(response)
if finished:
return {self.output_keys[0]: final_response}
continue
marginal, finished = self._do_retrieval(
low_confidence_spans,
_run_manager,
user_input,
response,
initial_response,
)
response = response.strip() + " " + marginal
if finished:
break
return {self.output_keys[0]: response}
[docs]
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
) -> FlareChain:
"""Creates a FlareChain from a language model.
Args:
llm: Language model to use.
max_generation_len: Maximum length of the generated response.
kwargs: Additional arguments to pass to the constructor.
Returns:
FlareChain class with the given language model.
"""
try:
from langchain_openai import OpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
)
response_chain = _OpenAIResponseChain(llm=response_llm)
return cls(
question_generator_chain=question_gen_chain,
response_chain=response_chain,
**kwargs,
)