From 1b12dfc6d27922f19ddbb1d23cd77a8e8a64385d Mon Sep 17 00:00:00 2001 From: Toby Osborne Date: Fri, 3 May 2024 09:48:31 +1200 Subject: [PATCH 1/2] implemented GenericChunkingModel and NLTKChunkingModel --- ragna/chunking_models/__init__.py | 7 +++ .../_generic_chunking_model.py | 54 +++++++++++++++++++ ragna/chunking_models/_nltk_chunking_model.py | 20 +++++++ ragna/core/__init__.py | 2 + ragna/core/_components.py | 19 ++++++- ragna/core/_rag.py | 50 ++++++++++------- 6 files changed, 131 insertions(+), 21 deletions(-) create mode 100644 ragna/chunking_models/__init__.py create mode 100644 ragna/chunking_models/_generic_chunking_model.py create mode 100644 ragna/chunking_models/_nltk_chunking_model.py diff --git a/ragna/chunking_models/__init__.py b/ragna/chunking_models/__init__.py new file mode 100644 index 00000000..1185ac49 --- /dev/null +++ b/ragna/chunking_models/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "GenericChunkingModel", + "NLTKChunkingModel", +] + +from ._generic_chunking_model import GenericChunkingModel +from ._nltk_chunking_model import NLTKChunkingModel diff --git a/ragna/chunking_models/_generic_chunking_model.py b/ragna/chunking_models/_generic_chunking_model.py new file mode 100644 index 00000000..33e4be0c --- /dev/null +++ b/ragna/chunking_models/_generic_chunking_model.py @@ -0,0 +1,54 @@ +from ragna.core import Document, Chunk, ChunkingModel + +import functools + +from typing import TYPE_CHECKING, TypeVar, Iterable, Iterator, Deque + +from collections import deque + +if TYPE_CHECKING: + import tiktoken + +T = TypeVar("T") + + +# The function is adapted from more_itertools.windowed to allow a ragged last window +# https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed +def _windowed_ragged( + iterable: Iterable[T], *, n: int, step: int +) -> Iterator[tuple[T, ...]]: + window: Deque[T] = deque(maxlen=n) + i = n + for _ in map(window.append, iterable): + i -= 1 + if not i: + i = step + yield tuple(window) + + if len(window) < n: + yield tuple(window) + elif 0 < i < min(step, n): + yield tuple(window)[i:] + +class GenericChunkingModel(ChunkingModel): + def chunk_documents(self, documents: list[Document], chunk_size: int = 500, chunk_overlap: int = 250) -> list[Chunk]: + chunks = [] + for document in documents: + for window in _windowed_ragged( + ( + (tokens, page.number) + for page in document.extract_pages() + for tokens in self.tokenizer.encode(page.text) + ), + n=chunk_size, + step=chunk_size - chunk_overlap, + ): + tokens, page_numbers = zip(*window) + chunks.append(Chunk( + text=self.tokenizer.decode(tokens), # type: ignore[arg-type] + document_id=document.id, + page_numbers=list(filter(lambda n: n is not None, page_numbers)) or None, + num_tokens=len(tokens), + )) + + return chunks diff --git a/ragna/chunking_models/_nltk_chunking_model.py b/ragna/chunking_models/_nltk_chunking_model.py new file mode 100644 index 00000000..6847799a --- /dev/null +++ b/ragna/chunking_models/_nltk_chunking_model.py @@ -0,0 +1,20 @@ +from ragna.core import Document, Chunk, ChunkingModel + +class NLTKChunkingModel(ChunkingModel): + def __init__(self): + super().__init__() + + # our text splitter goes here + from langchain.text_splitter import NLTKTextSplitter + self.text_splitter = NLTKTextSplitter() + + def chunk_documents(self, documents: list[Document]) -> list[Chunk]: + # This is not perfect, but it's the only way I could get this to somewhat work + chunks = [] + for document in documents: + pages = list(document.extract_pages()) + text = "".join([page.text for page in pages]) + + chunks += self.generate_chunks_from_text(self.text_splitter.split_text(text), document.id) + + return chunks diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 0740139e..7b1a119a 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -2,6 +2,7 @@ "Assistant", "Chat", "Chunk", + "ChunkingModel", "Component", "Document", "DocumentHandler", @@ -52,6 +53,7 @@ from ._components import ( Assistant, Component, + ChunkingModel, Embedding, EmbeddingModel, Message, diff --git a/ragna/core/_components.py b/ragna/core/_components.py index 5c80ac21..731e8ba8 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -11,6 +11,7 @@ AsyncIterable, AsyncIterator, Iterator, + Iterable, Optional, Type, Union, @@ -23,9 +24,11 @@ import pydantic import pydantic.utils -from ._document import Chunk, Document +from ._document import Chunk, Document, Page from ._utils import RequirementsMixin, merge_models +from uuid import UUID + class Component(RequirementsMixin): """Base class for RAG components. @@ -92,6 +95,20 @@ def _protocol_model(cls) -> Type[pydantic.BaseModel]: return merge_models(cls.display_name(), *cls._protocol_models().values()) +class ChunkingModel(Component, ABC): + def __init__(self): + import tiktoken + self.tokenizer = tiktoken.get_encoding("cl100k_base") + + @abstractmethod + def chunk_documents(self, documents: list[Document]) -> list[Chunk]: + raise NotImplementedError + + def generate_chunks_from_text(self, chunks: list[str], document_id: UUID) -> list[Chunk]: + return [Chunk(page_numbers=[1], text=chunks[i], document_id=document_id, + num_tokens=len(self.tokenizer.encode(chunks[i]))) for i in range(len(chunks))] + + @dataclass class Embedding: values: list[float] diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index c141d238..4b4542ab 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -21,10 +21,10 @@ import pydantic from starlette.concurrency import iterate_in_threadpool, run_in_threadpool -from ._compat import chunk_pages from ._components import ( Assistant, Chunk, + ChunkingModel, Component, Embedding, EmbeddingModel, @@ -91,6 +91,7 @@ def chat( source_storage: Union[Type[SourceStorage], SourceStorage], assistant: Union[Type[Assistant], Assistant], embedding_model: Optional[Union[Type[EmbeddingModel], EmbeddingModel]] = None, + chunking_model: Optional[Union[Type[ChunkingModel], ChunkingModel]] = None, **params: Any, ) -> Chat: """Create a new [ragna.core.Chat][]. @@ -100,6 +101,8 @@ def chat( [ragna.core.LocalDocument.from_path][] is invoked on it. source_storage: Source storage to use. assistant: Assistant to use. + embedding_model: Embedding model to use + chunking_model: Chunking model to use (Token Based, NLTK, Spacy) **params: Additional parameters passed to the source storage and assistant. """ return Chat( @@ -108,6 +111,7 @@ def chat( source_storage=source_storage, assistant=assistant, embedding_model=embedding_model, + chunking_model=chunking_model, **params, ) @@ -164,22 +168,32 @@ def __init__( source_storage: Union[Type[SourceStorage], SourceStorage], assistant: Union[Type[Assistant], Assistant], embedding_model: Optional[Union[Type[EmbeddingModel], EmbeddingModel]], + chunking_model: Optional[Union[Type[ChunkingModel], ChunkingModel]], **params: Any, ) -> None: self._rag = rag self.documents = self._parse_documents(documents) - if embedding_model is None and issubclass( + if (embedding_model is None or chunking_model is None) and issubclass( source_storage.__ragna_input_type__, Embedding ): raise RagnaException - elif embedding_model is not None: - embedding_model = cast( - EmbeddingModel, self._rag._load_component(embedding_model) - ) + else: + if embedding_model is not None: + embedding_model = cast( + EmbeddingModel, self._rag._load_component(embedding_model) + ) + + if chunking_model is not None: + chunking_model = cast( + ChunkingModel, self._rag._load_component(chunking_model) + ) + self.embedding_model = embedding_model + self.chunking_model = chunking_model + self.source_storage = cast( SourceStorage, self._rag._load_component(source_storage) ) @@ -225,20 +239,14 @@ async def prepare(self) -> Message: detail=RagnaException.EVENT, ) - chunks = [ - chunk - for document in self.documents - for chunk in chunk_pages( - document.extract_pages(), - document_id=document.id, - chunk_size=self.params["chunk_size"], - chunk_overlap=self.params["chunk_overlap"], - ) - ] - - input: Union[list[Document], list[Embedding]] = self.documents + # I vaguely recall you mentioning 3 distinct cases, in which the source_storage may take any one of + # Document, Embedding or Chunk. I have accounted for that here + input: Union[list[Document], list[Embedding], list[Chunk]] = self.documents if not issubclass(self.source_storage.__ragna_input_type__, Document): - input = cast(EmbeddingModel, self.embedding_model).embed_chunks(chunks) + input = cast(ChunkingModel, self.chunking_model).chunk_documents(input) + if not issubclass(self.source_storage.__ragna_input_type__, Chunk): + input = cast(EmbeddingModel, self.embedding_model).embed_chunks(input) + await self._run(self.source_storage.store, input) self._prepared = True @@ -271,7 +279,9 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: self._messages.append(Message(content=prompt, role=MessageRole.USER)) input: Union[str, list[float]] = prompt - if not issubclass(self.source_storage.__ragna_input_type__, Document): + # Both Chunk and Document would take a string prompt as input + if (not issubclass(self.source_storage.__ragna_input_type__, Document) + and not issubclass(self.source_storage.__ragna_input_type__, Chunk)): input = self._embed_text(prompt) sources = await self._run(self.source_storage.retrieve, self.documents, input) From c65d2fb3626177c33db1350f1fa10a3bca8a27f8 Mon Sep 17 00:00:00 2001 From: Toby Osborne Date: Fri, 3 May 2024 11:01:13 +1200 Subject: [PATCH 2/2] removed compat file (no longer needed) --- ragna/core/_compat.py | 70 ------------------------------------------- ragna/core/_rag.py | 2 ++ 2 files changed, 2 insertions(+), 70 deletions(-) delete mode 100644 ragna/core/_compat.py diff --git a/ragna/core/_compat.py b/ragna/core/_compat.py deleted file mode 100644 index d022cf17..00000000 --- a/ragna/core/_compat.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Temporary module -""" - -import functools -import uuid -from collections import deque -from typing import TYPE_CHECKING, Deque, Iterable, Iterator, TypeVar - -from ._document import Chunk, Page - -if TYPE_CHECKING: - import tiktoken - -__all__ = ["chunk_pages"] - -T = TypeVar("T") - - -# The function is adapted from more_itertools.windowed to allow a ragged last window -# https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed -def _windowed_ragged( - iterable: Iterable[T], *, n: int, step: int -) -> Iterator[tuple[T, ...]]: - window: Deque[T] = deque(maxlen=n) - i = n - for _ in map(window.append, iterable): - i -= 1 - if not i: - i = step - yield tuple(window) - - if len(window) < n: - yield tuple(window) - elif 0 < i < min(step, n): - yield tuple(window)[i:] - - -@functools.cache -def _get_tokenizer() -> "tiktoken.Encoding": - import tiktoken - - return tiktoken.get_encoding("cl100k_base") - - -def chunk_pages( - pages: Iterable[Page], - document_id: uuid.UUID, - *, - chunk_size: int, - chunk_overlap: int, -) -> Iterator[Chunk]: - tokenizer = _get_tokenizer() - - for window in _windowed_ragged( - ( - (tokens, page.number) - for page in pages - for tokens in tokenizer.encode(page.text) - ), - n=chunk_size, - step=chunk_size - chunk_overlap, - ): - tokens, page_numbers = zip(*window) - yield Chunk( - text=tokenizer.decode(tokens), # type: ignore[arg-type] - document_id=document_id, - page_numbers=list(filter(lambda n: n is not None, page_numbers)) or None, - num_tokens=len(tokens), - ) diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 4b4542ab..1af8b373 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -157,6 +157,8 @@ class Chat: [ragna.core.LocalDocument.from_path][] is invoked on it. source_storage: Source storage to use. assistant: Assistant to use. + embedding_model: Embedding model to use. Required for source storages that take embeddings + chunking_model: Chunking model to use. Required for source storages that take embeddings or chunks **params: Additional parameters passed to the source storage and assistant. """