Skip to content

Commit

Permalink
Allow customizing the prompt (#121)
Browse files Browse the repository at this point in the history
Allow customizing the prompt.

```python
from langchain.prompts import PromptTemplate

DEFAULT_PROMPT_TEMPLATE = PromptTemplate(
    input_variables=["format_instructions", "type_description"],
    template=(
        "Write some stuff here n\n"
        "{type_description}\n\n"
        "{format_instructions}"
        "Suffix heren\n"
    ),
)


chain = create_extraction_chain(llm, schema, instruction_template=DEFAULT_PROMPT_TEMPLATE)

print(chain.prompt.format_prompt(text='hello').to_string())
```
  • Loading branch information
eyurtsev authored Apr 6, 2023
1 parent d1ea14b commit 52db60a
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 54 deletions.
12 changes: 10 additions & 2 deletions kor/extraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Optional, Type, Union

from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.schema import BaseLanguageModel

Expand All @@ -20,6 +21,7 @@ def create_extraction_chain(
type_descriptor: Union[TypeDescriptor, str] = "typescript",
validator: Optional[Validator] = None,
input_formatter: InputFormatter = None,
instruction_template: Optional[PromptTemplate] = None,
**encoder_kwargs: Any,
) -> LLMChain:
"""Create an extraction chain.
Expand All @@ -34,11 +36,16 @@ def create_extraction_chain(
validator: optional validator to use for validation
input_formatter: the formatter to use for encoding the input. Used for \
both input examples and the text to be analyzed.
* `None`: use for single sentences or single paragraph, no formatting
* `triple_quotes`: for long text, surround input with \"\"\"
* `text_prefix`: for long text, triple_quote with `TEXT: ` prefix
* `Callable`: user provided function
instruction_template: optional prompt template to use, use to over-ride prompt
used for generating the instruction section of the prompt.
It accepts 2 optional input variables:
* "type_description": type description of the node (from TypeDescriptor)
* "format_instructions": information on how to format the output
(from Encoder)
encoder_kwargs: Keyword arguments to pass to the encoder class
Returns:
Expand Down Expand Up @@ -66,7 +73,8 @@ def create_extraction_chain(
node,
encoder,
type_descriptor_to_use,
validator,
validator=validator,
instruction_template=instruction_template,
input_formatter=input_formatter,
),
)
122 changes: 70 additions & 52 deletions kor/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional, Tuple

from langchain import BasePromptTemplate
from langchain.prompts import PromptTemplate
from langchain.schema import (
AIMessage,
BaseMessage,
Expand All @@ -22,35 +23,77 @@

from .validators import Validator

DEFAULT_INSTRUCTION_TEMPLATE = PromptTemplate(
input_variables=["type_description", "format_instructions"],
template=(
"Your goal is to extract structured information from the user's input that"
" matches the form described below. When extracting information please make"
" sure it matches the type information exactly. Do not add any attributes that"
" do not appear in the schema shown below.\n\n"
"{type_description}\n\n"
"{format_instructions}\n\n"
),
)


class ExtractionPromptValue(PromptValue):
"""Integration with langchain prompt format."""

text: str
string: str
messages: List[BaseMessage]

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def to_string(self) -> str:
"""Format the prompt to a string."""
return self.string

def to_messages(self) -> List[BaseMessage]:
"""Get materialized messages."""
return self.messages


class ExtractionPromptTemplate(BasePromptTemplate):
"""Extraction prompt template."""

encoder: Encoder
node: Object
type_descriptor: TypeDescriptor
input_formatter: InputFormatter
prefix: str = (
"Your goal is to extract structured information from the user's input that"
" matches the form described below. When extracting information please make"
" sure it matches the type information exactly. Do not add any attributes that"
" do not appear in the schema shown below."
)
instruction_template: PromptTemplate

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def get_formatted_text(self) -> str:
"""Get the text encoded if needed."""
return format_text(self.text, input_formatter=self.input_formatter)
def format_prompt( # type: ignore[override]
self,
text: str,
) -> PromptValue:
"""Format the prompt."""
text = format_text(text, input_formatter=self.input_formatter)
return ExtractionPromptValue(
string=self.to_string(text), messages=self.to_messages(text)
)

def to_string(self) -> str:
def format(self, **kwargs: Any) -> str:
"""Implementation of deprecated format method."""
raise NotImplementedError()

@property
def _prompt_type(self) -> str:
"""Prompt type."""
return "ExtractionPromptTemplate"

def to_string(self, text: str) -> str:
"""Format the template to a string."""
instruction_segment = self.generate_instruction_segment(self.node)
instruction_segment = self.format_instruction_segment(self.node)
encoded_examples = self.generate_encoded_examples(self.node)
formatted_examples: List[str] = []

Expand All @@ -62,14 +105,13 @@ def to_string(self) -> str:
]
)

text = self.get_formatted_text()
formatted_examples.append(f"Input: {text}\nOutput:")
input_output_block = "\n".join(formatted_examples)
return f"{instruction_segment}\n\n{input_output_block}"

def to_messages(self) -> List[BaseMessage]:
def to_messages(self, text: str) -> List[BaseMessage]:
"""Format the template to chat messages."""
instruction_segment = self.generate_instruction_segment(self.node)
instruction_segment = self.format_instruction_segment(self.node)

messages: List[BaseMessage] = [SystemMessage(content=instruction_segment)]
encoded_examples = self.generate_encoded_examples(self.node)
Expand All @@ -82,8 +124,7 @@ def to_messages(self) -> List[BaseMessage]:
]
)

content = self.get_formatted_text()
messages.append(HumanMessage(content=content))
messages.append(HumanMessage(content=text))
return messages

def generate_encoded_examples(self, node: Object) -> List[Tuple[str, str]]:
Expand All @@ -93,47 +134,21 @@ def generate_encoded_examples(self, node: Object) -> List[Tuple[str, str]]:
examples, self.encoder, input_formatter=self.input_formatter
)

def generate_instruction_segment(self, node: Object) -> str:
def format_instruction_segment(self, node: Object) -> str:
"""Generate the instruction segment of the extraction."""
type_description = self.type_descriptor.describe(node)
instruction_segment = self.encoder.get_instruction_segment()
return f"{self.prefix}\n\n{type_description}\n\n{instruction_segment}"
format_instructions = self.encoder.get_instruction_segment()
input_variables = self.instruction_template.input_variables

formatting_kwargs = {}

class ExtractionPromptTemplate(BasePromptTemplate):
"""Extraction prompt template."""
if "type_description" in input_variables:
formatting_kwargs["type_description"] = type_description

encoder: Encoder
node: Object
type_descriptor: TypeDescriptor
input_formatter: InputFormatter
if "format_instructions" in input_variables:
formatting_kwargs["format_instructions"] = format_instructions

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def format_prompt( # type: ignore[override]
self, text: str, **kwargs: Any
) -> PromptValue:
"""Format the prompt."""
return ExtractionPromptValue(
text=text,
encoder=self.encoder,
node=self.node,
type_descriptor=self.type_descriptor,
input_formatter=self.input_formatter,
)

def format(self, **kwargs: Any) -> str:
"""Implementation of deprecated format method."""
raise NotImplementedError()

@property
def _prompt_type(self) -> str:
"""Prompt type."""
return "ExtractionPromptTemplate"
return self.instruction_template.format(**formatting_kwargs)


# PUBLIC API
Expand All @@ -143,8 +158,10 @@ def create_langchain_prompt(
schema: Object,
encoder: Encoder,
type_descriptor: TypeDescriptor,
*,
validator: Optional[Validator] = None,
input_formatter: InputFormatter = None,
instruction_template: Optional[PromptTemplate] = None,
) -> ExtractionPromptTemplate:
"""Create a langchain style prompt with specified encoder."""
return ExtractionPromptTemplate(
Expand All @@ -154,4 +171,5 @@ def create_langchain_prompt(
node=schema,
input_formatter=input_formatter,
type_descriptor=type_descriptor,
instruction_template=instruction_template or DEFAULT_INSTRUCTION_TEMPLATE,
)
28 changes: 28 additions & 0 deletions tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List, Mapping, Optional

import pytest
from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
Expand Down Expand Up @@ -129,3 +130,30 @@ def test_not_implemented_assertion_raised_for_csv(options: Mapping[str, Any]) ->

with pytest.raises(NotImplementedError):
create_extraction_chain(chat_model, **options)


def test_using_custom_template() -> None:
"""Create an extraction chain with a custom template."""
template = PromptTemplate(
input_variables=["format_instructions", "type_description"],
template=(
"custom_prefix\n"
"{type_description}\n\n"
"{format_instructions}\n"
"custom_suffix"
),
)
chain = create_extraction_chain(
ToyChatModel(response="hello"),
OBJECT_SCHEMA_WITH_MANY,
instruction_template=template,
encoder_or_encoder_class="json",
)
prompt_value = chain.prompt.format_prompt(text="hello")
system_message = prompt_value.to_messages()[0]
string_value = prompt_value.to_string()

assert "custom_prefix" in string_value
assert "custom_suffix" in string_value
assert "custom_prefix" in system_message.content
assert "custom_suffix" in system_message.content

0 comments on commit 52db60a

Please sign in to comment.