From fc5b9844601e9318a1d68b7ea880ff9fa01fded7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 2 May 2023 11:01:39 -0400 Subject: [PATCH] update --- kor/adapters.py | 6 +-- kor/encoders/csv_data.py | 6 +-- kor/encoders/encode.py | 4 +- kor/encoders/typedefs.py | 4 +- kor/examples.py | 14 +++---- kor/nodes.py | 63 +++++++++++++++++++++------- kor/type_descriptors.py | 6 +-- tests/test_encoders/test_encoders.py | 4 +- tests/test_nodes.py | 8 ++-- tests/test_visitors.py | 8 ++-- 10 files changed, 77 insertions(+), 46 deletions(-) diff --git a/kor/adapters.py b/kor/adapters.py index 95836b4..ec8f8d4 100644 --- a/kor/adapters.py +++ b/kor/adapters.py @@ -15,7 +15,7 @@ from pydantic import BaseModel -from .nodes import ExtractionSchemaNode, Number, Object, Option, Selection, Text +from .nodes import ExtractionValueNode, Number, Object, Option, Selection, Text from .validators import PydanticValidator, Validator # Not going to support dicts or lists since that requires recursive checks. @@ -46,7 +46,7 @@ def _translate_pydantic_to_kor( The Kor internal representation of the model. """ - attributes: List[Union[ExtractionSchemaNode, Selection, "Object"]] = [] + attributes: List[Union[ExtractionValueNode, Selection, "Object"]] = [] for field_name, field in model_class.__fields__.items(): field_info = field.field_info extra = field_info.extra @@ -59,7 +59,7 @@ def _translate_pydantic_to_kor( type_ = field.type_ field_many = get_origin(field.outer_type_) is list - attribute: Union[ExtractionSchemaNode, Selection, "Object"] + attribute: Union[ExtractionValueNode, Selection, "Object"] # Precedence matters here since bool is a subclass of int if get_origin(type_) is Union: args = get_args(type_) diff --git a/kor/encoders/csv_data.py b/kor/encoders/csv_data.py index 7c1521d..1a160bd 100644 --- a/kor/encoders/csv_data.py +++ b/kor/encoders/csv_data.py @@ -13,12 +13,12 @@ from kor.encoders.typedefs import SchemaBasedEncoder from kor.encoders.utils import unwrap_tag, wrap_in_tag from kor.exceptions import ParseError -from kor.nodes import AbstractSchemaNode, Object +from kor.nodes import AbstractValueNode, Object DELIMITER = "|" -def _extract_top_level_fieldnames(node: AbstractSchemaNode) -> List[str]: +def _extract_top_level_fieldnames(node: AbstractValueNode) -> List[str]: """Temporary schema description for CSV extraction.""" if isinstance(node, Object): return [attributes.id for attributes in node.attributes] @@ -32,7 +32,7 @@ def _extract_top_level_fieldnames(node: AbstractSchemaNode) -> List[str]: class CSVEncoder(SchemaBasedEncoder): """CSV encoder.""" - def __init__(self, node: AbstractSchemaNode, use_tags: bool = False) -> None: + def __init__(self, node: AbstractValueNode, use_tags: bool = False) -> None: """Attach node to the encoder to allow the encoder to understand schema. Args: diff --git a/kor/encoders/encode.py b/kor/encoders/encode.py index 6decaba..1d8a20c 100644 --- a/kor/encoders/encode.py +++ b/kor/encoders/encode.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List, Literal, Mapping, Sequence, Tuple, Type, Union -from kor.nodes import AbstractSchemaNode +from kor.nodes import AbstractValueNode from .csv_data import CSVEncoder from .json_data import JSONEncoder @@ -67,7 +67,7 @@ def encode_examples( def initialize_encoder( encoder_or_encoder_class: Union[Type[Encoder], Encoder, str], - schema: AbstractSchemaNode, + schema: AbstractValueNode, **kwargs: Any, ) -> Encoder: """Flexible way to initialize an encoder, used only for top level API. diff --git a/kor/encoders/typedefs.py b/kor/encoders/typedefs.py index 8578c1a..283b718 100644 --- a/kor/encoders/typedefs.py +++ b/kor/encoders/typedefs.py @@ -9,7 +9,7 @@ import abc from typing import Any -from kor.nodes import AbstractSchemaNode +from kor.nodes import AbstractValueNode class Encoder(abc.ABC): @@ -46,6 +46,6 @@ class SchemaBasedEncoder(Encoder, abc.ABC): of the data that's being encoded. """ - def __init__(self, node: AbstractSchemaNode, **kwargs: Any) -> None: + def __init__(self, node: AbstractValueNode, **kwargs: Any) -> None: """Attach node to the encoder to allow the encoder to understand schema.""" self.node = node diff --git a/kor/examples.py b/kor/examples.py index 0d072c4..7d86506 100644 --- a/kor/examples.py +++ b/kor/examples.py @@ -9,9 +9,9 @@ from typing import Any, List, Tuple from kor.nodes import ( - AbstractSchemaNode, + AbstractValueNode, AbstractVisitor, - ExtractionSchemaNode, + ExtractionValueNode, Object, Option, Selection, @@ -29,7 +29,7 @@ def visit_option(self, node: "Option", **kwargs: Any) -> List[Tuple[str, str]]: raise AssertionError("Should never visit an Option node.") @staticmethod - def _assemble_output(node: AbstractSchemaNode, data: Any) -> Any: + def _assemble_output(node: AbstractValueNode, data: Any) -> Any: """Assemble the output data according to the type of the node.""" if not data: return {} @@ -83,10 +83,10 @@ def visit_selection( return examples def visit_default( - self, node: "AbstractSchemaNode", **kwargs: Any + self, node: "AbstractValueNode", **kwargs: Any ) -> List[Tuple[str, str]]: """Default visitor implementation.""" - if not isinstance(node, ExtractionSchemaNode): + if not isinstance(node, ExtractionValueNode): raise AssertionError() examples = [] @@ -95,7 +95,7 @@ def visit_default( examples.append((text, value)) return examples - def visit(self, node: "AbstractSchemaNode") -> List[Tuple[str, str]]: + def visit(self, node: "AbstractValueNode") -> List[Tuple[str, str]]: """Entry-point.""" return node.accept(self) @@ -103,7 +103,7 @@ def visit(self, node: "AbstractSchemaNode") -> List[Tuple[str, str]]: # PUBLIC API -def generate_examples(node: AbstractSchemaNode) -> List[Tuple[str, str]]: +def generate_examples(node: AbstractValueNode) -> List[Tuple[str, str]]: """Generate examples for a given element. A rudimentary implementation that simply concatenates all available examples diff --git a/kor/nodes.py b/kor/nodes.py index 8ab00d6..587565b 100644 --- a/kor/nodes.py +++ b/kor/nodes.py @@ -51,13 +51,29 @@ def visit_option(self, node: "Option", **kwargs: Any) -> T: """Visit option node.""" return self.visit_default(node, **kwargs) - def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> T: + def visit_list(self, node: "List", **kwargs: Any) -> T: + """Visit list node.""" + return self.visit_default(node, **kwargs) + + def visit_any_of(self, node: "AnyOf", **kwargs: Any) -> T: + """Visit any of node.""" + return self.visit_default(node, **kwargs) + + def visit_default(self, node: "AbstractValueNode", **kwargs: Any) -> T: """Default node implementation.""" raise NotImplementedError() -class AbstractSchemaNode(BaseModel): - """Abstract schema node. +class AbstractSchemaNode(BaseModel, abc.ABC): + """An abstract schema node that defines the schema tree.""" + + @abc.abstractmethod + def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: + """Accept a visitor.""" + + +class AbstractValueNode(AbstractSchemaNode, abc.ABC): + """Abstract value node. Each node is expected to have a unique ID, and should only use alphanumeric characters. @@ -83,17 +99,12 @@ def ensure_valid_uid(cls, uid: str) -> str: ) return uid - @abc.abstractmethod - def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: - """Accept a visitor.""" - raise NotImplementedError() - # Update return type to `Self` when bumping python version. def replace( self, id: Optional[str] = None, # pylint: disable=redefined-builtin description: Optional[str] = None, - ) -> "AbstractSchemaNode": + ) -> "AbstractValueNode": """Wrapper around data-classes replace.""" new_object = copy.copy(self) if id: @@ -103,7 +114,7 @@ def replace( return new_object -class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC): +class ExtractionValueNode(AbstractValueNode, abc.ABC): """An abstract definition for inputs that involve extraction. An extraction input can be associated with extraction examples. @@ -124,7 +135,7 @@ class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC): examples: Sequence[Tuple[str, Union[str, Sequence[str]]]] = tuple() -class Number(ExtractionSchemaNode): +class Number(ExtractionValueNode): """Built-in number input.""" def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: @@ -132,7 +143,7 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: return visitor.visit_number(self, **kwargs) -class Text(ExtractionSchemaNode): +class Text(ExtractionValueNode): """Built-in text input.""" def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: @@ -140,7 +151,7 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: return visitor.visit_text(self, **kwargs) -class Option(AbstractSchemaNode): +class Option(AbstractValueNode): """Built-in option input must be part of a selection input.""" examples: Sequence[str] = tuple() @@ -150,7 +161,27 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: return visitor.visit_option(self, **kwargs) -class Selection(AbstractSchemaNode): +class List_(AbstractSchemaNode): + """Represent a list of nodes, equivalent to many=True.""" + + nodes: Sequence[AbstractSchemaNode] + + def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: + """Accept a visitor.""" + return visitor.visit_list(self, **kwargs) + + +class AnyOf(AbstractSchemaNode): + """Equivalent of a Union type.""" + + nodes: Sequence[AbstractSchemaNode] + + def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: + """Accept a visitor.""" + return visitor.visit_any_of(self, **kwargs) + + +class Selection(AbstractValueNode): """Built-in selection node (aka Enum). A selection input is composed of one or more options. @@ -192,7 +223,7 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: return visitor.visit_selection(self, **kwargs) -class Object(AbstractSchemaNode): +class Object(AbstractValueNode): """Built-in representation for an object. Use an object node to represent an entire object that should be extracted. @@ -219,7 +250,7 @@ class Object(AbstractSchemaNode): """ - attributes: Sequence[Union[ExtractionSchemaNode, Selection, "Object"]] + attributes: Sequence[Union[ExtractionValueNode, Selection, "Object"]] examples: Sequence[ Tuple[ diff --git a/kor/type_descriptors.py b/kor/type_descriptors.py index 3ae8f34..757933a 100644 --- a/kor/type_descriptors.py +++ b/kor/type_descriptors.py @@ -11,7 +11,7 @@ from typing import Any, Iterable, List, TypeVar, Union from kor.nodes import ( - AbstractSchemaNode, + AbstractValueNode, AbstractVisitor, Number, Object, @@ -42,7 +42,7 @@ def describe(self, node: Object) -> str: class BulletPointDescriptor(TypeDescriptor[Iterable[str]]): """Generate a bullet point style schema description.""" - def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> List[str]: + def visit_default(self, node: "AbstractValueNode", **kwargs: Any) -> List[str]: """Default action for a node.""" depth = kwargs["depth"] space = "* " + depth * " " @@ -65,7 +65,7 @@ def describe(self, node: Object) -> str: class TypeScriptDescriptor(TypeDescriptor[Iterable[str]]): """Generate a typescript style schema description.""" - def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> List[str]: + def visit_default(self, node: "AbstractValueNode", **kwargs: Any) -> List[str]: """Default action for a node.""" depth = kwargs["depth"] space = depth * " " diff --git a/tests/test_encoders/test_encoders.py b/tests/test_encoders/test_encoders.py index 23ebbf7..01bc4b8 100644 --- a/tests/test_encoders/test_encoders.py +++ b/tests/test_encoders/test_encoders.py @@ -3,10 +3,10 @@ import pytest from kor.encoders import Encoder, JSONEncoder, XMLEncoder, encode_examples -from kor.nodes import AbstractSchemaNode, Number, Object, Option, Selection, Text +from kor.nodes import AbstractValueNode, Number, Object, Option, Selection, Text -def _get_schema() -> AbstractSchemaNode: +def _get_schema() -> AbstractValueNode: """Make an abstract input node.""" option = Option(id="option", description="Option", examples=["selection"]) number = Number(id="number", description="Number", examples=[("number", "2")]) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index d83b37d..6907d3d 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -6,7 +6,7 @@ from kor.nodes import AbstractVisitor -class FakeSchemaNode(nodes.AbstractSchemaNode): +class FakeValueNode(nodes.AbstractValueNode): """Fake Schema Node for testing purposes.""" def accept(self, visitor: AbstractVisitor, **kwargs: Any) -> Any: @@ -17,19 +17,19 @@ def accept(self, visitor: AbstractVisitor, **kwargs: Any) -> Any: @pytest.mark.parametrize("invalid_id", ["", "@@#", " ", "NAME", "1name", "name-name"]) def test_invalid_identifier_raises_error(invalid_id: str) -> None: with pytest.raises(ValueError): - FakeSchemaNode(id=invalid_id, description="Toy") + FakeValueNode(id=invalid_id, description="Toy") @pytest.mark.parametrize("valid_id", ["name", "name_name", "_name", "n1ame"]) def test_can_instantiate_with_valid_id(valid_id: str) -> None: """Can instantiate an abstract input with a valid ID.""" - FakeSchemaNode(id=valid_id, description="Toy") + FakeValueNode(id=valid_id, description="Toy") def test_extraction_input_cannot_be_instantiated() -> None: """ExtractionInput is abstract and should not be instantiated.""" with pytest.raises(TypeError): - nodes.ExtractionSchemaNode( # type: ignore[abstract] + nodes.ExtractionValueNode( # type: ignore[abstract] id="help", description="description", examples=[], diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 7eccb5a..e02ecd5 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -3,7 +3,7 @@ import pytest from kor.nodes import ( - AbstractSchemaNode, + AbstractValueNode, AbstractVisitor, Number, Object, @@ -16,11 +16,11 @@ class TestVisitor(AbstractVisitor[Tuple[str, Any]]): """Toy input for tests.""" - def visit_default(self, node: AbstractSchemaNode, **kwargs: Any) -> Tuple[str, Any]: + def visit_default(self, node: AbstractValueNode, **kwargs: Any) -> Tuple[str, Any]: """Verify default is invoked""" return node.id, kwargs - def visit(self, node: AbstractSchemaNode, **kwargs: Any) -> Tuple[str, Any]: + def visit(self, node: AbstractValueNode, **kwargs: Any) -> Tuple[str, Any]: """Convenience method.""" return node.accept(self, **kwargs) @@ -38,6 +38,6 @@ def visit(self, node: AbstractSchemaNode, **kwargs: Any) -> Tuple[str, Any]: Option(id="uid"), ], ) -def test_visit_default_is_invoked(node: AbstractSchemaNode) -> None: +def test_visit_default_is_invoked(node: AbstractValueNode) -> None: visitor = TestVisitor() assert visitor.visit(node, a="a", b="b") == ("uid", {"a": "a", "b": "b"})