import json
import logging
import re
from typing import Any, Dict, List, Mapping, MutableMapping, Tuple, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import BaseOutputParser
from langchain_cohere import CohereCitation
OUTPUT_KEY = "output"
GROUNDED_ANSWER_KEY = "grounded_answer"
[docs]
def parse_answer_with_prefixes(
completion: str, prefixes: Dict[str, str]
) -> Dict[str, str]:
"""parses string into key-value pairs,
according to patterns supplied in prefixes. Also strips.
if inputs are:
completion = "\nhello: sam\ngoodbye then: paul.",
prefixes = {"greeting": "hello:", "farewell": "goodbye then:"}
the expected returned result is:
{"greeting": "sam", "farewell": "paul."}
Args:
completion (str): text to split
prefixes (Dict[str, str]): a key-value dict of keys and patterns.
See example above
Returns:
Dict[str, str]: parsed result
"""
# sort out prefixes
re_pat = "(" + "|".join([re.escape(p) for p in prefixes.values()]) + ")"
reverse_prefix_map = {v: k for k, v in prefixes.items()}
split = re.split(re_pat, completion)
split = split[1:]
parsed = {}
for prefix, value in zip(split[::2], split[1::2]):
if prefix in reverse_prefix_map: # if the prefix is a match
if (
reverse_prefix_map[prefix] not in parsed
): # first occurrence of a prefix is kept, others discarded
parsed[reverse_prefix_map[prefix]] = value.strip()
return parsed
[docs]
def parse_actions(generation: str) -> Tuple[str, str, List[Dict]]:
"""Parse action selections from model output."""
plan = ""
generation = generation.strip()
actions = generation
try:
if "Plan: " in generation or "Reflection: " in generation:
# Model is trained to output a Plan or Reflection followed by an action.
# Use regex to extract the plan and action.
regex = r"^(Plan|Reflection)\s*\d*\s*:(.*?)(Action\s*\d*\s*:\s*\d*\s*```json\n.*?```)" # noqa: E501
action_match = re.search(regex, generation, re.DOTALL)
if not action_match:
raise ValueError(
f"Failed to parse multihop completion for input: {generation}"
)
plan = action_match.group(2).strip()
actions = action_match.group(3).strip()
else:
# Catch the case where model outputs only an action.
regex = r"^(Action\s*\d*\s*:\s*\d*\s*```json\n.*?```)"
action_match = re.search(regex, generation, re.DOTALL)
if not action_match:
raise ValueError(
f"Failed to parse multihop completion for input: {generation}"
)
actions = action_match.group(1).strip()
except Exception as e:
logging.error(f"Failed to parse multihop completion for input: {generation}")
logging.error(f"Error: {e}")
parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:")
return generation, plan, parsed_actions
[docs]
def parse_citations(
grounded_answer: str, documents: List[MutableMapping]
) -> Tuple[str, List[CohereCitation]]:
"""
Parses a grounded_generation (from parse_actions) and documents (from
convert_to_documents) into a (generation, CohereCitation list) tuple.
"""
no_markup_answer, parsed_answer = _parse_answer_spans(grounded_answer)
citations: List[CohereCitation] = []
start = 0
# Add an id field to each document. This may be useful for future deduplication.
for i in range(len(documents)):
documents[i]["id"] = documents[i].get("id") or f"doc_{i}"
for answer in parsed_answer:
text = answer.get("text", "")
document_indexes = answer.get("cited_docs")
if not document_indexes:
# There were no citations for this piece of text.
start += len(text)
continue
end = start + len(text)
# Look up the cited document by index
cited_documents: List[Mapping] = []
cited_document_ids: List[str] = []
for index in set(document_indexes):
if index >= len(documents):
# The document index doesn't exist
continue
cited_documents.append(documents[index])
cited_document_ids.append(documents[index]["id"])
citations.append(
CohereCitation(
start=start,
end=end,
text=text,
documents=cited_documents,
document_ids=set(cited_document_ids),
)
)
start = end
return no_markup_answer, citations
def _strip_spans(answer: str) -> str:
"""removes any <co> tags from a string, including trailing partial tags
input: "hi my <co>name</co> is <co: 1> patrick</co:3> and <co"
output: "hi my name is patrick and"
Args:
answer (str): string
Returns:
str: same string with co tags removed
"""
answer = re.sub(r"<co(.*?)>|</co(.*?)>", "", answer)
idx = answer.find("<co")
if idx > -1:
answer = answer[:idx]
idx = answer.find("</")
if idx > -1:
answer = answer[:idx]
return answer
def _parse_answer_spans(grounded_answer: str) -> Tuple[str, List[Dict[str, Any]]]:
actual_cites = []
for c in re.findall(r"<co:(.*?)>", grounded_answer):
actual_cites.append(c.strip().split(","))
no_markup_answer = _strip_spans(grounded_answer)
current_idx = 0
parsed_answer: List[Dict[str, Union[str, List[int]]]] = []
cited_docs_set = []
last_entry_is_open_cite = False
parsed_current_cite_document_idxs: List[int] = []
while current_idx < len(grounded_answer):
current_cite = re.search(r"<co: (.*?)>", grounded_answer[current_idx:])
if current_cite:
# previous part
parsed_answer.append(
{
"text": grounded_answer[
current_idx : current_idx + current_cite.start()
]
}
)
current_cite_document_idxs = current_cite.group(1).split(",")
parsed_current_cite_document_idxs = []
for cited_idx in current_cite_document_idxs:
if cited_idx.isdigit():
cited_idx = int(cited_idx.strip())
parsed_current_cite_document_idxs.append(cited_idx)
if cited_idx not in cited_docs_set:
cited_docs_set.append(cited_idx)
current_idx += current_cite.end()
current_cite_close = re.search(
r"</co: " + current_cite.group(1) + ">", grounded_answer[current_idx:]
)
if current_cite_close:
# there might have been issues parsing the ids, so we need to check
# that they are actually ints and available
if len(parsed_current_cite_document_idxs) > 0:
pt = grounded_answer[
current_idx : current_idx + current_cite_close.start()
]
parsed_answer.append(
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
)
else:
parsed_answer.append(
{
"text": grounded_answer[
current_idx : current_idx + current_cite_close.start()
],
}
)
current_idx += current_cite_close.end()
else:
last_entry_is_open_cite = True
break
else:
break
# don't forget about the last one
if last_entry_is_open_cite:
pt = _strip_spans(grounded_answer[current_idx:])
parsed_answer.append(
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
)
else:
parsed_answer.append({"text": _strip_spans(grounded_answer[current_idx:])})
return no_markup_answer, parsed_answer