From edcb1d86e58f7eb3764620b8fbb7f37e36020148 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 10 Mar 2023 11:50:18 -0500 Subject: [PATCH] Add lots of doc examples (#9) Add a lot of documentation, add some missing doc strings other minor changes. --- docs/extraction/kor_01.ipynb | 1526 +++++++++++++++++++++++++++------- kor/experimental/blocks.py | 2 +- kor/extraction.py | 40 +- kor/llms/openai.py | 6 +- kor/nodes.py | 13 +- kor/prompts.py | 3 +- kor/type_descriptors.py | 20 +- 7 files changed, 1294 insertions(+), 316 deletions(-) diff --git a/docs/extraction/kor_01.ipynb b/docs/extraction/kor_01.ipynb index 458ffbb..332a256 100644 --- a/docs/extraction/kor_01.ipynb +++ b/docs/extraction/kor_01.ipynb @@ -7,260 +7,1059 @@ "source": [ "# Kor\n", "\n", - "This notebooks shows a few extraction examples. \n", + "This notebooks shows a few extraction examples using the library. \n", "\n", - "Please pay attention to errors that are made to better understand limitations. \n", - "This may not be the best approach for information extraction using an LLM." + "Please pay attention to errors that are made to better understand limitations.\n", + "\n", + "Again, this may not be a good approach for robust information extraction using LLMs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fe40bbba-e7ec-4cf9-ab79-16139e6e7e94", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ea602017-6b8a-4b42-850a-686cae35809f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "import pprint" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ae7d98ba-5ca8-4ce5-b54a-ef100affa9f5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sys.path.insert(0, '../../')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "718c66a7-6186-4ed8-87e9-5ed28e3f209e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from kor.extraction import Extractor\n", + "from kor.nodes import Object, Text, Number\n", + "from kor.llms import OpenAIChatCompletion, OpenAICompletion" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "91859675-85a7-4b62-9368-cbb441cbe355", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm = OpenAIChatCompletion(model='gpt-3.5-turbo')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "25b908c2-6a02-49eb-9add-847e0c4017cd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Uncomment to use davinci model. This model may work better in some cases, but it's more expensive.\n", + "# llm = OpenAICompletion(model='text-davinci-003')" + ] + }, + { + "cell_type": "markdown", + "id": "afdabe59-b7d4-4b65-ac27-b91c37fabcd3", + "metadata": {}, + "source": [ + "# Create an extractor" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9666b3b9-e48e-41ab-91b5-7bc6ec5983df", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model = Extractor(llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2d4d94fc-a1cb-4b55-a2c8-1dcec9bfe7d5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help on function __init__ in module kor.extraction:\n", + "\n", + "__init__(self, model: Union[kor.llms.typedefs.CompletionModel, kor.llms.typedefs.ChatCompletionModel], prompt_generator: kor.prompts.PromptGenerator = ExtractionTemplate(prefix=\"Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. \", type_descriptor='TypeScript', suffix=\"For Union types the output must EXACTLY match one of the members of the Union type.\\n\\nPlease enclose the extracted information in HTML style tags with the tag name corresponding to the corresponding component ID. Use angle style brackets for the tags ('>' and '<'). Only output tags when you're confident about the information that was extracted from the user's query. If you can extract several pieces of relevant information from the query, then include all of them. If the type is an array, please repeat the corresponding tag name multiple times once for each relevant extraction. \", example_generator=)) -> None\n", + " Initialize self. See help(type(self)) for accurate signature.\n", + "\n" + ] + } + ], + "source": [ + "help(Extractor.__init__)" + ] + }, + { + "cell_type": "markdown", + "id": "6645f896-c969-444d-b9f2-85318abb79d6", + "metadata": {}, + "source": [ + "# Define a simple schema " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8c97a013-5443-442a-a87c-b3f4bff21bf6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "schema = Text(id='first_name', description='The first name of a person', examples=[('I am billy.', 'billy'), ('John Smith is 33 years old', 'John')])" + ] + }, + { + "cell_type": "markdown", + "id": "a6849441-f9ae-468d-9003-b26bfa0253dd", + "metadata": { + "tags": [] + }, + "source": [ + "## Run Extraction" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f6102c88-4147-43f4-800f-28e1ce6b2aa2", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'first_name': ['Tom', 'Bobby']}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My name is Tom. I am a cat. My best friend is Bobby. He is not a cat.', schema) " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "69148111-77e1-450c-ba38-fd46c3730de2", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'first_name': ['WOW']}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My name is My name is My name is WOW.', schema) " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "000dc5f9-ff7b-466d-ae18-142490129ba6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'first_name': ['MOO', 'MOO', 'MOO']}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note the extraction here. It's unlikely to be reasonable.\n", + "model('My name is My name is My name is MOO MOO.', schema)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4ccc9bf3-022c-4fdb-a018-7e1f422734a9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'first_name': ['Bobby', 'Cobby']}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My name is Bobby. My brother\\'s name is the same as mine except that it starts with a `C`.', schema)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ebbe576b-0796-448a-99c4-ce903917ad27", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'first_name': ['Bobby']}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My name is Bobby. My brother\\'s name rhymes with mine.', schema)" + ] + }, + { + "cell_type": "markdown", + "id": "b74b5145-6360-4091-afca-4690a0a07d2f", + "metadata": {}, + "source": [ + "# Extract a phone number" + ] + }, + { + "cell_type": "markdown", + "id": "2726f5d4-3f1a-4908-9e71-84b8a84ab309", + "metadata": { + "tags": [] + }, + "source": [ + "## Sometimes it might work without examples, but having a few examples is recommended" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "086f5465-6455-4e90-9984-a65cbc2cf6fa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "schema = Text(id='phone_number', description='Any phone numbers found in the text format should be 9 digit')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f44d099a-0878-4d4b-b9a7-8fb14848d6f4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'phone_number': ['(123)4449999', '(333)1232832']}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My phone number is (123)-444-9999. I found my true love one on a blue sunday. Her number was (333)1232832', schema)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "559f5f66-2f85-4fc9-9a66-2fd3842041c8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "schema = Text(id='phone_number', description='Any phone numbers found in the text format should be 9 digit', examples=[('Someone called me from 123-123-1234', '123-123-1234')])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5571951a-bbaf-4c94-a749-60573fa6539d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'phone_number': ['(123)-444-9999', '(333)1232832']}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My phone number is (123)-444-9999. I found my true love one on a blue sunday. Her number was (333)1232832', schema)" + ] + }, + { + "cell_type": "markdown", + "id": "5d3beab3-6dea-4301-8c59-ae1685830afa", + "metadata": {}, + "source": [ + "# Extracting Multiple Attributes\n", + "\n", + "Here examples are specified independently on a per attribute level.\n", + "\n", + "This is done for convenience and will sometimes work, even though individually specified examples can be contradictory (as is the case for first and last name below!" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "3bb910e9-43c4-42dd-83dd-546b7df6e805", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "schema = Object(\n", + " id='personal_info', description='Personal information about a given person.',\n", + " attributes=[\n", + " Text(id='first_name', description='The first name of the person', examples=[('John Smith went to the store', 'John')]),\n", + " Text(id='last_name', description='The last name of the person', examples=[('John Smith went to the store', 'Smith')]),\n", + " Number(id='age', description='The age of the person in years.', examples=[('23 years old', '23'), ('I turned three on sunday', '3')]), \n", + " ],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8be30aa6-a095-4506-8ec9-bde84dd107d3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{}" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My phone number is (123)-444-9999. I found my true love one on a blue sunday. Her number was (333)1232832', schema)" + ] + }, + { + "cell_type": "markdown", + "id": "0c2ee348-4908-4deb-859c-e860309a77f9", + "metadata": {}, + "source": [ + "## The model sometimes doesn't follow the schema\n", + "\n", + "One would have to create a validation layer on top (doesn't exist yet)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "611ed04e-e8be-4941-9745-d62000eb56bd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'personal_info': [{'first_name': ['Bob']},\n", + " {'phone_number': ['(123)-444-9999']},\n", + " {'phone_number': ['(333)1232832']}]}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My name is Bob and my phone number is (123)-444-9999. I found my true love one on a blue sunday. Her number was (333)1232832', schema)" + ] + }, + { + "cell_type": "markdown", + "id": "ea520ab9-38ae-4fc2-ab95-468df85f200e", + "metadata": {}, + "source": [ + "## But adding more examples helps prevent hallucinations" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "9a300e76-2f26-4914-b160-1d90548714a0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "schema = Object(\n", + " id='personal_info', description='Personal information about a given person.',\n", + " attributes=[\n", + " Text(id='first_name', description='The first name of the person', examples=[('John Smith went to the store', 'John')]),\n", + " Text(id='last_name', description='The last name of the person', examples=[('John Smith went to the store', 'Smith')]),\n", + " Number(id='age', description='The age of the person in years.', examples=[('23 years old', '23'), ('I turned three on sunday', '3')]), \n", + " ],\n", + " examples=[('John Smith was 23 years old', {'first_name': \"John\", 'last_name': \"Smith\", 'age': '23'})]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "582b5624-332b-412e-ae22-9298d8563ea5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'personal_info': [{'first_name': ['Bob']}]}" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('My name is Bob and my phone number is (123)-444-9999. I found my true love one on a blue sunday. Her number was (333)1232832', schema)" + ] + }, + { + "cell_type": "markdown", + "id": "8d5e4bcc-1fe5-4bdd-b1e2-00e0c7577cbb", + "metadata": {}, + "source": [ + "### What's the actual prompt?" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "a2944e8c-4630-4b29-b505-b2ca6fceba01", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. \n", + "\n", + "```TypeScript\n", + "\n", + "personal_info: {\n", + " first_name: string[] // The first name of the person\n", + " last_name: string[] // The last name of the person\n", + " age: number[] // The age of the person in years.\n", + "}\n", + "```\n", + "\n", + "\n", + "For Union types the output must EXACTLY match one of the members of the Union type.\n", + "\n", + "Please enclose the extracted information in HTML style tags with the tag name corresponding to the corresponding component ID. Use angle style brackets for the tags ('>' and '<'). Only output tags when you're confident about the information that was extracted from the user's query. If you can extract several pieces of relevant information from the query, then include all of them. If the type is an array, please repeat the corresponding tag name multiple times once for each relevant extraction. \n", + "\n", + "Input: John Smith was 23 years old\n", + "Output: 23JohnSmith\n", + "Input: John Smith went to the store\n", + "Output: John\n", + "Input: John Smith went to the store\n", + "Output: Smith\n", + "Input: 23 years old\n", + "Output: 23\n", + "Input: I turned three on sunday\n", + "Output: 3\n", + "Input: user input goes here\n", + "Output:\n" + ] + } + ], + "source": [ + "print(model.prompt_generator.format_as_string('user input goes here', schema))" + ] + }, + { + "cell_type": "markdown", + "id": "b5797171-a806-474d-93b8-d569d0355a30", + "metadata": {}, + "source": [ + "## More complex prompt\n", + "\n", + "Same schema but more complex prompt.\n", + "\n", + "Please note that Alice's age is in days not in years!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f1dece-3a0f-4507-8e8f-4c18cbc4b5fd", + "metadata": {}, + "outputs": [], + "source": [ + "user_input = (\n", + " 'Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", + " 'Bob is turning 10 years old. His eyes were bright red. Chris Prass used his '\n", + " 'green eyes to look at Dorothy to find 15 year old eyes staring back at him. '\n", + " 'Prass was a piano teacher. Dorothy was a certified mechanic. '\n", + " 'All certified mechanics have yellow eyes.'\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa43e654-df0c-4e60-8eab-f4d4927e368b", + "metadata": {}, + "outputs": [], + "source": [ + "schema = Object(\n", + " id='personal_info', description='Personal information about a given person.',\n", + " attributes=[\n", + " Text(id='first_name', description='The first name of the person', examples=[('John Smith went to the store', 'John')]),\n", + " Text(id='last_name', description='The last name of the person', examples=[('John Smith went to the store', 'Smith')]),\n", + " Number(id='age', description='The age of the person in years.', examples=[('23 years old', '23'), ('I turned three on sunday', '3')]), \n", + " ],\n", + " examples=[('John Smith was 23 years old', {'first_name': \"John\", 'last_name': \"Smith\", 'age': '23'})]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d4a342f-cca5-49a7-ab44-2411cc47c0be", + "metadata": {}, + "outputs": [], + "source": [ + "# Note that Alice age is reported in days above\n", + "model(user_input, schema)" ] }, { "cell_type": "markdown", - "id": "7db26299-a752-44c4-b20d-c338b90c6d15", + "id": "f2464529-fa6f-49f9-8773-2661675b1feb", + "metadata": {}, + "source": [ + "# More complex schema" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "4d9e07be-051d-46f5-9173-025f491fb99e", "metadata": { "tags": [] }, + "outputs": [], "source": [ - "## Temporary hack to add kor to PYTHONPATH" + "schema = Object(\n", + " id='personalinfo', \n", + " description='Collect information about a person.',\n", + " attributes=[\n", + " Text(\n", + " id='profession',\n", + " description='The person\\'s profession?',\n", + " examples=[('He was a professor', 'professor'), ('Bob was a lawyer and a politician', ['lawyer', 'politician'])]\n", + " ),\n", + " Text(\n", + " id='first_name',\n", + " description='The person\\'s first name',\n", + " examples=[('Billy was here', 'Billy'), ('Bob was very tall', 'Bob')]\n", + " ),\n", + " Text(\n", + " id='last_name',\n", + " description='The person\\'s last name',\n", + " examples=[('Joe Donatello was very tall', 'Donatello')]\n", + " ),\n", + " Text(\n", + " id='eye_color', \n", + " description='The person\\'s eye color?',\n", + " examples=[('my eyes are green', 'green')]\n", + " ),\n", + " Object(\n", + " id='age', \n", + " attributes=[\n", + " Number(\n", + " id='number',\n", + " description='what is the person\\'s age?',\n", + " examples=[('10 years old', 10)],\n", + " ),\n", + " Text(\n", + " id='unit',\n", + " description='In which units is the age reported in?',\n", + " examples=[('10 years old', 'years'), ('22 days', 'days')]\n", + " ),\n", + " ]\n", + " \n", + " )\n", + " ],\n", + " examples = []\n", + ")" ] }, { "cell_type": "code", - "execution_count": 1, - "id": "fe40bbba-e7ec-4cf9-ab79-16139e6e7e94", + "execution_count": 45, + "id": "89ad6a27-cf0c-4da8-8ff7-ae3f787e29fb", "metadata": { "tags": [] }, "outputs": [], "source": [ - "%load_ext autoreload\n", - "%autoreload 2" + "user_input = (\n", + " 'Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", + " 'Bob is turning 10 years old. His eyes were bright red. Chris Prass used his '\n", + " 'green eyes to look at Dorothy to find 15 year old eyes staring back at him. '\n", + " 'Prass was a piano teacher. Dorothy was a certified mechanic. '\n", + " 'All certified mechanics have yellow eyes.'\n", + ")\n" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "ea602017-6b8a-4b42-850a-686cae35809f", + "execution_count": 48, + "id": "93668b59-6e8c-4ad8-abf9-8cba62ccc2b7", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'personalinfo': [{'first_name': ['Alice'],\n", + " 'last_name': ['MacDonald'],\n", + " 'age': [{'number': ['60'], 'unit': ['days']}],\n", + " 'eye_color': ['blue']},\n", + " {'first_name': ['Bob'],\n", + " 'age': [{'number': ['10'], 'unit': ['years']}],\n", + " 'eye_color': ['bright red']},\n", + " {'first_name': ['Chris'],\n", + " 'last_name': ['Prass'],\n", + " 'profession': ['piano teacher'],\n", + " 'eye_color': ['green']},\n", + " {'first_name': ['Dorothy'],\n", + " 'profession': ['certified mechanic'],\n", + " 'eye_color': ['yellow']},\n", + " {'age': [{'number': ['15'], 'unit': ['year']}],\n", + " 'eye_color': ['staring back']}]}" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import sys\n", - "import pprint" + "# Note that Alice age is reported in days above\n", + "model(user_input, schema)" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "ae7d98ba-5ca8-4ce5-b54a-ef100affa9f5", + "execution_count": 53, + "id": "ec341db6-ab1f-456d-94e1-ede43b1c0381", "metadata": { "tags": [] }, "outputs": [], "source": [ - "sys.path.insert(0, '../../')" + "FROM_ADDRESS = Object(\n", + " id=\"from_address\",\n", + " description=\"Person moved away from this address\",\n", + " attributes=[\n", + " Text(id=\"street\"),\n", + " Text(id=\"city\"),\n", + " Text(id=\"state\"),\n", + " Text(id=\"zipcode\"),\n", + " Text(id=\"country\", description=\"A country in the world; e.g., France.\"),\n", + " ],\n", + " examples=[\n", + " (\n", + " \"100 Main St, Boston,MA, 23232, USA\",\n", + " {\n", + " \"street\": \"100 Marlo St\",\n", + " \"city\": \"Boston\",\n", + " \"state\": \"MA\",\n", + " \"zipcode\": \"23232\",\n", + " \"country\": \"USA\",\n", + " },\n", + " )\n", + " ],\n", + ")" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "718c66a7-6186-4ed8-87e9-5ed28e3f209e", + "execution_count": 54, + "id": "b333b284-9618-44b0-a2e8-69b1939b2392", "metadata": { "tags": [] }, "outputs": [], "source": [ - "from kor.elements import Form, TextInput, NumericRange, Number, Selection, Option, ObjectInput\n", - "from kor.extraction import extract, chat_extract\n", - "from kor.llm_utils import LLM, ChatLLM, ChatLLMWithChatInvoke" + "TO_ADDRESS = FROM_ADDRESS.replace(id='to_address', description='Address to which the person is moving')" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "3c0f282c-19b5-454d-a734-6f1cb4429d06", + "execution_count": 55, + "id": "2311ea1c-1f7f-4027-a22a-bd46b905d126", "metadata": { "tags": [] }, "outputs": [], "source": [ - "llm = ChatLLMWithChatInvoke(verbose=False)" + "form = Object(\n", + " id='information',\n", + " attributes=[\n", + " Text(id='person_name', description='The full name of the person or partial name', examples=[('John Smith was here', 'John Smith')]),\n", + " FROM_ADDRESS,\n", + " TO_ADDRESS,\n", + " ]\n", + ")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "16aa93c3-ac85-4f0a-98f8-92b23d1d934f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'information': [{'person_name': ['Alice Doe', 'Bob Smith'],\n", + " 'from_address': [{'city': ['New York']}],\n", + " 'to_address': [{'city': ['Boston']}]}]}" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model('Alice Doe and Bob Smith moved from New York to Boston', form)" ] }, { "cell_type": "markdown", - "id": "58ecdb75-7b61-4676-bd12-9f86ff523441", + "id": "cfc92a23-56c2-466f-a89d-61f8cb5f10d5", "metadata": {}, "source": [ - "## Collect personal information" + "## LIMITATION! Currently, grouping correctly is difficult due to ambiguity\n", + "\n", + "Because every type in Kor could be interpreted as a list.\n", + "At the moment, one should specify object level examples, to help the model determine how to group things correctly." ] }, { "cell_type": "code", - "execution_count": 11, - "id": "0c611670-2678-4bc3-9ce2-42a181896194", + "execution_count": 60, + "id": "4397e231-b32d-45a3-89a4-f8dcc3672a1a", "metadata": { "tags": [] }, "outputs": [], "source": [ - "form = Form(\n", - " id='personal-info', \n", - " description='Collect information about a person.',\n", - " elements=[\n", - " TextInput(\n", - " id='profession',\n", - " description='what is the person\\'s profession?',\n", - " examples=[('He was a piano teacher', 'piano teacher'), ('Bob was a lawyer and a politician', ['lawyer', 'politician'])]\n", - " ),\n", - " TextInput(\n", - " id='first_name',\n", - " description='what is the person\\'s first name',\n", - " examples=[('Billy was here', 'Billy'), ('Bob Donatello was very tall', 'Bob')]\n", - " ),\n", - " TextInput(\n", - " id='last_name',\n", - " description='what is the person\\'s last name',\n", - " examples=[('Billy was here', ''), ('Bob Donatello was very tall', 'Donatello')]\n", - " ),\n", - " Number(\n", - " id='age',\n", - " description='what is the person\\'s age',\n", - " examples=[('26 years old', '26 years'), ('6 puppies', '') ],\n", - " ),\n", - " Selection(\n", - " id='eye-color', \n", - " description='What is the person\\'s eye color?',\n", - " options=[\n", - " Option(id='green', description='Green Eyes', examples=['my eyes are green']),\n", - " Option(id='blue',description='Blue Eyes', examples=['blue eyes']),\n", - " Option(id='brown',description='Brown Eyes', examples=['brown eye color']),\n", - " ],\n", - " null_examples=['violet eyes']\n", + "form = Object(\n", + " id='information',\n", + " attributes=[\n", + " Text(id='person_name', description='The full name of the person or partial name', examples=[('John Smith was here', 'John Smith')]),\n", + " FROM_ADDRESS,\n", + " TO_ADDRESS,\n", + " ],\n", + " examples=[\n", + " (\"John Smith moved to Boston from New York. Billy moved to LA.\", \n", + " [\n", + " {\"person_name\": \"John Smith\", \"from_address\": { \"city\": \"New York\" }, \"to_address\": { \"city\": \"Boston\" }}, \n", + " {\"person_name\": \"Billy\", \"to_address\": { \"city\": \"LA\" }}, \n", + " ]\n", " )\n", " ]\n", - ")" + ")\n", + " " ] }, { "cell_type": "code", - "execution_count": 7, - "id": "c8d6ce1e-68e0-4226-ad8f-f45cf64f5190", + "execution_count": 63, + "id": "63f1fa72-bf97-4b09-878b-792e8cbcbb0f", "metadata": { "tags": [] }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{ 'personal-info': [ { 'age': ['60 days'],\n", - " 'first_name': ['Alice'],\n", - " 'last_name': ['MacDonald']}]}\n", - "CPU times: user 61.4 ms, sys: 3.62 ms, total: 65 ms\n", - "Wall time: 1.69 s\n" - ] + "data": { + "text/plain": [ + "{'information': [{'from_address': [{'city': ['New York']}],\n", + " 'person_name': ['Alice Doe'],\n", + " 'to_address': [{'city': ['Boston']}]},\n", + " {'from_address': [{'city': ['New York']}],\n", + " 'person_name': ['Bob Smith'],\n", + " 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Andrew']},\n", + " {'person_name': ['Joana'], 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Paul'], 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Betty'],\n", + " 'from_address': [{'city': ['Boston']}],\n", + " 'to_address': [{'city': ['New York']}]}]}" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "%%time\n", - "pprint.pprint(chat_extract('Today Alice MacDonald is turning sixty days old.', form, llm), indent=2)" + "model('Alice Doe and Bob Smith moved from New York to Boston. Andrew was 12 years old. He also moved to Boston. So did Joana and Paul. Betty did the opposite.', form)" ] }, { "cell_type": "markdown", - "id": "29d6e0ff-93d6-4783-8866-0ea91fd021d2", + "id": "e382f417-b281-4567-8cac-d205291de3a4", "metadata": {}, "source": [ - "**ATTENTION** At the moment, parsed information will be collected independently into lists if using a flat form. (Might change the API in a bit since it feels like this will be a source of confusion.)" + "## Untyped Obects\n", + "\n", + "One does not have to type an object. And can instead rely on just the examples.\n", + "\n", + "It's unclear 🤷🤷🤷 if the quality of the results is affected significantly, if one controls for the number of examples. " + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "4325ac08-248f-4c57-bca9-478f8cab0436", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "form = Object(\n", + " id='information',\n", + " attributes=[],\n", + " examples=[\n", + " (\"John Smith moved to Boston from New York. Billy moved to LA.\", \n", + " [\n", + " {\"person_name\": \"John Smith\", \"from_address\": { \"city\": \"New York\" }, \"to_address\": { \"city\": \"Boston\" }}, \n", + " {\"person_name\": \"Billy\", \"to_address\": { \"city\": \"LA\" }}, \n", + " ]\n", + " )\n", + " ]\n", + ")\n", + " " ] }, { "cell_type": "code", - "execution_count": 8, - "id": "ed6a1f62-c35b-4f91-b820-880b12ce9e85", + "execution_count": 75, + "id": "82c52f44-657e-45fb-8dc0-80155ae63a86", "metadata": { "tags": [] }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{ 'personal-info': [ { 'age': ['60 days'],\n", - " 'eye-color': ['blue'],\n", - " 'first_name': ['Alice'],\n", - " 'last_name': ['MacDonald']},\n", - " {'age': ['10 years'], 'first_name': ['Bob']}]}\n", - "CPU times: user 5.42 ms, sys: 143 µs, total: 5.57 ms\n", - "Wall time: 1.82 s\n" - ] + "data": { + "text/plain": [ + "{'information': [{'from_address': [{'city': ['New York']}],\n", + " 'person_name': ['Alice Doe'],\n", + " 'to_address': [{'city': ['Boston']}]},\n", + " {'from_address': [{'city': ['New York']}],\n", + " 'person_name': ['Bob Smith'],\n", + " 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Andrew'],\n", + " 'age': ['12'],\n", + " 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Joana'], 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Paul'], 'to_address': [{'city': ['Boston']}]},\n", + " {'person_name': ['Betty'], 'to_address': [{'city': ['New York']}]}]}" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "%%time\n", - "pprint.pprint(\n", - " chat_extract('Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", - " 'Bob is turning 10 years old. His eyes were bright red.', form, llm), \n", - "indent=2)" + "model('Alice Doe and Bob Smith moved from New York to Boston. Andrew was 12 years old. He also moved to Boston. So did Joana and Paul. Betty did the opposite.', form)" ] }, { "cell_type": "markdown", - "id": "c88a8104-78de-4d77-a18c-2773cf78d93b", + "id": "3c92aa32-25bd-416f-ba71-3506455cb68a", "metadata": {}, "source": [ - "## Using an Object input" + "# Flattened Objects" ] }, { "cell_type": "code", - "execution_count": 9, - "id": "b66bf239-1031-4f70-b8a8-3a92358b22d9", + "execution_count": 70, + "id": "df8f7a85-2308-4e50-bcf8-b31c9c5bc882", "metadata": { "tags": [] }, "outputs": [], "source": [ - "form = Form(\n", - " id='personal', \n", - " description='Collect information about a person.',\n", - " elements=[\n", - " ObjectInput(\n", - " id='info', \n", - " description='Personal information about a person like name, age, hobbies, date of birth, height etc.', \n", - " examples=[(\n", - " 'Billy Apple was born on 2020-01-01', {\n", - " 'first_name': 'Billy',\n", - " 'last_name': 'Apple', \n", - " 'born_on': '2020-01-01'\n", - " },\n", - " ), \n", - " (\n", - " 'Frank was born on 2020-01-01 and is 2 years old today', {\n", - " 'first_name': 'Frank', \n", - " 'born_on': '2020-01-01',\n", - " 'age': '2 years old',\n", - " }\n", - " )\n", - " ]\n", - " )\n", - " ]\n", - ")" + "form = Object(\n", + " id='information',\n", + " attributes=[\n", + " Text(id='person_name', description='The first name of the person or partial name', examples=[('John Smith was here', 'John')]),\n", + " Text(id='last_name', description='The last name of the person or partial name', examples=[('John Smith was here', 'Smith')]),\n", + " ],\n", + " examples=[],\n", + " group_as_object=False # <-- Please note\n", + ")\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "5aefd7ff-526e-45f4-9fe0-e3f0586dfee1", + "metadata": {}, + "source": [ + "## Let's build an API!" ] }, { "cell_type": "markdown", - "id": "29263531-2fba-4518-9b4d-964b6e9b9b66", + "id": "4a90e50e-aa1b-45c6-920b-4589b424e561", "metadata": {}, "source": [ - "If outputs fail for other sentences, try to add some more examples." + "### Order tickets?" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "e12c4842-0853-45bb-bd30-46903056dc95", + "execution_count": 81, + "id": "c50b080b-7179-4bbe-b234-83ce59e2d215", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "form = Object(\n", + " id='action', \n", + " description='User is looking for sports tickets',\n", + " attributes=[\n", + " Text(\n", + " id='sport',\n", + " description='which sports do you want to buy tickets for?',\n", + " examples=[('I want to buy tickets to basketball and football games', ['basketball', 'footbal'])]\n", + " ),\n", + " Text(\n", + " id='location',\n", + " description='where would you like to watch the game?',\n", + " examples=[('in boston', 'boston'), ('in france or italy', ['france', 'italy'])]\n", + " ),\n", + " \n", + " Object(\n", + " id='price_range',\n", + " description='how much do you want to spend?',\n", + " attributes=[],\n", + " examples=[('no more than $100', {'price_max': '100', 'currency': \"$\"}), \n", + " ('between 50 and 100 dollars', {'price_max': '100', 'price_min': '50', 'currency': \"$\"})]\n", + " ),\n", + " ]\n", + "\n", + " \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "73c31ace-32dd-4a33-ae39-475db6934f6d", "metadata": { "tags": [] }, @@ -269,41 +1068,32 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'personal-info': [ { 'age': ['60 days'],\n", - " 'eye-color': ['blue'],\n", - " 'first_name': ['Alice'],\n", - " 'last_name': ['MacDonald']},\n", - " { 'age': ['10 years'],\n", - " 'eye-color': ['bright red'],\n", - " 'first_name': ['Bob']},\n", - " { 'eye-color': ['green'],\n", - " 'first_name': ['Chris'],\n", - " 'last_name': ['Prass'],\n", - " 'profession': ['piano teacher']},\n", - " { 'age': ['less than 2 years'],\n", - " 'eye-color': ['yellow'],\n", - " 'first_name': ['Dorothy'],\n", - " 'profession': ['certified mechanic', 'chef']}]}\n", - "CPU times: user 6.61 ms, sys: 202 µs, total: 6.81 ms\n", - "Wall time: 5.54 s\n" + "CPU times: user 11.8 ms, sys: 4.02 ms, total: 15.8 ms\n", + "Wall time: 2.01 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'action': [{'sport': ['baseball'],\n", + " 'location': ['LA'],\n", + " 'price_range': [{'currency': ['$'], 'price_max': ['100']}]}]}" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(\n", - " chat_extract('Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", - " 'Bob is turning 10 years old. His eyes were bright red. Chris Prass used his '\n", - " 'green eyes to look at Dorothy to find 15 year old eyes staring back at him. '\n", - " 'Prass was a piano teacher. Dorothy was a certified mechanic and a chef. Dorothy was not even 2 years old. ' \n", - " 'All certified mechanics have yellow eyes.', form, llm), \n", - "indent=2)" + "model('I want to buy tickets for a baseball game in LA area under $100', form)" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "d7c7c74c-1ec0-41fe-8a5a-a099fb6c75e8", + "execution_count": 85, + "id": "78e3b3af-bfa8-4503-854a-b83a7f8f49e6", "metadata": { "tags": [] }, @@ -312,84 +1102,84 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'personal': [ { 'info': [ { 'age': ['60 days old'],\n", - " 'eye_color': ['blue'],\n", - " 'first_name': ['Alice'],\n", - " 'last_name': ['MacDonald']}]},\n", - " { 'info': [ { 'age': ['10 years old'],\n", - " 'eye_color': ['bright red'],\n", - " 'first_name': ['Bob']}]},\n", - " { 'info': [ { 'age': ['15 years old'],\n", - " 'eye_color': ['yellow'],\n", - " 'first_name': ['Dorothy']}]},\n", - " { 'info': [ { 'eye_color': ['green'],\n", - " 'first_name': ['Chris'],\n", - " 'last_name': ['Prass'],\n", - " 'profession': ['piano teacher']}]}]}\n", - "CPU times: user 7.33 ms, sys: 184 µs, total: 7.51 ms\n", - "Wall time: 4.82 s\n" + "CPU times: user 3.24 ms, sys: 459 µs, total: 3.7 ms\n", + "Wall time: 2.27 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'action': [{'sport': ['basketball'],\n", + " 'location': ['boston'],\n", + " 'price_range': [{'currency': ['$'],\n", + " 'price_max': ['40'],\n", + " 'price_min': ['20']}]}]}" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(\n", - " chat_extract('Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", - " 'Bob is turning 10 years old. His eyes were bright red. Chris Prass used his '\n", - " 'green eyes to look at Dorothy to find 15 year old eyes staring back at him. '\n", - " 'Prass was a piano teacher. Dorothy was a certified mechanic. ' \n", - " 'All certified mechanics have yellow eyes.', form, llm), \n", - "indent=2)" + "model('I want to see a celtics game in boston somewhere between 20 and 40 dollars per ticket', form)" ] }, { "cell_type": "markdown", - "id": "5aefd7ff-526e-45f4-9fe0-e3f0586dfee1", + "id": "e87b70f3-d4eb-4d82-b0bc-ca43abfe8e96", "metadata": {}, "source": [ - "## Buying sports tickets" + "### Play some music?" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "c50b080b-7179-4bbe-b234-83ce59e2d215", + "execution_count": 88, + "id": "cf7c3511-dcc4-4d5d-95e4-9dca7e5f8abf", "metadata": { "tags": [] }, "outputs": [], "source": [ - "form = Form(\n", - " id='action', \n", - " description='User is looking for sports tickets',\n", - " elements=[\n", - " TextInput(\n", - " id='sport',\n", - " description='which sports do you want to buy tickets for?',\n", - " examples=[('I want to buy tickets to basketball and football games', ['basketball', 'footbal'])]\n", + "form = Object(\n", + " id='player', \n", + " description='User is controling a music player to select songs, pause or start them or play music by a particular artist.',\n", + " attributes=[\n", + " Text(\n", + " id='song',\n", + " description='User wants to play this song',\n", + " examples=[]\n", " ),\n", - " TextInput(\n", - " id='location',\n", - " description='where would you like to watch the game?',\n", - " examples=[('in boston', 'boston'), ('in france or italy', ['france', 'italy'])]\n", + " Text(\n", + " id='album',\n", + " description='User wants to play this album',\n", + " examples=[]\n", " ),\n", - " \n", - " ObjectInput(\n", - " id='price-range',\n", - " description='how much do you want to spend?',\n", - " examples=[('no more than $100', {'price-max': '100', 'currency': \"$\"}), \n", - " ('between 50 and 100 dollars', {'price-max': '100', 'price-min': '50', 'currency': \"$\"})]\n", + " Text(\n", + " id='artist',\n", + " description='Music by the given artist',\n", + " examples=[('Songs by paul simon', 'paul simon')],\n", + " ),\n", + " Text(\n", + " id='stop_playing', \n", + " description=\"STOP if the user wants to stop playing music.\",\n", + " examples=[('Please stop the music', 'stop'), ('please keep playing', '')]\n", " ),\n", + " Text(\n", + " id='start_playing', \n", + " description=\"START if the user wants to play music.\",\n", + " examples=[('play something', 'start'), ('please stop', '')]\n", + " )\n", " ]\n", - "\n", - " \n", ")" ] }, { "cell_type": "code", - "execution_count": 11, - "id": "73c31ace-32dd-4a33-ae39-475db6934f6d", + "execution_count": 87, + "id": "992fb598-a32d-406c-b2f1-315bf11615e0", "metadata": { "tags": [] }, @@ -398,52 +1188,147 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'location': ['LA area'],\n", - " 'price-range': [{'currency': ['$'], 'price-max': ['100']}],\n", - " 'sport': ['baseball']}\n", - "CPU times: user 4.1 ms, sys: 0 ns, total: 4.1 ms\n", - "Wall time: 2.13 s\n" + "CPU times: user 11.8 ms, sys: 4.02 ms, total: 15.8 ms\n", + "Wall time: 1.3 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'player': [{'stop_playing': ['stop']}]}" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(extract('I want to buy tickets for a baseball game in LA area under $100', form, llm), indent=2)" + "model('stop the music now', form)" ] }, { - "cell_type": "markdown", - "id": "20ddd000-df00-4014-9433-fba85181ba46", - "metadata": {}, + "cell_type": "code", + "execution_count": 90, + "id": "266f2c81-4b93-41df-ad2f-4c0645d2eae0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 4.14 ms, sys: 8 µs, total: 4.15 ms\n", + "Wall time: 1.04 s\n" + ] + }, + { + "data": { + "text/plain": [ + "{'player': [{'start_playing': ['start']}]}" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "## More complex sentence\n", - "\n", - "Use an LLM to parse a sentence and later convert it into a database query." + "%%time\n", + "model('i want to hear a song', form)" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "ca1a8200-5da8-48ac-8601-4fc7ff00d960", + "execution_count": 91, + "id": "23328a30-bcb9-4c30-a5cb-6e1935781f0b", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.17 ms, sys: 0 ns, total: 3.17 ms\n", + "Wall time: 985 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "{'player': [{'album': ['lion king']}]}" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "model('can you play the album lion king from the movie', form)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "a5580e02-71d7-4f39-be12-6665e0e776f3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.26 ms, sys: 45 µs, total: 3.3 ms\n", + "Wall time: 2.56 s\n" + ] + }, + { + "data": { + "text/plain": [ + "{'player': [{'artist': ['paul simon', 'led zepplin']}]}" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "model('can you play all the songs from paul simon and led zepplin', form)" + ] + }, + { + "cell_type": "markdown", + "id": "20ddd000-df00-4014-9433-fba85181ba46", + "metadata": {}, "source": [ - "from kor import elements" + "## Issue some database queries?\n", + "\n", + "Please note that this is a demo about how to build complexity.\n", + "\n", + "This particular format is actually *NOT* good for issuing database queries.\n", + "\n", + "I may publish another package showing how to do this properly for things like database queries." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 103, "id": "2b0bcf09-a3ae-4a8a-9ce3-f86834ce6ca2", "metadata": { "tags": [] }, "outputs": [], "source": [ - "company_name = elements.TextInput(\n", - " id=\"company-name\",\n", + "company_name = Text(\n", + " id=\"company_name\",\n", " description=\"what is the name of the company you want to find\",\n", " examples=[\n", " (\"Apple inc\", \"Apple inc\"),\n", @@ -452,8 +1337,8 @@ " ],\n", ")\n", "\n", - "industry_name = elements.TextInput(\n", - " id=\"industry-name\",\n", + "industry_name = Text(\n", + " id=\"industry_name\",\n", " description=\"what is the name of the company's industry\",\n", " examples=[\n", " (\"companies in the steel manufacturing industry\", \"steel manufacturing\"),\n", @@ -464,8 +1349,8 @@ " ],\n", ")\n", "\n", - "geography_name = elements.TextInput(\n", - " id=\"geography-name\",\n", + "geography_name = Text(\n", + " id=\"geography_name\",\n", " description=\"where is the company based?\",\n", " examples=[\n", " (\"chinese companies\", \"china\"),\n", @@ -475,14 +1360,14 @@ " ],\n", ")\n", "\n", - "foundation_date = elements.DateInput(\n", - " id=\"foundation-date\",\n", + "foundation_date = Text(\n", + " id=\"foundation_date\",\n", " description=\"Foundation date of the company\",\n", " examples=[(\"companies founded in 2023\", \"2023\")],\n", ")\n", "\n", - "attribute_filter = elements.ObjectInput(\n", - " id=\"attribute-filter\",\n", + "attribute_filter = Text(\n", + " id=\"attribute_filter\",\n", " description=(\n", " \"Filter by a value of an attribute using a binary expression. Specify the attribute's name, \"\n", " \"an operator (>, <, =, !=, >=, <=, in, not in) and a value.\"\n", @@ -519,8 +1404,8 @@ " ],\n", ")\n", "\n", - "sales_geography = elements.TextInput(\n", - " id=\"geography-sales\",\n", + "sales_geography = Text(\n", + " id=\"geography_sales\",\n", " description=(\n", " \"where is the company doing sales? Please use a single country name.\"\n", " ),\n", @@ -531,7 +1416,7 @@ " ],\n", ")\n", "\n", - "attribute_selection_block = elements.TextInput(\n", + "attribute_selection_block = Text(\n", " id=\"attribute_selection\",\n", " description=\"Asking to see the value of one or more attributes\",\n", " examples=[\n", @@ -543,12 +1428,16 @@ " ],\n", ")\n", "\n", - "sort_by_attribute_block = elements.ObjectInput(\n", - " id=\"sort-block\",\n", + "sort_by_attribute_block = Object(\n", + " id=\"sort_block\",\n", " description=(\n", " \"Use to request to sort the results by a particular attribute. \"\n", " \"Can specify the direction\"\n", " ),\n", + " attributes=[\n", + " Text(id='direction', description='The direction of the sort'),\n", + " Text(id='attribute', description='The sort attribute')\n", + " ],\n", " examples=[\n", " (\n", " \"Largest by market-cap tech companies\",\n", @@ -561,10 +1450,10 @@ " ],\n", ")\n", "\n", - "form = elements.Form(\n", - " id=\"search-for-companies\",\n", + "form = Object(\n", + " id=\"search_for_companies\",\n", " description=\"Search for companies matching the following criteria.\",\n", - " elements=[\n", + " attributes=[\n", " company_name,\n", " geography_name,\n", " foundation_date,\n", @@ -577,9 +1466,29 @@ ")\n" ] }, + { + "cell_type": "markdown", + "id": "ee6725a8-246b-4163-a657-5f3eddbf5d2b", + "metadata": {}, + "source": [ + "# Note some of the examplesbelow could fail" + ] + }, + { + "cell_type": "markdown", + "id": "4f585aa0-071c-40ab-b445-0616be92f430", + "metadata": {}, + "source": [ + "Please note that some of the queries below could *fail* for different reasons.\n", + "\n", + "One common reason with complex objects is ambiguity in how to group things together.\n", + "\n", + "These failures can be remedied by adding object level examples." + ] + }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 100, "id": "6aae8e17-bb1e-4f8d-94e6-4855f2077a26", "metadata": { "tags": [] @@ -589,23 +1498,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "{}\n", - "CPU times: user 3.5 ms, sys: 0 ns, total: 3.5 ms\n", - "Wall time: 790 ms\n" + "CPU times: user 14.6 ms, sys: 311 µs, total: 14.9 ms\n", + "Wall time: 1.74 s\n" ] + }, + { + "data": { + "text/plain": [ + "({'search_for_companies': [{}]},)" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(\n", - " extract('Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", - " 'Bob is turning 10 years old. His eyes were bright red.', form, llm), \n", - "indent=2)" + "model('Today Alice MacDonald is turning sixty days old. She had blue eyes. '\n", + " 'Bob is turning 10 years old. His eyes were bright red.', form), " ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 104, "id": "214daf5f-0026-4880-a066-06ecaf73b7b8", "metadata": { "tags": [] @@ -615,29 +1531,35 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'attribute-filter': [ { 'attribute': ['market cap'],\n", - " 'op': ['>'],\n", - " 'value': ['1 million']},\n", - " { 'attribute': ['employees'],\n", - " 'op': ['<'],\n", - " 'value': ['50']}],\n", - " 'attribute_selection': ['revenue', 'eps']}\n", - "CPU times: user 2.65 ms, sys: 4.03 ms, total: 6.68 ms\n", - "Wall time: 4.12 s\n" + "CPU times: user 4.21 ms, sys: 15 µs, total: 4.22 ms\n", + "Wall time: 5.12 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'search_for_companies': [{'attribute_filter': [{'attribute': ['market cap',\n", + " 'market cap',\n", + " 'employees'],\n", + " 'op': ['>', '<', 'in'],\n", + " 'value': ['1 million', '50', 'red', 'blue']}],\n", + " 'attribute_selection': ['revenue', 'eps']}]}" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(\n", - " extract('revenue, eps of indian companies that have market cap of over 1 million, '\n", - " 'but less than 50 employees and own red and blue buildings', form, llm\n", - " ), indent=2)" + "model('revenue, eps of indian companies that have market cap of over 1 million, '\n", + " 'but less than 50 employees and own red and blue buildings', form)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 105, "id": "d631862c-175e-42bb-83ab-1cfea635a684", "metadata": { "tags": [] @@ -647,27 +1569,33 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'attribute-filter': [ { 'attribute': ['market cap'],\n", - " 'op': ['>'],\n", - " 'value': ['1 million']},\n", - " { 'attribute': ['employees'],\n", - " 'op': ['in'],\n", - " 'value': ['20', '50']}],\n", - " 'attribute_selection': ['revenue', 'eps'],\n", - " 'geography-name': ['india']}\n", - "CPU times: user 5.81 ms, sys: 137 µs, total: 5.94 ms\n", - "Wall time: 4.37 s\n" + "CPU times: user 54 µs, sys: 3.53 ms, total: 3.58 ms\n", + "Wall time: 4.29 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'search_for_companies': [{'attribute_filter': [{'attribute': ['market cap'],\n", + " 'op': ['>'],\n", + " 'value': ['1 million']},\n", + " {'attribute': ['employees'], 'op': ['in'], 'value': ['20', '50']}],\n", + " 'attribute_selection': ['revenue', 'eps']}]}" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(extract('revenue, eps of indian companies that have market cap of over 1 million, and and between 20-50 employees', form, llm), indent=2)" + "model('revenue, eps of indian companies that have market cap of over 1 million, and and between 20-50 employees', form)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 106, "id": "2a620246-4c85-4256-8f58-0acbcc9455a3", "metadata": { "tags": [] @@ -677,22 +1605,31 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'attribute-filter': [ { 'attribute': ['building color'],\n", - " 'op': ['in'],\n", - " 'value': ['red', 'blue']}]}\n", - "CPU times: user 5.6 ms, sys: 0 ns, total: 5.6 ms\n", - "Wall time: 1.74 s\n" + "CPU times: user 3.38 ms, sys: 0 ns, total: 3.38 ms\n", + "Wall time: 2.28 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'search_for_companies': [{'attribute_filter': [{'attribute': ['building-colors'],\n", + " 'op': ['in'],\n", + " 'value': ['red', 'blue']}]}]}" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "%%time\n", - "pprint.pprint(extract('companies that own red and blue buildings', form, llm), indent=2)" + "model('companies that own red and blue buildings', form)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 108, "id": "4745517e-507e-4d1a-97e0-d143fa34cea2", "metadata": { "tags": [] @@ -702,18 +1639,27 @@ "name": "stdout", "output_type": "stream", "text": [ - "{ 'attribute_selection': ['revenue'],\n", - " 'geography-sales': ['germany'],\n", - " 'sort-block': [ { 'attribute': ['number of employees'],\n", - " 'direction': ['descending']}]}\n", - "CPU times: user 17.1 ms, sys: 0 ns, total: 17.1 ms\n", - "Wall time: 2.66 s\n" + "CPU times: user 3.74 ms, sys: 0 ns, total: 3.74 ms\n", + "Wall time: 2.64 s\n" ] + }, + { + "data": { + "text/plain": [ + "{'search_for_companies': [{'geography_name': ['germany'],\n", + " 'sort_block': [{'attribute': ['number of employees'],\n", + " 'direction': ['descending']}],\n", + " 'attribute_selection': ['revenue']}]}" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "b%%time\n", - "pprint.pprint(extract('revenue of largest german companies sorted by number of employees', form, llm), indent=2)" + "%%time\n", + "model('revenue of largest german companies sorted by number of employees', form)" ] } ], @@ -733,7 +1679,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.11.1" } }, "nbformat": 4, diff --git a/kor/experimental/blocks.py b/kor/experimental/blocks.py index a1137f3..846a5d2 100644 --- a/kor/experimental/blocks.py +++ b/kor/experimental/blocks.py @@ -60,7 +60,7 @@ PRICE = Object( id="price", description="The price of the item, including currency", - children=[ + attributes=[ Number(id="amount", description="The amount in digit format."), Text(id="currency", description="The currency."), ], diff --git a/kor/extraction.py b/kor/extraction.py index 37bbe1b..9d65c1b 100644 --- a/kor/extraction.py +++ b/kor/extraction.py @@ -1,22 +1,28 @@ -from typing import Callable, Mapping, Sequence +import abc +from typing import Union, Any from kor import nodes, prompts +from kor.llms import CompletionModel, ChatCompletionModel from kor.parsing import parse_llm_output -def extract( - user_input: str, - node: nodes.AbstractInput, - model: Callable[[str], str] | Callable[[Sequence[Mapping[str, str]]], str], - prompt_generator: prompts.PromptGenerator = prompts.STANDARD_EXTRACTION_TEMPLATE, - prompt_format: prompts.PROMPT_FORMAT = "string", -) -> dict[str, list[str]]: - """Extract information from the user input using the given form.""" - if prompt_format == "string": - chat_prompt = prompt_generator.format_as_string(user_input, node) - elif prompt_format == "openai-chat": - chat_prompt = prompt_generator.format_as_chat(user_input, node) - else: - raise NotImplementedError(f"Unknown prompt format {prompt_format}") - model_output = model(chat_prompt) - return parse_llm_output(model_output) +class Extractor(abc.ABC): + def __init__( + self, + model: Union[CompletionModel, ChatCompletionModel], + prompt_generator: prompts.PromptGenerator = prompts.STANDARD_EXTRACTION_TEMPLATE, + ) -> None: + self.model = model + self.prompt_generator = prompt_generator + + def __call__(self, user_input: str, node: nodes.AbstractInput) -> Any: + if isinstance(self.model, CompletionModel): + prompt = self.prompt_generator.format_as_string(user_input, node) + elif isinstance(self.model, ChatCompletionModel): + prompt = self.prompt_generator.format_as_chat(user_input, node) + else: + raise NotImplementedError( + f"Unsupported model interface for type {type(self.model)}." + ) + model_output = self.model(prompt) + return parse_llm_output(model_output) diff --git a/kor/llms/openai.py b/kor/llms/openai.py index 0678a4c..4d16153 100644 --- a/kor/llms/openai.py +++ b/kor/llms/openai.py @@ -33,10 +33,10 @@ def _set_openai_api_key_if_needed() -> None: class OpenAICompletion(CompletionModel): """Wrapper around OpenAI Completion endpoint.""" - model: str = "text-davinci-001" + model: str verbose: bool = False temperature: float = 0 - max_tokens: int = 1000 + max_tokens: int = 2000 frequency_penalty: float = 0 presence_penalty: float = 0 top_p: float = 1.0 @@ -68,7 +68,7 @@ def __call__(self, prompt: str) -> str: class OpenAIChatCompletion(ChatCompletionModel): """Wrapper around OpenAI Chat Completion endpoint.""" - model: str = "gpt-3.5-turbo" + model: str verbose: bool = False temperature: float = 0 max_tokens: int = 1000 diff --git a/kor/nodes.py b/kor/nodes.py index bcb8916..8375d10 100644 --- a/kor/nodes.py +++ b/kor/nodes.py @@ -2,7 +2,7 @@ import abc import dataclasses import re -from typing import Sequence, Mapping, Any, Generic, TypeVar +from typing import Sequence, Mapping, Any, Generic, TypeVar, Optional, Self # For now, limit what's allowed for identifiers. # The main constraints @@ -77,6 +77,17 @@ def accept(self, visitor: AbstractVisitor) -> Any: """Accept a visitor.""" raise NotImplementedError() + def replace( + self, id: Optional[str] = None, description: Optional[str] = None + ) -> Self: + """Wrapper around data-classes replace.""" + attributes = {} + if id: + attributes["id"] = id + if description: + attributes["description"] = description + return dataclasses.replace(self, **attributes) + @dataclasses.dataclass(frozen=True, kw_only=True) class ExtractionInput(AbstractInput, abc.ABC): diff --git a/kor/prompts.py b/kor/prompts.py index a1bd42c..4e176d1 100644 --- a/kor/prompts.py +++ b/kor/prompts.py @@ -113,7 +113,8 @@ def format_as_chat( prefix=( "Your goal is to extract structured information from the user's input that matches " "the form described below. " - "When extracting information please make sure it matches the type information exactly. " + "When extracting information please make sure it matches the type information exactly. Do " + "not add any attributes that do not appear in the schema shown below." ), type_descriptor="TypeScript", suffix=( diff --git a/kor/type_descriptors.py b/kor/type_descriptors.py index 304a88c..4016c99 100644 --- a/kor/type_descriptors.py +++ b/kor/type_descriptors.py @@ -10,18 +10,21 @@ class BulletPointTypeGenerator(AbstractVisitor[None]): + """Mutable visitor used to generate a bullet point style schema description.""" + def __init__(self) -> None: - """Use to print the type.""" self.depth = 0 self.type_str_messages = [] def visit_default(self, node: "AbstractInput") -> None: + """Default action for a node.""" space = "* " + self.depth * " " self.type_str_messages.append( f"{space}{node.id}: {node.__class__.__name__} # {node.description}" ) def visit_object(self, node: Object) -> None: + """Visit an object node.""" self.visit_default(node) self.depth += 1 for child in node.attributes: @@ -34,11 +37,14 @@ def get_type_description(self) -> str: class TypeScriptTypeGenerator(AbstractVisitor[None]): + """A mutable visitor (not thread safe) that helps generate TypeScript schema.""" + def __init__(self) -> None: self.depth = 0 self.code_lines = [] def visit_default(self, node: "AbstractInput") -> None: + """Default action for a node.""" space = self.depth * " " if isinstance(node, Selection): @@ -53,10 +59,11 @@ def visit_default(self, node: "AbstractInput") -> None: raise NotImplementedError() self.code_lines.append( - f"{space}{node.id}: {finalized_type} // {node.description}" + f"{space}{node.id}: {finalized_type}[] // {node.description}" ) def visit_object(self, node: Object) -> None: + """Visit an object node.""" space = self.depth * " " self.code_lines.append(f"{space}{node.id}: {{") @@ -73,9 +80,16 @@ def get_type_description(self) -> str: return "\n".join(self.code_lines) def describe(self, node: "AbstractInput") -> str: + """Describe the node type in TypeScript notation.""" self.depth = 0 self.code_lines = [] + node.accept(self) + + # Add curly brackets if top level node is not an object. + if not isinstance(node, Object): + self.code_lines.insert(0, "{\n") + self.code_lines.append("}\n") return self.get_type_description() @@ -93,4 +107,4 @@ def generate_typescript_description(node: AbstractInput) -> str: """Generate a description of the object_input type in TypeScript syntax.""" code_generator = TypeScriptTypeGenerator() type_script_code = code_generator.describe(node) - return f"```TypeScript\n{type_script_code}\n```\n" + return f"```TypeScript\n\n{type_script_code}\n```\n"