From 94d6bcb3fc58922fd11217141ff0a410ea17b63c Mon Sep 17 00:00:00 2001 From: Boris Wilhelms Date: Tue, 27 Jun 2023 23:31:21 +0200 Subject: [PATCH] Load schema from JSON (#178) This PR makes the schema loadable from JSON via `Object.parse_raw`. PR includes feature, tests and documentation Link to Issue #177 --- CONTRIBUTING.md | 2 +- docs/source/index.md | 1 + docs/source/schema_from_json.ipynb | 221 +++++++++++++++++++++++++++++ kor/nodes.py | 20 +++ tests/test_deserialization.py | 159 +++++++++++++++++++++ 5 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 docs/source/schema_from_json.ipynb create mode 100644 tests/test_deserialization.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7a869fa..31dfd83 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,7 +19,7 @@ The package uses [poetry](https://python-poetry.org/) together with ```shell -poetry install --with dev,test,doc +poetry install --with dev,test,docs ``` ### List tasks diff --git a/docs/source/index.md b/docs/source/index.md index 30b601d..de5f130 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -194,6 +194,7 @@ untyped_objects apis validation document_extraction +schema_from_json guidelines ``` diff --git a/docs/source/schema_from_json.ipynb b/docs/source/schema_from_json.ipynb new file mode 100644 index 0000000..fb000e1 --- /dev/null +++ b/docs/source/schema_from_json.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "4b3a0584-b52c-4873-abb8-8382e13ff5c0", + "metadata": {}, + "source": [ + "# Schema from JSON\n", + "\n", + "Kor lets you define the schema in JSON. The structure of the JSON matches the struture of the `Object` type.\n", + "\n", + "The following attribute types must be annotated with a type descrimintator (`$type`):\n", + "\n", + "- Number\n", + "- Text\n", + "- Bool\n", + "- Selection" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "0b4597b2-2a43-4491-8830-bf9f79428074", + "metadata": { + "nbsphinx": "hidden", + "tags": [ + "remove-cell" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "\n", + "sys.path.insert(0, \"../../\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3bd33817", + "metadata": {}, + "outputs": [], + "source": [ + "json = \"\"\"\n", + "{\n", + " \"id\": \"personal_info\",\n", + " \"description\": \"Personal information about a given person.\",\n", + " \"attributes\": [\n", + " {\n", + " \"$type\": \"Text\",\n", + " \"id\": \"first_name\",\n", + " \"description\": \"The first name of the person\",\n", + " \"examples\": [[\"John Smith went to the store\", \"John\"]]\n", + " },\n", + " {\n", + " \"$type\": \"Text\",\n", + " \"id\": \"last_name\",\n", + " \"description\": \"The last name of the person\",\n", + " \"examples\": [[\"John Smith went to the store\", \"Smith\"]]\n", + " },\n", + " {\n", + " \"$type\": \"Number\",\n", + " \"id\": \"age\",\n", + " \"description\": \"The age of the person in years.\",\n", + " \"examples\": [[\"23 years old\", \"23\"], [\"I turned three on sunday\", \"3\"]]\n", + " }\n", + " ],\n", + " \"examples\": [\n", + " [\n", + " \"John Smith was 23 years old. He was very tall. He knew Jane Doe. She was 5 years old.\",\n", + " [\n", + " {\"first_name\": \"John\", \"last_name\": \"Smith\", \"age\": 23},\n", + " {\"first_name\": \"Jane\", \"last_name\": \"Doe\", \"age\": 5}\n", + " ]\n", + " ]\n", + " ],\n", + " \"many\": true\n", + "}\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "6088c98a", + "metadata": {}, + "outputs": [], + "source": [ + "from kor import Object\n", + "schema = Object.parse_raw(json)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "718c66a7-6186-4ed8-87e9-5ed28e3f209e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from kor.extraction import create_extraction_chain\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9bc98f35-ea5f-4b74-a32e-a300a22c0c89", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm = ChatOpenAI(\n", + " model_name=\"gpt-3.5-turbo\",\n", + " temperature=0,\n", + " max_tokens=2000,\n", + " model_kwargs={\"frequency_penalty\":0,\"presence_penalty\":0, \"top_p\": 1.0}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "54a199a5-24b4-442c-8907-1449e437a880", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chain = create_extraction_chain(llm, schema)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "193e257b-df01-45ec-af77-076d2070533b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'personal_info': [{'first_name': 'Eugene', 'last_name': '', 'age': '18'}]}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.predict_and_parse(text=\"Eugene was 18 years old a long time ago.\")[\"data\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "c8295f36-f986-4db2-97bc-ef2e6cdbcc87", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'personal_info': [{'first_name': 'Bob', 'last_name': 'Alice', 'age': ''}, {'first_name': 'Moana', 'last_name': 'Sunrise', 'age': '10'}]}\n" + ] + } + ], + "source": [ + "chain = create_extraction_chain(llm, schema)\n", + "print(\n", + " chain.predict_and_parse(\n", + " text=(\n", + " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", + " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", + " )\n", + " )[\"data\"]\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/kor/nodes.py b/kor/nodes.py index 2913df4..eeb3cd1 100644 --- a/kor/nodes.py +++ b/kor/nodes.py @@ -11,6 +11,7 @@ Optional, Sequence, Tuple, + Type, TypeVar, Union, ) @@ -129,6 +130,25 @@ class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC): examples: Sequence[Tuple[str, Union[str, Sequence[str]]]] = tuple() + @classmethod + def parse_obj(cls: Type[ExtractionSchemaNode], data: dict) -> ExtractionSchemaNode: + type = data.get("$type") + if type is None: + raise ValueError("Need to specify type ($type)") + for sub in cls.__subclasses__(): + if type == sub.__name__: + return sub(**data) + raise TypeError(f"Unknown sub-type: {type}") + + @classmethod + def validate(cls: Type[ExtractionSchemaNode], v: Any) -> ExtractionSchemaNode: + if isinstance(v, dict): + return cls.parse_obj(v) + elif isinstance(v, cls): + return v + else: + raise TypeError(f"Unsupported type: {type(v)}") + class Number(ExtractionSchemaNode): """Built-in number input.""" diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py new file mode 100644 index 0000000..35db8f1 --- /dev/null +++ b/tests/test_deserialization.py @@ -0,0 +1,159 @@ +import pytest + +from kor import Bool, Number, Object, Selection, Text + + +def test_simple_deserialization() -> None: + json = """ + { + "id": "sample_object", + "description": "Deserialization Example", + "many": true, + "attributes": [ + { + "$type": "Number", + "id": "number_attribute", + "description": "Description for Number", + "many": true, + "examples": [ + ["Here is 1 number", 1], + ["Here are 0 numbers", 0] + ] + }, + { + "$type": "Text", + "id": "text_attribute", + "description": "Description for Text", + "many": true, + "examples": [ + ["Here is a text", "a text"], + ["Here is no text", "no text"] + ] + }, + { + "$type": "Bool", + "id": "bool_attribute", + "description": "Description for Bool", + "many": true, + "examples": [ + ["This is soo true", true], + ["This is wrong", false] + ] + }, + { + "$type": "Selection", + "id": "selection_attribute", + "description": "Description for Selection", + "many": true, + "options": [ + { + "id": "option1", + "description": "description for option 1" + }, + { + "id": "option2", + "description": "description for option 2" + } + ], + "examples": [ + ["This is soo true", true], + ["This is wrong", false] + ] + } + ] + } + """ + scheme = Object.parse_raw(json) + + assert scheme.id == "sample_object" + assert scheme.description == "Deserialization Example" + assert scheme.many is True + + assert isinstance(scheme.attributes[0], Number) + assert scheme.attributes[0].id == "number_attribute" + assert scheme.attributes[0].description == "Description for Number" + assert scheme.attributes[0].many is True + assert len(scheme.attributes[0].examples) == 2 + + assert isinstance(scheme.attributes[1], Text) + assert scheme.attributes[1].id == "text_attribute" + assert scheme.attributes[1].description == "Description for Text" + assert scheme.attributes[1].many is True + assert len(scheme.attributes[1].examples) == 2 + + assert isinstance(scheme.attributes[2], Bool) + assert scheme.attributes[2].id == "bool_attribute" + assert scheme.attributes[2].description == "Description for Bool" + assert scheme.attributes[2].many is True + assert len(scheme.attributes[2].examples) == 2 + + assert isinstance(scheme.attributes[3], Selection) + assert scheme.attributes[3].id == "selection_attribute" + assert scheme.attributes[3].description == "Description for Selection" + assert scheme.attributes[3].many is True + assert len(scheme.attributes[3].options) == 2 + assert len(scheme.attributes[3].examples) == 2 + + +def test_nested_object_deserialization() -> None: + json = """ + { + "id": "root_object", + "description": "Deserialization Example", + "many": true, + "attributes": [ + { + "id": "nested_object", + "description": "Description nested object", + "many": true, + "attributes": [ + { + "$type": "Number", + "id": "number_attribute", + "description": "Description for Number", + "many": true, + "examples": [ + ["Here is 1 number", 1], + ["Here are 0 numbers", 0] + ] + } + ] + } + ] + } + """ + scheme = Object.parse_raw(json) + + assert scheme.id == "root_object" + assert scheme.description == "Deserialization Example" + assert scheme.many is True + + assert isinstance(scheme.attributes[0], Object) + assert scheme.attributes[0].id == "nested_object" + assert scheme.attributes[0].description == "Description nested object" + assert scheme.attributes[0].many is True + assert len(scheme.attributes[0].attributes) == 1 + + +def test_extractionschemanode_without_type_cannot_be_deserialized() -> None: + json = """ + { + "id": "root_object", + "description": "Deserialization Example", + "many": true, + "attributes": [ + { + "id": "number_attribute", + "description": "Description for Number", + "many": true, + "examples": [ + ["Here is 1 number", 1], + ["Here are 0 numbers", 0] + ] + } + ] + } + """ + + with pytest.raises(ValueError): + Object.parse_raw(json)