Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP On AddAny type #146

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions kor/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pydantic import BaseModel

from .nodes import ExtractionSchemaNode, Number, Object, Option, Selection, Text
from .nodes import ExtractionValueNode, Number, Object, Option, Selection, Text
from .validators import PydanticValidator, Validator

# Not going to support dicts or lists since that requires recursive checks.
Expand Down Expand Up @@ -46,7 +46,7 @@ def _translate_pydantic_to_kor(
The Kor internal representation of the model.
"""

attributes: List[Union[ExtractionSchemaNode, Selection, "Object"]] = []
attributes: List[Union[ExtractionValueNode, Selection, "Object"]] = []
for field_name, field in model_class.__fields__.items():
field_info = field.field_info
extra = field_info.extra
Expand All @@ -59,7 +59,7 @@ def _translate_pydantic_to_kor(

type_ = field.type_
field_many = get_origin(field.outer_type_) is list
attribute: Union[ExtractionSchemaNode, Selection, "Object"]
attribute: Union[ExtractionValueNode, Selection, "Object"]
# Precedence matters here since bool is a subclass of int
if get_origin(type_) is Union:
args = get_args(type_)
Expand Down
6 changes: 3 additions & 3 deletions kor/encoders/csv_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from kor.encoders.typedefs import SchemaBasedEncoder
from kor.encoders.utils import unwrap_tag, wrap_in_tag
from kor.exceptions import ParseError
from kor.nodes import AbstractSchemaNode, Object
from kor.nodes import AbstractValueNode, Object

DELIMITER = "|"


def _extract_top_level_fieldnames(node: AbstractSchemaNode) -> List[str]:
def _extract_top_level_fieldnames(node: AbstractValueNode) -> List[str]:
"""Temporary schema description for CSV extraction."""
if isinstance(node, Object):
return [attributes.id for attributes in node.attributes]
Expand All @@ -32,7 +32,7 @@ def _extract_top_level_fieldnames(node: AbstractSchemaNode) -> List[str]:
class CSVEncoder(SchemaBasedEncoder):
"""CSV encoder."""

def __init__(self, node: AbstractSchemaNode, use_tags: bool = False) -> None:
def __init__(self, node: AbstractValueNode, use_tags: bool = False) -> None:
"""Attach node to the encoder to allow the encoder to understand schema.

Args:
Expand Down
4 changes: 2 additions & 2 deletions kor/encoders/encode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, List, Literal, Mapping, Sequence, Tuple, Type, Union

from kor.nodes import AbstractSchemaNode
from kor.nodes import AbstractValueNode

from .csv_data import CSVEncoder
from .json_data import JSONEncoder
Expand Down Expand Up @@ -67,7 +67,7 @@ def encode_examples(

def initialize_encoder(
encoder_or_encoder_class: Union[Type[Encoder], Encoder, str],
schema: AbstractSchemaNode,
schema: AbstractValueNode,
**kwargs: Any,
) -> Encoder:
"""Flexible way to initialize an encoder, used only for top level API.
Expand Down
4 changes: 2 additions & 2 deletions kor/encoders/typedefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import abc
from typing import Any

from kor.nodes import AbstractSchemaNode
from kor.nodes import AbstractValueNode


class Encoder(abc.ABC):
Expand Down Expand Up @@ -46,6 +46,6 @@ class SchemaBasedEncoder(Encoder, abc.ABC):
of the data that's being encoded.
"""

def __init__(self, node: AbstractSchemaNode, **kwargs: Any) -> None:
def __init__(self, node: AbstractValueNode, **kwargs: Any) -> None:
"""Attach node to the encoder to allow the encoder to understand schema."""
self.node = node
14 changes: 7 additions & 7 deletions kor/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing import Any, List, Tuple

from kor.nodes import (
AbstractSchemaNode,
AbstractValueNode,
AbstractVisitor,
ExtractionSchemaNode,
ExtractionValueNode,
Object,
Option,
Selection,
Expand All @@ -29,7 +29,7 @@ def visit_option(self, node: "Option", **kwargs: Any) -> List[Tuple[str, str]]:
raise AssertionError("Should never visit an Option node.")

@staticmethod
def _assemble_output(node: AbstractSchemaNode, data: Any) -> Any:
def _assemble_output(node: AbstractValueNode, data: Any) -> Any:
"""Assemble the output data according to the type of the node."""
if not data:
return {}
Expand Down Expand Up @@ -83,10 +83,10 @@ def visit_selection(
return examples

def visit_default(
self, node: "AbstractSchemaNode", **kwargs: Any
self, node: "AbstractValueNode", **kwargs: Any
) -> List[Tuple[str, str]]:
"""Default visitor implementation."""
if not isinstance(node, ExtractionSchemaNode):
if not isinstance(node, ExtractionValueNode):
raise AssertionError()
examples = []

Expand All @@ -95,15 +95,15 @@ def visit_default(
examples.append((text, value))
return examples

def visit(self, node: "AbstractSchemaNode") -> List[Tuple[str, str]]:
def visit(self, node: "AbstractValueNode") -> List[Tuple[str, str]]:
"""Entry-point."""
return node.accept(self)


# PUBLIC API


def generate_examples(node: AbstractSchemaNode) -> List[Tuple[str, str]]:
def generate_examples(node: AbstractValueNode) -> List[Tuple[str, str]]:
"""Generate examples for a given element.

A rudimentary implementation that simply concatenates all available examples
Expand Down
63 changes: 47 additions & 16 deletions kor/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,29 @@ def visit_option(self, node: "Option", **kwargs: Any) -> T:
"""Visit option node."""
return self.visit_default(node, **kwargs)

def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> T:
def visit_list(self, node: "List", **kwargs: Any) -> T:
"""Visit list node."""
return self.visit_default(node, **kwargs)

def visit_any_of(self, node: "AnyOf", **kwargs: Any) -> T:
"""Visit any of node."""
return self.visit_default(node, **kwargs)

def visit_default(self, node: "AbstractValueNode", **kwargs: Any) -> T:
"""Default node implementation."""
raise NotImplementedError()


class AbstractSchemaNode(BaseModel):
"""Abstract schema node.
class AbstractSchemaNode(BaseModel, abc.ABC):
"""An abstract schema node that defines the schema tree."""

@abc.abstractmethod
def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""


class AbstractValueNode(AbstractSchemaNode, abc.ABC):
"""Abstract value node.

Each node is expected to have a unique ID, and should
only use alphanumeric characters.
Expand All @@ -83,17 +99,12 @@ def ensure_valid_uid(cls, uid: str) -> str:
)
return uid

@abc.abstractmethod
def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
raise NotImplementedError()

# Update return type to `Self` when bumping python version.
def replace(
self,
id: Optional[str] = None, # pylint: disable=redefined-builtin
description: Optional[str] = None,
) -> "AbstractSchemaNode":
) -> "AbstractValueNode":
"""Wrapper around data-classes replace."""
new_object = copy.copy(self)
if id:
Expand All @@ -103,7 +114,7 @@ def replace(
return new_object


class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC):
class ExtractionValueNode(AbstractValueNode, abc.ABC):
"""An abstract definition for inputs that involve extraction.

An extraction input can be associated with extraction examples.
Expand All @@ -124,23 +135,23 @@ class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC):
examples: Sequence[Tuple[str, Union[str, Sequence[str]]]] = tuple()


class Number(ExtractionSchemaNode):
class Number(ExtractionValueNode):
"""Built-in number input."""

def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_number(self, **kwargs)


class Text(ExtractionSchemaNode):
class Text(ExtractionValueNode):
"""Built-in text input."""

def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_text(self, **kwargs)


class Option(AbstractSchemaNode):
class Option(AbstractValueNode):
"""Built-in option input must be part of a selection input."""

examples: Sequence[str] = tuple()
Expand All @@ -150,7 +161,27 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
return visitor.visit_option(self, **kwargs)


class Selection(AbstractSchemaNode):
class List_(AbstractSchemaNode):
"""Represent a list of nodes, equivalent to many=True."""

nodes: Sequence[AbstractSchemaNode]

def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_list(self, **kwargs)


class AnyOf(AbstractSchemaNode):
"""Equivalent of a Union type."""

nodes: Sequence[AbstractSchemaNode]

def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_any_of(self, **kwargs)


class Selection(AbstractValueNode):
"""Built-in selection node (aka Enum).

A selection input is composed of one or more options.
Expand Down Expand Up @@ -192,7 +223,7 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
return visitor.visit_selection(self, **kwargs)


class Object(AbstractSchemaNode):
class Object(AbstractValueNode):
"""Built-in representation for an object.

Use an object node to represent an entire object that should be extracted.
Expand All @@ -219,7 +250,7 @@ class Object(AbstractSchemaNode):

"""

attributes: Sequence[Union[ExtractionSchemaNode, Selection, "Object"]]
attributes: Sequence[Union[ExtractionValueNode, Selection, "Object"]]

examples: Sequence[
Tuple[
Expand Down
6 changes: 3 additions & 3 deletions kor/type_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Iterable, List, TypeVar, Union

from kor.nodes import (
AbstractSchemaNode,
AbstractValueNode,
AbstractVisitor,
Number,
Object,
Expand Down Expand Up @@ -42,7 +42,7 @@ def describe(self, node: Object) -> str:
class BulletPointDescriptor(TypeDescriptor[Iterable[str]]):
"""Generate a bullet point style schema description."""

def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> List[str]:
def visit_default(self, node: "AbstractValueNode", **kwargs: Any) -> List[str]:
"""Default action for a node."""
depth = kwargs["depth"]
space = "* " + depth * " "
Expand All @@ -65,7 +65,7 @@ def describe(self, node: Object) -> str:
class TypeScriptDescriptor(TypeDescriptor[Iterable[str]]):
"""Generate a typescript style schema description."""

def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> List[str]:
def visit_default(self, node: "AbstractValueNode", **kwargs: Any) -> List[str]:
"""Default action for a node."""
depth = kwargs["depth"]
space = depth * " "
Expand Down
4 changes: 2 additions & 2 deletions tests/test_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import pytest

from kor.encoders import Encoder, JSONEncoder, XMLEncoder, encode_examples
from kor.nodes import AbstractSchemaNode, Number, Object, Option, Selection, Text
from kor.nodes import AbstractValueNode, Number, Object, Option, Selection, Text


def _get_schema() -> AbstractSchemaNode:
def _get_schema() -> AbstractValueNode:
"""Make an abstract input node."""
option = Option(id="option", description="Option", examples=["selection"])
number = Number(id="number", description="Number", examples=[("number", "2")])
Expand Down
8 changes: 4 additions & 4 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kor.nodes import AbstractVisitor


class FakeSchemaNode(nodes.AbstractSchemaNode):
class FakeValueNode(nodes.AbstractValueNode):
"""Fake Schema Node for testing purposes."""

def accept(self, visitor: AbstractVisitor, **kwargs: Any) -> Any:
Expand All @@ -17,19 +17,19 @@ def accept(self, visitor: AbstractVisitor, **kwargs: Any) -> Any:
@pytest.mark.parametrize("invalid_id", ["", "@@#", " ", "NAME", "1name", "name-name"])
def test_invalid_identifier_raises_error(invalid_id: str) -> None:
with pytest.raises(ValueError):
FakeSchemaNode(id=invalid_id, description="Toy")
FakeValueNode(id=invalid_id, description="Toy")


@pytest.mark.parametrize("valid_id", ["name", "name_name", "_name", "n1ame"])
def test_can_instantiate_with_valid_id(valid_id: str) -> None:
"""Can instantiate an abstract input with a valid ID."""
FakeSchemaNode(id=valid_id, description="Toy")
FakeValueNode(id=valid_id, description="Toy")


def test_extraction_input_cannot_be_instantiated() -> None:
"""ExtractionInput is abstract and should not be instantiated."""
with pytest.raises(TypeError):
nodes.ExtractionSchemaNode( # type: ignore[abstract]
nodes.ExtractionValueNode( # type: ignore[abstract]
id="help",
description="description",
examples=[],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from kor.nodes import (
AbstractSchemaNode,
AbstractValueNode,
AbstractVisitor,
Number,
Object,
Expand All @@ -16,11 +16,11 @@
class TestVisitor(AbstractVisitor[Tuple[str, Any]]):
"""Toy input for tests."""

def visit_default(self, node: AbstractSchemaNode, **kwargs: Any) -> Tuple[str, Any]:
def visit_default(self, node: AbstractValueNode, **kwargs: Any) -> Tuple[str, Any]:
"""Verify default is invoked"""
return node.id, kwargs

def visit(self, node: AbstractSchemaNode, **kwargs: Any) -> Tuple[str, Any]:
def visit(self, node: AbstractValueNode, **kwargs: Any) -> Tuple[str, Any]:
"""Convenience method."""
return node.accept(self, **kwargs)

Expand All @@ -38,6 +38,6 @@ def visit(self, node: AbstractSchemaNode, **kwargs: Any) -> Tuple[str, Any]:
Option(id="uid"),
],
)
def test_visit_default_is_invoked(node: AbstractSchemaNode) -> None:
def test_visit_default_is_invoked(node: AbstractValueNode) -> None:
visitor = TestVisitor()
assert visitor.visit(node, a="a", b="b") == ("uid", {"a": "a", "b": "b"})