diff --git a/kork/chain.py b/kork/chain.py index 215a839..5361955 100644 --- a/kork/chain.py +++ b/kork/chain.py @@ -15,15 +15,13 @@ cast, ) -from langchain import LLMChain from langchain.chains.base import Chain -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate try: from langchain.base_language import BaseLanguageModel except ImportError: from langchain.schema import BaseLanguageModel -from pydantic import Extra from kork import ast from kork.ast_printer import AbstractAstPrinter, AstPrinter @@ -128,7 +126,7 @@ class CodeChain(Chain): class Config: """Configuration for this pydantic object.""" - extra = Extra.allow + extra = "allow" arbitrary_types_allowed = True @property @@ -226,14 +224,14 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore environment, few_shot_template = self.prepare_context(query, variables) - chain = LLMChain( - prompt=few_shot_template, - llm=self.llm, - ) + chain = few_shot_template | self.llm formatted_query = format_text(query, self.input_formatter) - llm_output = cast(str, chain.predict_and_parse(query=formatted_query)) - code = unwrap_tag("code", llm_output) + llm_output = chain.invoke({"query": formatted_query}) + + # Extract content from AIMessage + llm_output_content = llm_output.content + code = unwrap_tag("code", llm_output_content) if not code: return { @@ -243,7 +241,7 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore "is wrapped in a code block." ) ], - "raw": llm_output, + "raw": llm_output_content, "code": "", "environment": None, } @@ -253,14 +251,14 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore if interpreter_result["errors"]: return { "errors": interpreter_result["errors"], - "raw": llm_output, + "raw": llm_output_content, "code": code, "environment": None, } return { "errors": [], - "raw": llm_output, + "raw": llm_output_content, "code": code, "environment": interpreter_result["environment"], } diff --git a/kork/parser.py b/kork/parser.py index 41f58d3..12801a2 100644 --- a/kork/parser.py +++ b/kork/parser.py @@ -17,7 +17,7 @@ from kork import ast -GRAMMAR = """ +GRAMMAR = r""" program: statement+ statement: function_decl diff --git a/kork/prompt_adapter.py b/kork/prompt_adapter.py index 59cd1cd..94439e6 100644 --- a/kork/prompt_adapter.py +++ b/kork/prompt_adapter.py @@ -7,9 +7,8 @@ """ from typing import Any, Callable, List, Sequence, Tuple -from langchain import BasePromptTemplate, PromptTemplate from langchain.schema import BaseMessage, HumanMessage, PromptValue, SystemMessage -from pydantic import Extra +from langchain_core.prompts import BasePromptTemplate, PromptTemplate class FewShotPromptValue(PromptValue): @@ -21,7 +20,7 @@ class FewShotPromptValue(PromptValue): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid + extra = "forbid" arbitrary_types_allowed = True def to_string(self) -> str: diff --git a/tests/test_chain.py b/tests/test_chain.py index c27bdf1..ef3ef0b 100644 --- a/tests/test_chain.py +++ b/tests/test_chain.py @@ -18,7 +18,7 @@ def test_code_chain() -> None: example_retriever=example_retriever, ) - response = chain(inputs={"query": "blah"}) + response = chain.invoke({"query": "blah"}) # Why does the chain return a `query` key? assert sorted(response) == ["code", "environment", "errors", "query", "raw"] env = response.pop("environment") @@ -43,7 +43,7 @@ def test_bad_program() -> None: example_retriever=example_retriever, ) - response = chain(inputs={"query": "blah"}) + response = chain.invoke({"query": "blah"}) # Why does the chain return a `query` key? assert sorted(response) == ["code", "environment", "errors", "query", "raw"] assert response["raw"] == "\nINVALID PROGRAM\n" @@ -67,7 +67,7 @@ def test_llm_output_missing_program() -> None: example_retriever=example_retriever, ) - response = chain(inputs={"query": "blah"}) + response = chain.invoke({"query": "blah"}) # Why does the chain return a `query` key? assert sorted(response) == ["code", "environment", "errors", "query", "raw"] assert response["raw"] == "oops." @@ -82,7 +82,7 @@ def test_from_defaults_instantiation() -> None: """Test from default instantiation.""" llm = ToyChatModel(response="\nvar x = 1;\n") chain = CodeChain.from_defaults(llm=llm) - response = chain(inputs={"query": "blah"}) + response = chain.invoke({"query": "blah"}) # Why does the chain return a `query` key? assert sorted(response) == ["code", "environment", "errors", "query", "raw"] assert response["environment"].get_symbol("x") == 1 diff --git a/tests/test_prompt_adapter.py b/tests/test_prompt_adapter.py index 35b5379..3b39e09 100644 --- a/tests/test_prompt_adapter.py +++ b/tests/test_prompt_adapter.py @@ -1,4 +1,4 @@ -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate from kork.prompt_adapter import FewShotPromptValue, FewShotTemplate diff --git a/tests/utils.py b/tests/utils.py index cce5622..41339cb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,6 @@ from langchain.chat_models.base import BaseChatModel from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult -from pydantic import Extra class ToyChatModel(BaseChatModel): @@ -11,7 +10,7 @@ class ToyChatModel(BaseChatModel): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid + extra = "forbid" arbitrary_types_allowed = True def _generate( @@ -28,3 +27,7 @@ async def _agenerate( message = AIMessage(content=self.response) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "toy"