diff --git a/kor/extraction.py b/kor/extraction.py index 2a8335c..b57b5d3 100644 --- a/kor/extraction.py +++ b/kor/extraction.py @@ -1,5 +1,6 @@ from typing import Any, Optional, Type, Union +from langchain import PromptTemplate from langchain.chains import LLMChain from langchain.schema import BaseLanguageModel @@ -20,6 +21,7 @@ def create_extraction_chain( type_descriptor: Union[TypeDescriptor, str] = "typescript", validator: Optional[Validator] = None, input_formatter: InputFormatter = None, + instruction_template: Optional[PromptTemplate] = None, **encoder_kwargs: Any, ) -> LLMChain: """Create an extraction chain. @@ -34,11 +36,16 @@ def create_extraction_chain( validator: optional validator to use for validation input_formatter: the formatter to use for encoding the input. Used for \ both input examples and the text to be analyzed. - * `None`: use for single sentences or single paragraph, no formatting * `triple_quotes`: for long text, surround input with \"\"\" * `text_prefix`: for long text, triple_quote with `TEXT: ` prefix * `Callable`: user provided function + instruction_template: optional prompt template to use, use to over-ride prompt + used for generating the instruction section of the prompt. + It accepts 2 optional input variables: + * "type_description": type description of the node (from TypeDescriptor) + * "format_instructions": information on how to format the output + (from Encoder) encoder_kwargs: Keyword arguments to pass to the encoder class Returns: @@ -66,7 +73,8 @@ def create_extraction_chain( node, encoder, type_descriptor_to_use, - validator, + validator=validator, + instruction_template=instruction_template, input_formatter=input_formatter, ), ) diff --git a/kor/prompts.py b/kor/prompts.py index 8a439dc..247dffa 100644 --- a/kor/prompts.py +++ b/kor/prompts.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional, Tuple from langchain import BasePromptTemplate +from langchain.prompts import PromptTemplate from langchain.schema import ( AIMessage, BaseMessage, @@ -22,21 +23,48 @@ from .validators import Validator +DEFAULT_INSTRUCTION_TEMPLATE = PromptTemplate( + input_variables=["type_description", "format_instructions"], + template=( + "Your goal is to extract structured information from the user's input that" + " matches the form described below. When extracting information please make" + " sure it matches the type information exactly. Do not add any attributes that" + " do not appear in the schema shown below.\n\n" + "{type_description}\n\n" + "{format_instructions}\n\n" + ), +) + class ExtractionPromptValue(PromptValue): """Integration with langchain prompt format.""" - text: str + string: str + messages: List[BaseMessage] + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def to_string(self) -> str: + """Format the prompt to a string.""" + return self.string + + def to_messages(self) -> List[BaseMessage]: + """Get materialized messages.""" + return self.messages + + +class ExtractionPromptTemplate(BasePromptTemplate): + """Extraction prompt template.""" + encoder: Encoder node: Object type_descriptor: TypeDescriptor input_formatter: InputFormatter - prefix: str = ( - "Your goal is to extract structured information from the user's input that" - " matches the form described below. When extracting information please make" - " sure it matches the type information exactly. Do not add any attributes that" - " do not appear in the schema shown below." - ) + instruction_template: PromptTemplate class Config: """Configuration for this pydantic object.""" @@ -44,13 +72,28 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True - def get_formatted_text(self) -> str: - """Get the text encoded if needed.""" - return format_text(self.text, input_formatter=self.input_formatter) + def format_prompt( # type: ignore[override] + self, + text: str, + ) -> PromptValue: + """Format the prompt.""" + text = format_text(text, input_formatter=self.input_formatter) + return ExtractionPromptValue( + string=self.to_string(text), messages=self.to_messages(text) + ) - def to_string(self) -> str: + def format(self, **kwargs: Any) -> str: + """Implementation of deprecated format method.""" + raise NotImplementedError() + + @property + def _prompt_type(self) -> str: + """Prompt type.""" + return "ExtractionPromptTemplate" + + def to_string(self, text: str) -> str: """Format the template to a string.""" - instruction_segment = self.generate_instruction_segment(self.node) + instruction_segment = self.format_instruction_segment(self.node) encoded_examples = self.generate_encoded_examples(self.node) formatted_examples: List[str] = [] @@ -62,14 +105,13 @@ def to_string(self) -> str: ] ) - text = self.get_formatted_text() formatted_examples.append(f"Input: {text}\nOutput:") input_output_block = "\n".join(formatted_examples) return f"{instruction_segment}\n\n{input_output_block}" - def to_messages(self) -> List[BaseMessage]: + def to_messages(self, text: str) -> List[BaseMessage]: """Format the template to chat messages.""" - instruction_segment = self.generate_instruction_segment(self.node) + instruction_segment = self.format_instruction_segment(self.node) messages: List[BaseMessage] = [SystemMessage(content=instruction_segment)] encoded_examples = self.generate_encoded_examples(self.node) @@ -82,8 +124,7 @@ def to_messages(self) -> List[BaseMessage]: ] ) - content = self.get_formatted_text() - messages.append(HumanMessage(content=content)) + messages.append(HumanMessage(content=text)) return messages def generate_encoded_examples(self, node: Object) -> List[Tuple[str, str]]: @@ -93,47 +134,21 @@ def generate_encoded_examples(self, node: Object) -> List[Tuple[str, str]]: examples, self.encoder, input_formatter=self.input_formatter ) - def generate_instruction_segment(self, node: Object) -> str: + def format_instruction_segment(self, node: Object) -> str: """Generate the instruction segment of the extraction.""" type_description = self.type_descriptor.describe(node) - instruction_segment = self.encoder.get_instruction_segment() - return f"{self.prefix}\n\n{type_description}\n\n{instruction_segment}" + format_instructions = self.encoder.get_instruction_segment() + input_variables = self.instruction_template.input_variables + formatting_kwargs = {} -class ExtractionPromptTemplate(BasePromptTemplate): - """Extraction prompt template.""" + if "type_description" in input_variables: + formatting_kwargs["type_description"] = type_description - encoder: Encoder - node: Object - type_descriptor: TypeDescriptor - input_formatter: InputFormatter + if "format_instructions" in input_variables: + formatting_kwargs["format_instructions"] = format_instructions - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def format_prompt( # type: ignore[override] - self, text: str, **kwargs: Any - ) -> PromptValue: - """Format the prompt.""" - return ExtractionPromptValue( - text=text, - encoder=self.encoder, - node=self.node, - type_descriptor=self.type_descriptor, - input_formatter=self.input_formatter, - ) - - def format(self, **kwargs: Any) -> str: - """Implementation of deprecated format method.""" - raise NotImplementedError() - - @property - def _prompt_type(self) -> str: - """Prompt type.""" - return "ExtractionPromptTemplate" + return self.instruction_template.format(**formatting_kwargs) # PUBLIC API @@ -143,8 +158,10 @@ def create_langchain_prompt( schema: Object, encoder: Encoder, type_descriptor: TypeDescriptor, + *, validator: Optional[Validator] = None, input_formatter: InputFormatter = None, + instruction_template: Optional[PromptTemplate] = None, ) -> ExtractionPromptTemplate: """Create a langchain style prompt with specified encoder.""" return ExtractionPromptTemplate( @@ -154,4 +171,5 @@ def create_langchain_prompt( node=schema, input_formatter=input_formatter, type_descriptor=type_descriptor, + instruction_template=instruction_template or DEFAULT_INSTRUCTION_TEMPLATE, ) diff --git a/tests/test_extraction.py b/tests/test_extraction.py index d427a8c..93f4a27 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -2,6 +2,7 @@ from typing import Any, List, Mapping, Optional import pytest +from langchain import PromptTemplate from langchain.chains import LLMChain from langchain.chat_models.base import BaseChatModel from langchain.schema import ( @@ -129,3 +130,30 @@ def test_not_implemented_assertion_raised_for_csv(options: Mapping[str, Any]) -> with pytest.raises(NotImplementedError): create_extraction_chain(chat_model, **options) + + +def test_using_custom_template() -> None: + """Create an extraction chain with a custom template.""" + template = PromptTemplate( + input_variables=["format_instructions", "type_description"], + template=( + "custom_prefix\n" + "{type_description}\n\n" + "{format_instructions}\n" + "custom_suffix" + ), + ) + chain = create_extraction_chain( + ToyChatModel(response="hello"), + OBJECT_SCHEMA_WITH_MANY, + instruction_template=template, + encoder_or_encoder_class="json", + ) + prompt_value = chain.prompt.format_prompt(text="hello") + system_message = prompt_value.to_messages()[0] + string_value = prompt_value.to_string() + + assert "custom_prefix" in string_value + assert "custom_suffix" in string_value + assert "custom_prefix" in system_message.content + assert "custom_suffix" in system_message.content