import json
import logging
import re
from typing import Any, Dict, List, Mapping, MutableMapping, Tuple, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import BaseOutputParser
from langchain_cohere import CohereCitation
OUTPUT_KEY = "output"
GROUNDED_ANSWER_KEY = "grounded_answer"
[docs]
def parse_answer_with_prefixes(
    completion: str, prefixes: Dict[str, str]
) -> Dict[str, str]:
    """parses string into key-value pairs,
       according to patterns supplied in prefixes. Also strips.
    if inputs are:
        completion = "\nhello: sam\ngoodbye then: paul.",
        prefixes = {"greeting": "hello:", "farewell": "goodbye then:"}
    the expected returned result is:
        {"greeting": "sam", "farewell": "paul."}
    Args:
        completion (str): text to split
        prefixes (Dict[str, str]): a key-value dict of keys and patterns.
        See example above
    Returns:
        Dict[str, str]: parsed result
    """
    # sort out prefixes
    re_pat = "(" + "|".join([re.escape(p) for p in prefixes.values()]) + ")"
    reverse_prefix_map = {v: k for k, v in prefixes.items()}
    split = re.split(re_pat, completion)
    split = split[1:]
    parsed = {}
    for prefix, value in zip(split[::2], split[1::2]):
        if prefix in reverse_prefix_map:  # if the prefix is a match
            if (
                reverse_prefix_map[prefix] not in parsed
            ):  # first occurrence of a prefix is kept, others discarded
                parsed[reverse_prefix_map[prefix]] = value.strip()
    return parsed 
[docs]
def parse_actions(generation: str) -> Tuple[str, str, List[Dict]]:
    """Parse action selections from model output."""
    plan = ""
    generation = generation.strip()
    actions = generation
    try:
        if "Plan: " in generation or "Reflection: " in generation:
            # Model is trained to output a Plan or Reflection followed by an action.
            # Use regex to extract the plan and action.
            regex = r"^(Plan|Reflection)\s*\d*\s*:(.*?)(Action\s*\d*\s*:\s*\d*\s*```json\n.*?```)"  # noqa: E501
            action_match = re.search(regex, generation, re.DOTALL)
            if not action_match:
                raise ValueError(
                    f"Failed to parse multihop completion for input: {generation}"
                )
            plan = action_match.group(2).strip()
            actions = action_match.group(3).strip()
        else:
            # Catch the case where model outputs only an action.
            regex = r"^(Action\s*\d*\s*:\s*\d*\s*```json\n.*?```)"
            action_match = re.search(regex, generation, re.DOTALL)
            if not action_match:
                raise ValueError(
                    f"Failed to parse multihop completion for input: {generation}"
                )
            actions = action_match.group(1).strip()
    except Exception as e:
        logging.error(f"Failed to parse multihop completion for input: {generation}")
        logging.error(f"Error: {e}")
    parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:")
    return generation, plan, parsed_actions 
[docs]
def parse_citations(
    grounded_answer: str, documents: List[MutableMapping]
) -> Tuple[str, List[CohereCitation]]:
    """
    Parses a grounded_generation (from parse_actions) and documents (from
    convert_to_documents) into a (generation, CohereCitation list) tuple.
    """
    no_markup_answer, parsed_answer = _parse_answer_spans(grounded_answer)
    citations: List[CohereCitation] = []
    start = 0
    # Add an id field to each document. This may be useful for future deduplication.
    for i in range(len(documents)):
        documents[i]["id"] = documents[i].get("id") or f"doc_{i}"
    for answer in parsed_answer:
        text = answer.get("text", "")
        document_indexes = answer.get("cited_docs")
        if not document_indexes:
            # There were no citations for this piece of text.
            start += len(text)
            continue
        end = start + len(text)
        # Look up the cited document by index
        cited_documents: List[Mapping] = []
        cited_document_ids: List[str] = []
        for index in set(document_indexes):
            if index >= len(documents):
                # The document index doesn't exist
                continue
            cited_documents.append(documents[index])
            cited_document_ids.append(documents[index]["id"])
        citations.append(
            CohereCitation(
                start=start,
                end=end,
                text=text,
                documents=cited_documents,
                document_ids=set(cited_document_ids),
            )
        )
        start = end
    return no_markup_answer, citations 
def _strip_spans(answer: str) -> str:
    """removes any <co> tags from a string, including trailing partial tags
    input: "hi my <co>name</co> is <co: 1> patrick</co:3> and <co"
    output: "hi my name is patrick and"
    Args:
        answer (str): string
    Returns:
        str: same string with co tags removed
    """
    answer = re.sub(r"<co(.*?)>|</co(.*?)>", "", answer)
    idx = answer.find("<co")
    if idx > -1:
        answer = answer[:idx]
    idx = answer.find("</")
    if idx > -1:
        answer = answer[:idx]
    return answer
def _parse_answer_spans(grounded_answer: str) -> Tuple[str, List[Dict[str, Any]]]:
    actual_cites = []
    for c in re.findall(r"<co:(.*?)>", grounded_answer):
        actual_cites.append(c.strip().split(","))
    no_markup_answer = _strip_spans(grounded_answer)
    current_idx = 0
    parsed_answer: List[Dict[str, Union[str, List[int]]]] = []
    cited_docs_set = []
    last_entry_is_open_cite = False
    parsed_current_cite_document_idxs: List[int] = []
    while current_idx < len(grounded_answer):
        current_cite = re.search(r"<co: (.*?)>", grounded_answer[current_idx:])
        if current_cite:
            # previous part
            parsed_answer.append(
                {
                    "text": grounded_answer[
                        current_idx : current_idx + current_cite.start()
                    ]
                }
            )
            current_cite_document_idxs = current_cite.group(1).split(",")
            parsed_current_cite_document_idxs = []
            for cited_idx in current_cite_document_idxs:
                if cited_idx.isdigit():
                    cited_idx = int(cited_idx.strip())
                    parsed_current_cite_document_idxs.append(cited_idx)
                    if cited_idx not in cited_docs_set:
                        cited_docs_set.append(cited_idx)
            current_idx += current_cite.end()
            current_cite_close = re.search(
                r"</co: " + current_cite.group(1) + ">", grounded_answer[current_idx:]
            )
            if current_cite_close:
                # there might have been issues parsing the ids, so we need to check
                # that they are actually ints and available
                if len(parsed_current_cite_document_idxs) > 0:
                    pt = grounded_answer[
                        current_idx : current_idx + current_cite_close.start()
                    ]
                    parsed_answer.append(
                        {"text": pt, "cited_docs": parsed_current_cite_document_idxs}
                    )
                else:
                    parsed_answer.append(
                        {
                            "text": grounded_answer[
                                current_idx : current_idx + current_cite_close.start()
                            ],
                        }
                    )
                current_idx += current_cite_close.end()
            else:
                last_entry_is_open_cite = True
                break
        else:
            break
    # don't forget about the last one
    if last_entry_is_open_cite:
        pt = _strip_spans(grounded_answer[current_idx:])
        parsed_answer.append(
            {"text": pt, "cited_docs": parsed_current_cite_document_idxs}
        )
    else:
        parsed_answer.append({"text": _strip_spans(grounded_answer[current_idx:])})
    return no_markup_answer, parsed_answer