-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Apply input formatter to user text, not only examples (#119)
After this PR, the input formatter, if provided, will be applied to both the input part of the examples as well as to the actual user input. Before this, it was not applied to user input.
- Loading branch information
Showing
3 changed files
with
48 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import pytest | ||
|
||
from kor import JSONEncoder, Object, TypeScriptDescriptor | ||
from kor.encoders import InputFormatter | ||
from kor.prompts import create_langchain_prompt | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_formatter, expected_string", | ||
[ | ||
(None, "user input"), | ||
("triple_quotes", '"""\nuser input\n"""'), | ||
("text_prefix", 'Text: """\nuser input\n"""'), | ||
], | ||
) | ||
def test_input_formatter_applied_correctly( | ||
input_formatter: InputFormatter, expected_string: str | ||
) -> None: | ||
untyped_object = Object( | ||
id="obj", | ||
examples=[("text", {"text": "text"})], | ||
attributes=[], | ||
) | ||
prompt = create_langchain_prompt( | ||
untyped_object, | ||
JSONEncoder(), | ||
TypeScriptDescriptor(), | ||
input_formatter=input_formatter, | ||
) | ||
|
||
prompt_value = prompt.format_prompt(text="user input") | ||
|
||
assert prompt_value.to_messages()[-1].content == expected_string | ||
assert expected_string in prompt_value.to_string() |