diff --git a/kork/chain.py b/kork/chain.py
index 215a839..5361955 100644
--- a/kork/chain.py
+++ b/kork/chain.py
@@ -15,15 +15,13 @@
cast,
)
-from langchain import LLMChain
from langchain.chains.base import Chain
-from langchain.prompts import PromptTemplate
+from langchain_core.prompts import PromptTemplate
try:
from langchain.base_language import BaseLanguageModel
except ImportError:
from langchain.schema import BaseLanguageModel
-from pydantic import Extra
from kork import ast
from kork.ast_printer import AbstractAstPrinter, AstPrinter
@@ -128,7 +126,7 @@ class CodeChain(Chain):
class Config:
"""Configuration for this pydantic object."""
- extra = Extra.allow
+ extra = "allow"
arbitrary_types_allowed = True
@property
@@ -226,14 +224,14 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore
environment, few_shot_template = self.prepare_context(query, variables)
- chain = LLMChain(
- prompt=few_shot_template,
- llm=self.llm,
- )
+ chain = few_shot_template | self.llm
formatted_query = format_text(query, self.input_formatter)
- llm_output = cast(str, chain.predict_and_parse(query=formatted_query))
- code = unwrap_tag("code", llm_output)
+ llm_output = chain.invoke({"query": formatted_query})
+
+ # Extract content from AIMessage
+ llm_output_content = llm_output.content
+ code = unwrap_tag("code", llm_output_content)
if not code:
return {
@@ -243,7 +241,7 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore
"is wrapped in a code block."
)
],
- "raw": llm_output,
+ "raw": llm_output_content,
"code": "",
"environment": None,
}
@@ -253,14 +251,14 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore
if interpreter_result["errors"]:
return {
"errors": interpreter_result["errors"],
- "raw": llm_output,
+ "raw": llm_output_content,
"code": code,
"environment": None,
}
return {
"errors": [],
- "raw": llm_output,
+ "raw": llm_output_content,
"code": code,
"environment": interpreter_result["environment"],
}
diff --git a/kork/parser.py b/kork/parser.py
index 41f58d3..12801a2 100644
--- a/kork/parser.py
+++ b/kork/parser.py
@@ -17,7 +17,7 @@
from kork import ast
-GRAMMAR = """
+GRAMMAR = r"""
program: statement+
statement: function_decl
diff --git a/kork/prompt_adapter.py b/kork/prompt_adapter.py
index 59cd1cd..94439e6 100644
--- a/kork/prompt_adapter.py
+++ b/kork/prompt_adapter.py
@@ -7,9 +7,8 @@
"""
from typing import Any, Callable, List, Sequence, Tuple
-from langchain import BasePromptTemplate, PromptTemplate
from langchain.schema import BaseMessage, HumanMessage, PromptValue, SystemMessage
-from pydantic import Extra
+from langchain_core.prompts import BasePromptTemplate, PromptTemplate
class FewShotPromptValue(PromptValue):
@@ -21,7 +20,7 @@ class FewShotPromptValue(PromptValue):
class Config:
"""Configuration for this pydantic object."""
- extra = Extra.forbid
+ extra = "forbid"
arbitrary_types_allowed = True
def to_string(self) -> str:
diff --git a/tests/test_chain.py b/tests/test_chain.py
index c27bdf1..ef3ef0b 100644
--- a/tests/test_chain.py
+++ b/tests/test_chain.py
@@ -18,7 +18,7 @@ def test_code_chain() -> None:
example_retriever=example_retriever,
)
- response = chain(inputs={"query": "blah"})
+ response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
env = response.pop("environment")
@@ -43,7 +43,7 @@ def test_bad_program() -> None:
example_retriever=example_retriever,
)
- response = chain(inputs={"query": "blah"})
+ response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["raw"] == "\nINVALID PROGRAM\n
"
@@ -67,7 +67,7 @@ def test_llm_output_missing_program() -> None:
example_retriever=example_retriever,
)
- response = chain(inputs={"query": "blah"})
+ response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["raw"] == "oops."
@@ -82,7 +82,7 @@ def test_from_defaults_instantiation() -> None:
"""Test from default instantiation."""
llm = ToyChatModel(response="\nvar x = 1;\n
")
chain = CodeChain.from_defaults(llm=llm)
- response = chain(inputs={"query": "blah"})
+ response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["environment"].get_symbol("x") == 1
diff --git a/tests/test_prompt_adapter.py b/tests/test_prompt_adapter.py
index 35b5379..3b39e09 100644
--- a/tests/test_prompt_adapter.py
+++ b/tests/test_prompt_adapter.py
@@ -1,4 +1,4 @@
-from langchain.prompts import PromptTemplate
+from langchain_core.prompts import PromptTemplate
from kork.prompt_adapter import FewShotPromptValue, FewShotTemplate
diff --git a/tests/utils.py b/tests/utils.py
index cce5622..41339cb 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -2,7 +2,6 @@
from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
-from pydantic import Extra
class ToyChatModel(BaseChatModel):
@@ -11,7 +10,7 @@ class ToyChatModel(BaseChatModel):
class Config:
"""Configuration for this pydantic object."""
- extra = Extra.forbid
+ extra = "forbid"
arbitrary_types_allowed = True
def _generate(
@@ -28,3 +27,7 @@ async def _agenerate(
message = AIMessage(content=self.response)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
+
+ @property
+ def _llm_type(self) -> str:
+ return "toy"