Skip to content

Commit

Permalink
Apply input formatter to user text, not only examples (#119)
Browse files Browse the repository at this point in the history
After this PR, the input formatter, if provided, will be 
applied to both the input part of the examples as well as to 
the actual user input.

Before this, it was not applied to user input.
  • Loading branch information
eyurtsev authored Apr 4, 2023
1 parent 57dfcce commit ca85381
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
10 changes: 5 additions & 5 deletions kor/encoders/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
]


def _format_text(text: str, input_formatter: InputFormatter = None) -> str:
# PUBLIC API


def format_text(text: str, input_formatter: InputFormatter = None) -> str:
"""An encoder for the input text.
Args:
Expand All @@ -46,9 +49,6 @@ def _format_text(text: str, input_formatter: InputFormatter = None) -> str:
)


# PUBLIC API


def encode_examples(
examples: Sequence[Tuple[str, str]],
encoder: Encoder,
Expand All @@ -58,7 +58,7 @@ def encode_examples(

return [
(
_format_text(input_example, input_formatter=input_formatter),
format_text(input_example, input_formatter=input_formatter),
encoder.encode(output_example),
)
for input_example, output_example in examples
Expand Down
12 changes: 9 additions & 3 deletions kor/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import Extra

from kor.encoders import Encoder
from kor.encoders.encode import InputFormatter, encode_examples
from kor.encoders.encode import InputFormatter, encode_examples, format_text
from kor.encoders.parser import KorParser
from kor.examples import generate_examples
from kor.nodes import Object
Expand Down Expand Up @@ -44,6 +44,10 @@ class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

def get_formatted_text(self) -> str:
"""Get the text encoded if needed."""
return format_text(self.text, input_formatter=self.input_formatter)

def to_string(self) -> str:
"""Format the template to a string."""
instruction_segment = self.generate_instruction_segment(self.node)
Expand All @@ -58,7 +62,8 @@ def to_string(self) -> str:
]
)

formatted_examples.append(f"Input: {self.text}\nOutput:")
text = self.get_formatted_text()
formatted_examples.append(f"Input: {text}\nOutput:")
input_output_block = "\n".join(formatted_examples)
return f"{instruction_segment}\n\n{input_output_block}"

Expand All @@ -77,7 +82,8 @@ def to_messages(self) -> List[BaseMessage]:
]
)

messages.append(HumanMessage(content=self.text))
content = self.get_formatted_text()
messages.append(HumanMessage(content=content))
return messages

def generate_encoded_examples(self, node: Object) -> List[Tuple[str, str]]:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_prompt_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from kor import JSONEncoder, Object, TypeScriptDescriptor
from kor.encoders import InputFormatter
from kor.prompts import create_langchain_prompt


@pytest.mark.parametrize(
"input_formatter, expected_string",
[
(None, "user input"),
("triple_quotes", '"""\nuser input\n"""'),
("text_prefix", 'Text: """\nuser input\n"""'),
],
)
def test_input_formatter_applied_correctly(
input_formatter: InputFormatter, expected_string: str
) -> None:
untyped_object = Object(
id="obj",
examples=[("text", {"text": "text"})],
attributes=[],
)
prompt = create_langchain_prompt(
untyped_object,
JSONEncoder(),
TypeScriptDescriptor(),
input_formatter=input_formatter,
)

prompt_value = prompt.format_prompt(text="user input")

assert prompt_value.to_messages()[-1].content == expected_string
assert expected_string in prompt_value.to_string()

0 comments on commit ca85381

Please sign in to comment.