Source code for langchain_ai21.semantic_text_splitter
import copy
import logging
import re
from typing import (
    Any,
    Iterable,
    List,
    Optional,
)
from ai21.models import DocumentType
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import SecretStr
from langchain_text_splitters import TextSplitter
from langchain_ai21.ai21_base import AI21Base
logger = logging.getLogger(__name__)
[docs]
class AI21SemanticTextSplitter(TextSplitter):
    """Splitting text into coherent and readable units,
    based on distinct topics and lines.
    """
[docs]
    def __init__(
        self,
        chunk_size: int = 0,
        chunk_overlap: int = 0,
        client: Optional[Any] = None,
        api_key: Optional[SecretStr] = None,
        api_host: Optional[str] = None,
        timeout_sec: Optional[float] = None,
        num_retries: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        """Create a new TextSplitter."""
        super().__init__(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            **kwargs,
        )
        self._segmentation = AI21Base(
            client=client,
            api_key=api_key,
            api_host=api_host,
            timeout_sec=timeout_sec,
            num_retries=num_retries,
        ).client.segmentation 
[docs]
    def split_text(self, source: str) -> List[str]:
        """Split text into multiple components.
        Args:
            source: Specifies the text input for text segmentation
        """
        response = self._segmentation.create(
            source=source, source_type=DocumentType.TEXT
        )
        segments = [segment.segment_text for segment in response.segments]
        if self._chunk_size > 0:
            return self._merge_splits_no_seperator(segments)
        return segments 
[docs]
    def split_text_to_documents(self, source: str) -> List[Document]:
        """Split text into multiple documents.
        Args:
            source: Specifies the text input for text segmentation
        """
        response = self._segmentation.create(
            source=source, source_type=DocumentType.TEXT
        )
        return [
            Document(
                page_content=segment.segment_text,
                metadata={"source_type": segment.segment_type},
            )
            for segment in response.segments
        ] 
[docs]
    def create_documents(
        self, texts: List[str], metadatas: Optional[List[dict]] = None
    ) -> List[Document]:
        """Create documents from a list of texts."""
        _metadatas = metadatas or [{}] * len(texts)
        documents = []
        for i, text in enumerate(texts):
            normalized_text = self._normalized_text(text)
            index = 0
            previous_chunk_len = 0
            for chunk in self.split_text_to_documents(text):
                # merge metadata from user (if exists) and from segmentation api
                metadata = copy.deepcopy(_metadatas[i])
                metadata.update(chunk.metadata)
                if self._add_start_index:
                    # find the start index of the chunk
                    offset = index + previous_chunk_len - self._chunk_overlap
                    normalized_chunk = self._normalized_text(chunk.page_content)
                    index = normalized_text.find(normalized_chunk, max(0, offset))
                    metadata["start_index"] = index
                    previous_chunk_len = len(normalized_chunk)
                documents.append(
                    Document(
                        page_content=chunk.page_content,
                        metadata=metadata,
                    )
                )
        return documents 
    def _normalized_text(self, string: str) -> str:
        """Use regular expression to replace sequences of '\n'"""
        return re.sub(r"\s+", "", string)
    def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
        """This method overrides the default implementation of TextSplitter"""
        return self._merge_splits_no_seperator(splits)
    def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]:
        """Merge splits into chunks.
        If the segment size is bigger than chunk_size,
        it will be left as is (won't be cut to match to chunk_size).
        If the segment size is smaller than chunk_size,
        it will be merged with the next segment until the chunk_size is reached.
        """
        chunks = []
        current_chunk = ""
        for split in splits:
            split_len = self._length_function(split)
            if split_len > self._chunk_size:
                logger.warning(
                    f"Split of length {split_len}"
                    f"exceeds chunk size {self._chunk_size}."
                )
            if self._length_function(current_chunk) + split_len > self._chunk_size:
                if current_chunk != "":
                    chunks.append(current_chunk)
                    current_chunk = ""
            current_chunk += split
        if current_chunk != "":
            chunks.append(current_chunk)
        return chunks