Skip to content

Commit

Permalink
Remove Deprecated LLMChain
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachin-Bhat committed Jul 16, 2024
1 parent 9a3afa3 commit 33c49d4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
44 changes: 28 additions & 16 deletions kor/extraction/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Kor API for extraction related functionality."""

import asyncio
from contextlib import contextmanager
from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast

from langchain.chains import LLMChain
from langchain.globals import get_verbose, set_verbose
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence

from kor.encoders import Encoder, InputFormatter, initialize_encoder
from kor.extraction.parser import KorParser
Expand All @@ -16,17 +20,25 @@
from kor.validators import Validator


@contextmanager
def set_verbose_context(verbose: bool):
old_verbose = get_verbose()
set_verbose(verbose)
yield
set_verbose(old_verbose)


async def _extract_from_document_with_semaphore(
semaphore: asyncio.Semaphore,
chain: LLMChain,
chain: RunnableSequence,
document: Document,
uid: str,
source_uid: str,
) -> DocumentExtraction:
"""Extract from document with a semaphore to limit concurrency."""
async with semaphore:
extraction_result: Extraction = cast(
Extraction, await chain.arun(document.page_content)
Extraction, await chain.ainvoke(document.page_content)
)
return {
"uid": uid,
Expand All @@ -52,7 +64,7 @@ def create_extraction_chain(
instruction_template: Optional[PromptTemplate] = None,
verbose: Optional[bool] = None,
**encoder_kwargs: Any,
) -> LLMChain:
) -> RunnableSequence:
"""Create an extraction chain.
Args:
Expand Down Expand Up @@ -98,27 +110,27 @@ def create_extraction_chain(
encoder = initialize_encoder(encoder_or_encoder_class, node, **encoder_kwargs)
type_descriptor_to_use = initialize_type_descriptors(type_descriptor)

chain_kwargs = {}
if verbose is not None:
chain_kwargs["verbose"] = verbose

return LLMChain(
llm=llm,
prompt=create_langchain_prompt(
with set_verbose_context(verbose if verbose is not None else False):
prompt = create_langchain_prompt(
node,
encoder,
type_descriptor_to_use,
validator=validator,
instruction_template=instruction_template,
input_formatter=input_formatter,
),
output_parser=KorParser(encoder=encoder, validator=validator, schema_=node),
**chain_kwargs,
)
)

chain = (
prompt
| llm
| StrOutputParser()
| KorParser(encoder=encoder, validator=validator, schema_=node)
)
return chain


async def extract_from_documents(
chain: LLMChain,
chain: RunnableSequence,
documents: Sequence[Document],
*,
max_concurrency: int = 1,
Expand Down
15 changes: 6 additions & 9 deletions tests/extraction/test_extraction_with_chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test that the extraction chain works as expected."""
from typing import Any, Mapping, Optional

import langchain
import pytest
from langchain.chains import LLMChain
from langchain.globals import get_verbose
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence

from kor.encoders import CSVEncoder, JSONEncoder
from kor.extraction import create_extraction_chain
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_create_extraction_chain(options: Mapping[str, Any]) -> None:

for schema in [SIMPLE_OBJECT_SCHEMA]:
chain = create_extraction_chain(chat_model, schema, **options)
assert isinstance(chain, LLMChain)
assert isinstance(chain, RunnableSequence)
# Try to run through predict and parse
chain.invoke("some string") # type: ignore

Expand All @@ -60,7 +60,7 @@ def test_create_extraction_chain_with_csv_encoder(options: Mapping[str, Any]) ->
chat_model = ToyChatModel(response="hello")

chain = create_extraction_chain(chat_model, **options)
assert isinstance(chain, LLMChain)
assert isinstance(chain, RunnableSequence)
# Try to run through predict and parse
chain.invoke("some string") # type: ignore

Expand Down Expand Up @@ -115,11 +115,8 @@ def test_instantiation_with_verbose_flag(verbose: Optional[bool]) -> None:
encoder_or_encoder_class="json",
verbose=verbose,
)
assert isinstance(chain, LLMChain)
if verbose is None:
expected_verbose = langchain.verbose
else:
expected_verbose = verbose
assert isinstance(chain, RunnableSequence)
expected_verbose = verbose if verbose is not None else get_verbose()
assert chain.verbose == expected_verbose


Expand Down

0 comments on commit 33c49d4

Please sign in to comment.