Skip to content

Commit

Permalink
Fix to allow encoding non ascii in JSON (#153)
Browse files Browse the repository at this point in the history
Allow encoding non ascii characters in JSON.
  • Loading branch information
eyurtsev authored May 7, 2023
1 parent fed4215 commit aa10b41
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 37 deletions.
44 changes: 37 additions & 7 deletions kor/encoders/json_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,45 @@ class JSONEncoder(Encoder):
within the LLM response and extract it.
The usage of <json> tags is similar to the usage of ```JSON and ``` marks.
Examples:
.. code-block:: python
from kor import JSONEncoder
json_encoder = JSONEncoder(use_tags=True)
json_encoder.encode({"object": [{"a": 1}]})
# '<json>{"object": [{"a": 1}]}</json>'
json_encoder = JSONEncoder(use_tags=True, ensure_ascii=False)
data = {"name": "Café"}
json_encoder.encode(data)
# '<json>{"name": "Café"}</json>'
"""

def __init__(self, use_tags: bool = True) -> None:
def __init__(self, use_tags: bool = True, ensure_ascii: bool = True) -> None:
"""Initialize the JSON encoder.
Args:
use_tags: Whether to wrap the output in a special JSON tags.
This may help identify the JSON content in cases when
the model attempts to add clarifying explanations.
Args:
use_tags: Whether to wrap the output in a special JSON tags.
This may help identify the JSON content in cases when
the model attempts to add clarifying explanations.
ensure_ascii: Whether to escape non-ASCII characters.
data = {"name": "Café"}
# Using ensure_ascii=True (default)
json_str = json.dumps(data)
print(json_str) # {"name": "Caf\u00e9"}
# Using ensure_ascii=False
json_str = json.dumps(data, ensure_ascii=False)
print(json_str) # {"name": "Café"}
"""
self.use_tags = use_tags
self.ensure_ascii = ensure_ascii

def encode(self, data: Any) -> str:
"""Encode the data as JSON.
Expand All @@ -40,7 +68,7 @@ def encode(self, data: Any) -> str:
"""
content = json.dumps(data)
if self.use_tags:
return wrap_in_tag("json", json.dumps(data))
return wrap_in_tag("json", json.dumps(data, ensure_ascii=self.ensure_ascii))
return content

def decode(self, text: str) -> Any:
Expand All @@ -63,7 +91,9 @@ def decode(self, text: str) -> Any:
if content is None:
return {}
try:
return json.loads(content)
return json.loads(
content,
)
except json.JSONDecodeError as e:
raise ParseError(e)

Expand Down
2 changes: 2 additions & 0 deletions tests/extraction/test_extraction_with_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
{"encoder_or_encoder_class": "csv", "input_formatter": None},
{"encoder_or_encoder_class": "csv", "input_formatter": "text_prefix"},
{"encoder_or_encoder_class": "json"},
{"encoder_or_encoder_class": "json", "ensure_ascii": False},
{"encoder_or_encoder_class": "json", "ensure_ascii": True},
{"encoder_or_encoder_class": "xml"},
{"encoder_or_encoder_class": JSONEncoder()},
{"encoder_or_encoder_class": JSONEncoder},
Expand Down
30 changes: 0 additions & 30 deletions tests/test_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,6 @@ def test_xml_encoding(node_data: Any, expected: str) -> None:
assert xml_encoder.encode(node_data) == expected


@pytest.mark.parametrize(
"node_data,expected",
[
({"object": [{"number": ["1"]}]}, '{"object": [{"number": ["1"]}]}'),
({"object": [{"text": ["3"]}]}, '{"object": [{"text": ["3"]}]}'),
(
{"object": [{"selection": ["option"]}]},
'{"object": [{"selection": ["option"]}]}',
),
],
)
def test_json_encoding(node_data: Any, expected: str) -> None:
"""Test JSON encoding. This is just json.dumps, so no need to test extensively."""
json_encoder = JSONEncoder(use_tags=False)
assert json_encoder.encode(node_data) == expected
assert json_encoder.decode(expected) == node_data


def test_json_encoding_with_tags() -> None:
"""Test JSON encoder with content wrapped in tags."""
json_encoder = JSONEncoder(use_tags=True)
assert (
json_encoder.encode({"object": [{"a": 1}]})
== '<json>{"object": [{"a": 1}]}</json>'
)
assert json_encoder.decode('<json>{"object": [{"a": 1}]}</json>') == {
"object": [{"a": 1}]
}


class NoOpEncoder(Encoder):
def encode(self, data: Any) -> str:
"""Identity function for encoding."""
Expand Down
54 changes: 54 additions & 0 deletions tests/test_encoders/test_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any

import pytest

from kor import JSONEncoder


@pytest.mark.parametrize(
"node_data,expected",
[
({"object": [{"number": ["1"]}]}, '{"object": [{"number": ["1"]}]}'),
({"object": [{"text": ["3"]}]}, '{"object": [{"text": ["3"]}]}'),
(
{"object": [{"selection": ["option"]}]},
'{"object": [{"selection": ["option"]}]}',
),
],
)
def test_json_encoding(node_data: Any, expected: str) -> None:
"""Test JSON encoding. This is just json.dumps, so no need to test extensively."""
json_encoder = JSONEncoder(use_tags=False)
assert json_encoder.encode(node_data) == expected
assert json_encoder.decode(expected) == node_data


def test_json_encoding_with_tags() -> None:
"""Test JSON encoder with content wrapped in tags."""
json_encoder = JSONEncoder(use_tags=True)
assert (
json_encoder.encode({"object": [{"a": 1}]})
== '<json>{"object": [{"a": 1}]}</json>'
)
assert json_encoder.decode('<json>{"object": [{"a": 1}]}</json>') == {
"object": [{"a": 1}]
}


def test_json_encoding_with_non_ascii_chars() -> None:
"""Test json encoder with chinese characters."""
text = "我喜欢珍珠奶茶"

# Test encoding / decoding with chinese characters and ensure_ascii = True
json_encoder = JSONEncoder(use_tags=True, ensure_ascii=True)
assert (
json_encoder.encode(text)
== '<json>"\\u6211\\u559c\\u6b22\\u73cd\\u73e0\\u5976\\u8336"</json>'
)

# Test encoding/decoding with chinese characters and ensure_ascii = False
assert json_encoder.decode(json_encoder.encode(text)) == "我喜欢珍珠奶茶"

json_encoder = JSONEncoder(use_tags=True, ensure_ascii=False)
assert json_encoder.encode(text) == '<json>"我喜欢珍珠奶茶"</json>'
assert json_encoder.decode('<json>"我喜欢珍珠奶茶"</json>') == text

0 comments on commit aa10b41

Please sign in to comment.