Skip to content

Commit

Permalink
chore: Use run instead of predict_and_parse (#188)
Browse files Browse the repository at this point in the history
Related issue #185 

This PR set the KorParser as the OutputParser on the LLMChain. This is
needed, since `predict_and_parse` on the LLMChain is deprecated and
`run` should be used. For this minimum LangChain version is raised to
0.0.205.

This PR includes

- Needed changes to code
- Needed change to documentation
- A minimal integration test (needs OpenAI API key) to be sure, things
are still working correct
  • Loading branch information
BorisWilhelms authored Jul 11, 2023
1 parent eebb05f commit ba20aa1
Show file tree
Hide file tree
Showing 8 changed files with 684 additions and 659 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ schema = Object(
)

chain = create_extraction_chain(llm, schema, encoder_or_encoder_class='json')
chain.predict_and_parse(text="play songs by paul simon and led zeppelin and the doors")['data']
chain.run("play songs by paul simon and led zeppelin and the doors")['data']
```

```python
Expand Down Expand Up @@ -119,7 +119,7 @@ schema, validator = from_pydantic(MusicRequest)
chain = create_extraction_chain(
llm, schema, encoder_or_encoder_class="json", validator=validator
)
chain.predict_and_parse(text="stop the music now")["validated_data"]
chain.run("stop the music now")["validated_data"]
```

```python
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ schema = Object(
)

chain = create_extraction_chain(llm, schema, encoder_or_encoder_class='json')
chain.predict_and_parse(text="play songs by paul simon and led zeppelin and the doors")['data']
chain.run("play songs by paul simon and led zeppelin and the doors")['data']
```

```python
Expand Down Expand Up @@ -114,7 +114,7 @@ schema, validator = from_pydantic(MusicRequest)
chain = create_extraction_chain(
llm, schema, encoder_or_encoder_class="json", validator=validator
)
chain.predict_and_parse(text="stop the music now")["validated_data"]
chain.run("stop the music now")["validated_data"]
```

```python
Expand Down
5 changes: 4 additions & 1 deletion kor/extraction/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from langchain.chains import LLMChain
from langchain.docstore.document import Document

from kor.extraction.parser import KorParser

try: # Handle breaking change in langchain
from langchain.base_language import BaseLanguageModel
except ImportError:
Expand All @@ -29,7 +31,7 @@ async def _extract_from_document_with_semaphore(
"""Extract from document with a semaphore to limit concurrency."""
async with semaphore:
extraction_result: Extraction = cast(
Extraction, await chain.apredict_and_parse(text=document.page_content)
Extraction, await chain.arun(document.page_content)
)
return {
"uid": uid,
Expand Down Expand Up @@ -115,6 +117,7 @@ def create_extraction_chain(
instruction_template=instruction_template,
input_formatter=input_formatter,
),
output_parser=KorParser(encoder=encoder, validator=validator, schema_=node),
**chain_kwargs,
)

Expand Down
1,281 changes: 633 additions & 648 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ repository = "https://www.github.com/eyurtsev/kor"
[tool.poetry.dependencies]
python = "^3.8.1"
openai = "^0.27"
langchain = ">=0.0.110"
langchain = ">=0.0.205"
pandas = "^1.5.3"
markdownify = {version = "^0.11.6", optional = true}
markdownify = {version = "^0.11.6", optional = false}

[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
Expand Down Expand Up @@ -52,7 +52,7 @@ html = ["markdownify"]
[tool.poe.tasks]
black = "black"
ruff = "ruff"
pytest.cmd = "py.test --durations=5 -W error::RuntimeWarning --cov --cov-config=.coveragerc --cov-report xml --cov-report term-missing:skip-covered"
pytest.cmd = "py.test -s --durations=5 -W error::RuntimeWarning --cov --cov-config=.coveragerc --cov-report xml --cov-report term-missing:skip-covered"
mypy = "mypy . --pretty --show-error-codes"
fix = { shell = "poe black . && poe ruff --fix ." }
fix_docs = "black docs"
Expand All @@ -61,6 +61,8 @@ fix_strings = "black kor --preview"
test = { shell = "poe black . --check --diff && poe ruff . && poe pytest && poe mypy" }
# Use to auto-generate docs
apidoc = "sphinx-apidoc -o docs/source/generated kor"
pyintegration = { shell = "py.test -s ./tests/integration/*.py" }
integration = { shell = "poe black . --check --diff && poe ruff . && poe pyintegration && poe mypy" }

[tool.ruff]
select = [
Expand Down
4 changes: 2 additions & 2 deletions tests/extraction/test_extraction_with_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_create_extraction_chain(options: Mapping[str, Any]) -> None:
chain = create_extraction_chain(chat_model, schema, **options)
assert isinstance(chain, LLMChain)
# Try to run through predict and parse
chain.predict_and_parse(text="some string")
chain.run("some string")


@pytest.mark.parametrize(
Expand All @@ -62,7 +62,7 @@ def test_create_extraction_chain_with_csv_encoder(options: Mapping[str, Any]) ->
chain = create_extraction_chain(chat_model, **options)
assert isinstance(chain, LLMChain)
# Try to run through predict and parse
chain.predict_and_parse(text="some string")
chain.run("some string")


MANY_TEXT_SCHEMA = Text(
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/extraction_with_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel, Field

from kor.adapters import from_pydantic
from kor.extraction import create_extraction_chain


def test_pydantic() -> None:
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, max_tokens=2000)

class Person(BaseModel):
first_name: str = Field(
description="The first name of a person.",
)

schema, v = from_pydantic(
Person,
description="Personal information",
examples=[
(
"Alice and Bob are friends",
{"first_name": "Alice"},
)
],
many=True,
)

chain = create_extraction_chain(llm, schema)
result = chain.run("My name is Bobby. My brother's name Joe.") # type: ignore
data = result["data"] # type: ignore

assert "person" in data # type: ignore
assert len(data["person"]) == 2 # type: ignore
4 changes: 3 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -23,6 +23,7 @@ def _generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
message = AIMessage(content=self.response)
Expand All @@ -34,6 +35,7 @@ async def _agenerate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Async version of _generate."""
message = AIMessage(content=self.response)
Expand Down

0 comments on commit ba20aa1

Please sign in to comment.