Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Jul 20, 2024
1 parent 72bafb4 commit 19f8722
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 69 deletions.
57 changes: 27 additions & 30 deletions kor/extraction/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Kor API for extraction related functionality."""

import asyncio
from contextlib import contextmanager
from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast

from langchain.globals import get_verbose, set_verbose
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
Expand All @@ -20,16 +18,6 @@
from kor.validators import Validator


@contextmanager
def set_verbose_context(verbose: bool):
old_verbose = get_verbose()
set_verbose(verbose)
try:
yield
finally:
set_verbose(old_verbose)


async def _extract_from_document_with_semaphore(
semaphore: asyncio.Semaphore,
chain: Runnable,
Expand Down Expand Up @@ -89,8 +77,10 @@ def create_extraction_chain(
* "type_description": type description of the node (from TypeDescriptor)
* "format_instructions": information on how to format the output
(from Encoder)
verbose: if provided, sets the verbosity on the chain, otherwise default
verbosity of the chain will be used
verbose: Deprecated, use langchain_core.globals.set_verbose and
langchain_core.globals.set_debug instead.
Please reference this guide for more information:
https://python.langchain.com/v0.2/docs/how_to/debugging
encoder_kwargs: Keyword arguments to pass to the encoder class
Returns:
Expand All @@ -107,27 +97,34 @@ def create_extraction_chain(
chain = create_extraction_chain(llm, node, encoder_or_encoder_class="JSON",
input_formatter="triple_quotes")
"""

if verbose is not None:
raise NotImplementedError(
"The verbose argument is no longer supported. Instead if you want to see "
"verbose output, please reference this guide for more information: "
"https://python.langchain.com/v0.2/docs/how_to/debugging "
)

if not isinstance(node, Object):
raise ValueError(f"node must be an Object got {type(node)}")
encoder = initialize_encoder(encoder_or_encoder_class, node, **encoder_kwargs)
type_descriptor_to_use = initialize_type_descriptors(type_descriptor)

with set_verbose_context(verbose if verbose is not None else False):
prompt = create_langchain_prompt(
node,
encoder,
type_descriptor_to_use,
validator=validator,
instruction_template=instruction_template,
input_formatter=input_formatter,
)

chain = (
prompt
| llm
| StrOutputParser()
| KorParser(encoder=encoder, validator=validator, schema_=node)
)
prompt = create_langchain_prompt(
node,
encoder,
type_descriptor_to_use,
validator=validator,
instruction_template=instruction_template,
input_formatter=input_formatter,
)

chain = (
prompt
| llm
| StrOutputParser()
| KorParser(encoder=encoder, validator=validator, schema_=node)
)
return chain


Expand Down
47 changes: 8 additions & 39 deletions tests/extraction/test_extraction_with_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from typing import Any, Mapping, Optional

import pytest
from langchain.globals import get_verbose
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable

from kor.encoders import CSVEncoder, JSONEncoder
Expand Down Expand Up @@ -105,43 +103,14 @@ def test_not_implemented_assertion_raised_for_csv(options: Mapping[str, Any]) ->
create_extraction_chain(chat_model, **options)


@pytest.mark.parametrize("verbose", [True, False, None])
@pytest.mark.parametrize("verbose", [True, False])
def test_instantiation_with_verbose_flag(verbose: Optional[bool]) -> None:
"""Create an extraction chain."""
chat_model = ToyChatModel(response="hello")
chain = create_extraction_chain(
chat_model,
SIMPLE_OBJECT_SCHEMA,
encoder_or_encoder_class="json",
verbose=verbose,
)
assert isinstance(chain, Runnable)
expected_verbose = verbose if verbose is not None else get_verbose()
assert chain.verbose == expected_verbose


def test_using_custom_template() -> None:
"""Create an extraction chain with a custom template."""
template = PromptTemplate(
input_variables=["format_instructions", "type_description"],
template=(
"custom_prefix\n"
"{type_description}\n\n"
"{format_instructions}\n"
"custom_suffix"
),
)
chain = create_extraction_chain(
ToyChatModel(response="hello"),
OBJECT_SCHEMA_WITH_MANY,
instruction_template=template,
encoder_or_encoder_class="json",
)
prompt_value = chain.prompt.format_prompt(text="hello")
system_message = prompt_value.to_messages()[0]
string_value = prompt_value.to_string()

assert "custom_prefix" in string_value
assert "custom_suffix" in string_value
assert "custom_prefix" in system_message.content
assert "custom_suffix" in system_message.content
with pytest.raises(NotImplementedError):
create_extraction_chain(
chat_model,
SIMPLE_OBJECT_SCHEMA,
encoder_or_encoder_class="json",
verbose=verbose,
)

0 comments on commit 19f8722

Please sign in to comment.