diff --git a/kor/extraction/parser.py b/kor/extraction/parser.py index 46c3358..a317ac7 100644 --- a/kor/extraction/parser.py +++ b/kor/extraction/parser.py @@ -12,7 +12,7 @@ from kor.validators import Validator -class KorParser(BaseOutputParser): +class KorParser(BaseOutputParser[Extraction]): """A Kor langchain parser integration. This parser can use any of Kor's encoders to support encoding/decoding diff --git a/kor/extraction/typedefs.py b/kor/extraction/typedefs.py index 2e4291d..9706d48 100644 --- a/kor/extraction/typedefs.py +++ b/kor/extraction/typedefs.py @@ -1,5 +1,7 @@ """Type definitions for the extraction package.""" -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List + +from typing_extensions import TypedDict class Extraction(TypedDict): diff --git a/tests/extraction/test_extraction_with_chain.py b/tests/extraction/test_extraction_with_chain.py index 68b3ab8..ff66bcb 100644 --- a/tests/extraction/test_extraction_with_chain.py +++ b/tests/extraction/test_extraction_with_chain.py @@ -114,3 +114,43 @@ def test_instantiation_with_verbose_flag(verbose: Optional[bool]) -> None: encoder_or_encoder_class="json", verbose=verbose, ) + + +def test_get_prompt() -> None: + """Create an extraction chain.""" + chat_model = ToyChatModel(response="hello") + chain = create_extraction_chain( + chat_model, + SIMPLE_OBJECT_SCHEMA, + encoder_or_encoder_class="json", + ) + prompts = chain.get_prompts() + prompt = prompts[0] + assert prompt.format_prompt(text="[text]").to_string() == ( + "Your goal is to extract structured information from the user's input that " + "matches the form described below. When extracting information please make " + "sure it matches the type information exactly. Do not add any attributes that " + "do not appear in the schema shown below.\n" + "\n" + "```TypeScript\n" + "\n" + "obj: { // \n" + " text_node: string // Text Field\n" + "}\n" + "```\n" + "\n" + "\n" + "Please output the extracted information in JSON format. Do not output " + "anything except for the extracted information. Do not add any clarifying " + "information. Do not add any fields that are not in the schema. If the text " + "contains attributes that do not appear in the schema, please ignore them. " + "All output must be in JSON format and follow the schema specified above. " + "Wrap the JSON in tags.\n" + "\n" + "\n" + "\n" + "Input: hello\n" + 'Output: {"obj": {"text_node": "goodbye"}}\n' + "Input: [text]\n" + "Output:" + )