Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated code #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions kork/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
cast,
)

from langchain import LLMChain
from langchain.chains.base import Chain
from langchain.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate

try:
from langchain.base_language import BaseLanguageModel
except ImportError:
from langchain.schema import BaseLanguageModel
from pydantic import Extra

from kork import ast
from kork.ast_printer import AbstractAstPrinter, AstPrinter
Expand Down Expand Up @@ -128,7 +126,7 @@ class CodeChain(Chain):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.allow
extra = "allow"
arbitrary_types_allowed = True

@property
Expand Down Expand Up @@ -226,14 +224,14 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore

environment, few_shot_template = self.prepare_context(query, variables)

chain = LLMChain(
prompt=few_shot_template,
llm=self.llm,
)
chain = few_shot_template | self.llm

formatted_query = format_text(query, self.input_formatter)
llm_output = cast(str, chain.predict_and_parse(query=formatted_query))
code = unwrap_tag("code", llm_output)
llm_output = chain.invoke({"query": formatted_query})

# Extract content from AIMessage
llm_output_content = llm_output.content
code = unwrap_tag("code", llm_output_content)

if not code:
return {
Expand All @@ -243,7 +241,7 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore
"is wrapped in a code block."
)
],
"raw": llm_output,
"raw": llm_output_content,
"code": "",
"environment": None,
}
Expand All @@ -253,14 +251,14 @@ def _call(self, inputs: Dict[str, str]) -> CodeResult: # type: ignore
if interpreter_result["errors"]:
return {
"errors": interpreter_result["errors"],
"raw": llm_output,
"raw": llm_output_content,
"code": code,
"environment": None,
}

return {
"errors": [],
"raw": llm_output,
"raw": llm_output_content,
"code": code,
"environment": interpreter_result["environment"],
}
Expand Down
2 changes: 1 addition & 1 deletion kork/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from kork import ast

GRAMMAR = """
GRAMMAR = r"""
program: statement+

statement: function_decl
Expand Down
5 changes: 2 additions & 3 deletions kork/prompt_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
"""
from typing import Any, Callable, List, Sequence, Tuple

from langchain import BasePromptTemplate, PromptTemplate
from langchain.schema import BaseMessage, HumanMessage, PromptValue, SystemMessage
from pydantic import Extra
from langchain_core.prompts import BasePromptTemplate, PromptTemplate


class FewShotPromptValue(PromptValue):
Expand All @@ -21,7 +20,7 @@ class FewShotPromptValue(PromptValue):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
extra = "forbid"
arbitrary_types_allowed = True

def to_string(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_code_chain() -> None:
example_retriever=example_retriever,
)

response = chain(inputs={"query": "blah"})
response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
env = response.pop("environment")
Expand All @@ -43,7 +43,7 @@ def test_bad_program() -> None:
example_retriever=example_retriever,
)

response = chain(inputs={"query": "blah"})
response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["raw"] == "<code>\nINVALID PROGRAM\n</code>"
Expand All @@ -67,7 +67,7 @@ def test_llm_output_missing_program() -> None:
example_retriever=example_retriever,
)

response = chain(inputs={"query": "blah"})
response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["raw"] == "oops."
Expand All @@ -82,7 +82,7 @@ def test_from_defaults_instantiation() -> None:
"""Test from default instantiation."""
llm = ToyChatModel(response="<code>\nvar x = 1;\n</code>")
chain = CodeChain.from_defaults(llm=llm)
response = chain(inputs={"query": "blah"})
response = chain.invoke({"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["environment"].get_symbol("x") == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompt_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate

from kork.prompt_adapter import FewShotPromptValue, FewShotTemplate

Expand Down
7 changes: 5 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
from pydantic import Extra


class ToyChatModel(BaseChatModel):
Expand All @@ -11,7 +10,7 @@ class ToyChatModel(BaseChatModel):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
extra = "forbid"
arbitrary_types_allowed = True

def _generate(
Expand All @@ -28,3 +27,7 @@ async def _agenerate(
message = AIMessage(content=self.response)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
return "toy"