From 857bd421931ab754d7a864f8b83eec8c91cb708f Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 14:49:38 -0800 Subject: [PATCH 01/13] Update for README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 74ce3d1e1..c4ea671be 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ LabGraph is a streaming framework built by the Facebook Reality Labs Research te ### Method 1 - using PyPI (Recommended) **Prerequisites**: -- [Python 3.6](https://www.python.org/downloads/release/python-368/) -- Windows and Linux (CentOS 7, CentOS 8, Ubuntu 20.04) +- Python3.6+ (Python 3.8 recommended) +- Mac (Big Sur), Windows and Linux (CentOS 7, CentOS 8, Ubuntu 20.04; Python3.6 only) - Based on [PyPa](https://github.com/pypa/manylinux), the following Linux systems are also supported: Fedora 32+, Mageia 8+, openSUSE 15.3+, Photon OS 4.0+ (3.0+ with updates), Ubuntu 20.04+ ``` @@ -20,7 +20,7 @@ pip install labgraph **Prerequisites**: - [Buck](https://buck.build/setup/getting_started.html) ([Watchman](https://facebook.github.io/watchman/docs/install) also recommended) -- [Python 3.6](https://www.python.org/downloads/release/python-368/) (note: currently incompatible with Anaconda) +- [Python3.6-Python3.10](https://www.python.org/downloads/release/) - **Windows only:** [Build Tools for Visual Studio 2019](https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019) ``` From f1d83c4da2425fc54b5b7344cbeb4834e75a4264 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 14:53:21 -0800 Subject: [PATCH 02/13] Update setup and test_script --- setup.py | 42 ++++++++++++++++++++++++------------------ test_script.sh | 26 +++++++++++++------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/setup.py b/setup.py index 6e09f0292..4be29f8cf 100644 --- a/setup.py +++ b/setup.py @@ -22,28 +22,34 @@ setup( name="labgraph", - version="1.0.3", - description="Python streaming framework", + version="2.0.0", + description="Research-friendly framework for in-lab experiments", + long_description="LabGraph is a Python framework for rapidly prototyping experimental " + + "systems for real-time streaming applications. " + + "It is particularly well-suited to real-time neuroscience, " + + "physiology and psychology experiments.", + url="https://github.com/facebookresearch/labgraph", + license="MIT", + keywords="python streaming framework, reality lab, neuroscience, physiology, psychology", packages=find_packages(), package_data={"labgraph": ["tests/mypy.ini"]}, - python_requires=">=3.6, <3.7", + python_requires=">=3.6", ext_modules=LIBRARY_EXTENSIONS, cmdclass={"build_ext": buck_build_ext}, install_requires=[ - "appdirs==1.4.3", - "click==7.0", - "dataclasses==0.6", - "h5py==2.10.0", - "matplotlib==3.1.1", - "mypy==0.782", - "numpy==1.16.4", - "psutil==5.6.7", - "pytest==3.10.1", - "pytest_mock==2.0.0", - "pyzmq==18.1.0", - "typeguard==2.5.1", - "typing_extensions>=3.7.4.3", - "yappi==1.2.5", - "pylsl==1.15.0", + "appdirs>=1.4.4", + "click>=7.1.2", + "h5py>=3.3.0", + "matplotlib>=3.1.2", + "mypy>=0.910", + "numpy>=1.19.5", + "psutil>=5.6.7", + "pytest>=3.10.1", + "pytest_mock>=2.0.0", + "pyzmq>=19.0.2", + "typeguard>=2.10.0", + "typing_extensions>=3.7.4", + "websockets>=8.1", + "yappi>=1.2.5", ], ) diff --git a/test_script.sh b/test_script.sh index 2715af17a..af9fb3ac2 100644 --- a/test_script.sh +++ b/test_script.sh @@ -2,16 +2,16 @@ # RUN export LC_ALL=C.UTF-8 # RUN export LANG=en_US.utf-8 -python3.6 -m pytest --pyargs -v labgraph._cthulhu -python3.6 -m pytest --pyargs -v labgraph.events -python3.6 -m pytest --pyargs -v labgraph.graphs -python3.6 -m pytest --pyargs -v labgraph.loggers -python3.6 -m pytest --pyargs -v labgraph.messages -python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_process_manager -python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_aligner -python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_cpp -python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_exception -python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_launch -python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_runner -python3.6 -m pytest --pyargs -v labgraph.zmq_node -python3.6 -m pytest --pyargs -v labgraph.devices.protocols.lsl +python3 -m pytest --pyargs -v labgraph._cthulhu +python3 -m pytest --pyargs -v labgraph.events +python3 -m pytest --pyargs -v labgraph.graphs +python3 -m pytest --pyargs -v labgraph.loggers +python3 -m pytest --pyargs -v labgraph.messages +python3 -m pytest --pyargs -v labgraph.runners.tests.test_process_manager +python3 -m pytest --pyargs -v labgraph.runners.tests.test_aligner +python3 -m pytest --pyargs -v labgraph.runners.tests.test_cpp +python3 -m pytest --pyargs -v labgraph.runners.tests.test_exception +python3 -m pytest --pyargs -v labgraph.runners.tests.test_launch +python3 -m pytest --pyargs -v labgraph.runners.tests.test_runner +python3 -m pytest --pyargs -v labgraph.zmq_node +python3 -m pytest --pyargs -v labgraph.devices.protocols.lsl From eda02c8429f30f64fb3cf4430841bb5b60dddca3 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 15:34:33 -0800 Subject: [PATCH 03/13] Updates for labgraph and third-party. --- third-party/pybind11/BUCK | 2 +- third-party/python/BUCK | 4 ++-- third-party/python/DEFS | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third-party/pybind11/BUCK b/third-party/pybind11/BUCK index fe708e5cc..0d753ecd6 100644 --- a/third-party/pybind11/BUCK +++ b/third-party/pybind11/BUCK @@ -4,6 +4,6 @@ cxx_library( public_include_directories = ["include"], preferred_linkage = "static", exported_headers = glob(["include/pybind11/**/*.h"]), - deps = ["//third-party/python:python3.6"], + deps = ["//third-party/python:python"], visibility = ["PUBLIC"], ) diff --git a/third-party/python/BUCK b/third-party/python/BUCK index dc343a83d..7a557b3f4 100644 --- a/third-party/python/BUCK +++ b/third-party/python/BUCK @@ -18,7 +18,7 @@ if PLATFORM == "win": ) prebuilt_cxx_library( - name = "python" + PYTHON_VERSION, + name = "python", header_dirs = [":python" + PYTHON_VERSION + "_include"], static_lib = ":python" + PYTHON_VERSION + "_lib", preferred_linkage = "static", @@ -33,7 +33,7 @@ else: ) prebuilt_cxx_library( - name = "python" + PYTHON_VERSION, + name = "python", header_dirs = [":python" + PYTHON_VERSION + "_include"], shared_lib = ":python" + PYTHON_VERSION + "_lib", preferred_linkage = "shared", diff --git a/third-party/python/DEFS b/third-party/python/DEFS index 3ed517c5f..54f43c4bd 100644 --- a/third-party/python/DEFS +++ b/third-party/python/DEFS @@ -1,7 +1,7 @@ with allow_unsafe_import(): import os, platform, shutil -PYTHON_VERSION = "3.6" +PYTHON_VERSION = "3.6" # To be modified for other python versions. def _config_var(executable, key): @@ -18,10 +18,10 @@ def _config_var(executable, key): def config_var(key): - for executable in ("python", "python3", "python" + PYTHON_VERSION): + for executable in (python3", "python" + PYTHON_VERSION): if not shutil.which(executable): continue - if _config_var(executable, "py_version_short") != "3.6": + if _config_var(executable, "py_version_short") != PYTHON_VERSION: continue return _config_var(executable, key) raise Exception("Could not find Python " + PYTHON_VERSION + " on PATH") From dc873e3ec2149f979f70530690588a3ebf7e9405 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 15:47:33 -0800 Subject: [PATCH 04/13] Updates for labgraph folder. --- labgraph/_cthulhu/bindings.py | 2 +- labgraph/_cthulhu/cthulhu.py | 41 +-- labgraph/_cthulhu/tests/test_cthulhu.py | 12 +- labgraph/cpp/bindings.cpp | 2 +- labgraph/cpp/include/labgraph/Node.h | 18 +- labgraph/cpp/tests/bindings.cpp | 2 +- labgraph/events/event_generator.py | 10 +- labgraph/events/event_generator_node.py | 2 +- labgraph/events/tests/test_event_generator.py | 10 +- labgraph/graphs/config.py | 4 +- labgraph/graphs/cpp_node.py | 4 +- labgraph/graphs/graph.py | 8 +- labgraph/graphs/group.py | 6 +- labgraph/graphs/method.py | 40 +-- labgraph/graphs/module.py | 22 +- labgraph/graphs/node.py | 6 +- labgraph/graphs/node_test_harness.py | 2 +- labgraph/graphs/state.py | 2 +- labgraph/graphs/tests/test_config.py | 4 +- labgraph/graphs/tests/test_graph.py | 3 +- labgraph/graphs/tests/test_group.py | 4 +- labgraph/graphs/tests/test_harness.py | 10 + labgraph/graphs/tests/test_node.py | 8 +- labgraph/graphs/tests/test_publisher.py | 4 +- labgraph/graphs/tests/test_subscriber.py | 6 +- labgraph/graphs/topic.py | 2 +- labgraph/loggers/hdf5/logger.py | 36 +- labgraph/loggers/hdf5/reader.py | 90 +++++ labgraph/loggers/hdf5/tests/__init__.py | 2 + labgraph/loggers/hdf5/tests/test_logger.py | 104 +----- labgraph/loggers/hdf5/tests/test_reader.py | 38 +++ labgraph/loggers/hdf5/tests/test_utils.py | 110 ++++++ labgraph/loggers/logger.py | 4 +- labgraph/messages/message.py | 34 +- labgraph/messages/tests/test_message.py | 35 +- labgraph/messages/types.py | 317 ++++++++++++++---- labgraph/runners/aligner.py | 18 +- labgraph/runners/entry.py | 2 +- labgraph/runners/exceptions.py | 2 +- labgraph/runners/local_runner.py | 16 +- labgraph/runners/parallel_runner.py | 6 +- labgraph/runners/runner.py | 2 +- labgraph/runners/tests/test_cpp.py | 132 -------- labgraph/runners/tests/test_exception.py | 39 +-- .../runners/tests/test_process_manager.py | 125 ------- labgraph/runners/util.py | 2 +- labgraph/tests/test_imports.py | 4 +- labgraph/tests/test_typecheck.py | 73 ---- labgraph/util/__init__.py | 4 +- labgraph/util/error.py | 6 +- labgraph/util/typing.py | 19 ++ labgraph/util/version.py | 2 +- labgraph/zmq_node/tests/test_zmq_node.py | 272 --------------- labgraph/zmq_node/zmq_poller_node.py | 1 + labgraph/zmq_node/zmq_sender_node.py | 6 +- 55 files changed, 752 insertions(+), 983 deletions(-) create mode 100644 labgraph/loggers/hdf5/reader.py create mode 100644 labgraph/loggers/hdf5/tests/test_reader.py create mode 100644 labgraph/loggers/hdf5/tests/test_utils.py delete mode 100644 labgraph/runners/tests/test_cpp.py delete mode 100644 labgraph/tests/test_typecheck.py create mode 100644 labgraph/util/typing.py delete mode 100644 labgraph/zmq_node/tests/test_zmq_node.py diff --git a/labgraph/_cthulhu/bindings.py b/labgraph/_cthulhu/bindings.py index 6c65d2dce..401b3cb83 100644 --- a/labgraph/_cthulhu/bindings.py +++ b/labgraph/_cthulhu/bindings.py @@ -2,7 +2,7 @@ # Copyright 2004-present Facebook. All Rights Reserved. # This is a wrapper around cthulhubindings to randomize the name of the shared memory it -# uses. This allows us to keep shared memory for different LabGraph graphs separate +# uses. This allows us to keep shared memory for different Labgraph graphs separate # when they are running simultaneously. import os diff --git a/labgraph/_cthulhu/cthulhu.py b/labgraph/_cthulhu/cthulhu.py index e2557adbd..166a99d6b 100644 --- a/labgraph/_cthulhu/cthulhu.py +++ b/labgraph/_cthulhu/cthulhu.py @@ -6,7 +6,8 @@ from typing import Callable, Generic, Optional, Type, TypeVar from ..messages.message import Message -from ..util.error import LabGraphError +from ..util.error import LabgraphError +from ..util.typing import is_generic_subclass from .bindings import ( # type: ignore PerformanceSummary, StreamConsumer, @@ -22,7 +23,7 @@ T = TypeVar("T") -class LabGraphCallbackParams(Generic[T]): +class LabgraphCallbackParams(Generic[T]): def __init__(self, message: T, stream_id: Optional[str]) -> None: self.message = message self.stream_id = stream_id @@ -31,7 +32,7 @@ def __init__(self, message: T, stream_id: Optional[str]) -> None: stream_id: Optional[str] -LabGraphCallback = Callable[..., None] +LabgraphCallback = Callable[..., None] CthulhuCallback = Callable[[StreamSample], None] @@ -43,17 +44,17 @@ class Mode(Enum): class Consumer(StreamConsumer): # type: ignore """ Convenience wrapper of Cthulhu's `StreamConsumer` that allows us to specify a - callback accepting LabGraph `Message`s. + callback accepting Labgraph `Message`s. Args: stream_interface: The stream interface to use. - sample_callback: The callback to use (uses LabGraph messages). + sample_callback: The callback to use (uses Labgraph messages). """ def __init__( self, stream_interface: StreamInterface, - sample_callback: LabGraphCallback, + sample_callback: LabgraphCallback, mode: Mode = Mode.SYNC, stream_id: Optional[str] = None, ) -> None: @@ -66,9 +67,9 @@ def __init__( ) self.stream_id = stream_id - def _to_cthulhu_callback(self, callback: LabGraphCallback) -> CthulhuCallback: + def _to_cthulhu_callback(self, callback: LabgraphCallback) -> CthulhuCallback: """ - Given a LabGraph callback, creates a Cthulhu callback (accepting + Given a Labgraph callback, creates a Cthulhu callback (accepting `StreamSample`s). """ @@ -83,23 +84,23 @@ def wrapped_callback(sample: StreamSample) -> None: message_types = [ arg_type for arg_type in annotated_types.values() - if issubclass(arg_type, Message) - or issubclass(arg_type, LabGraphCallbackParams) + if is_generic_subclass(arg_type, LabgraphCallbackParams) + or issubclass(arg_type, Message) ] assert len(message_types) == 1 message_type = message_types[0] - if issubclass(message_type, Message): - message = message_type(__sample__=sample) - callback(message) - elif issubclass(message_type, LabGraphCallbackParams): + if is_generic_subclass(message_type, LabgraphCallbackParams): (arg_type,) = message_type.__args__ message = arg_type(__sample__=sample) - params = LabGraphCallbackParams(message, self.stream_id) + params = LabgraphCallbackParams(message, self.stream_id) callback(params) + elif issubclass(message_type, Message): + message = message_type(__sample__=sample) + callback(message) else: raise TypeError( - f"Expected callback taking type '{Message.__name__}' or '{LabGraphCallbackParams.__name__}', got '{message_type.__name__}'" + f"Expected callback taking type '{Message.__name__}' or '{LabgraphCallbackParams.__name__}', got '{message_type.__name__}'" ) return wrapped_callback @@ -118,7 +119,7 @@ def __exit__( class Producer(StreamProducer): # type: ignore """ - Convenience wrapper of Cthulhu's `StreamProducer` that accepts a LabGraph message. + Convenience wrapper of Cthulhu's `StreamProducer` that accepts a Labgraph message. Args: stream_interface: The stream interface to use. @@ -134,7 +135,7 @@ def __init__( def produce_message(self, message: Message) -> None: """ - Produces a LabGraph message to the Cthulhu stream. + Produces a Labgraph message to the Cthulhu stream. Args: message: The message to produce. @@ -155,7 +156,7 @@ def __exit__( def register_stream(name: str, message_type: Type[Message]) -> StreamInterface: """ - Registers a stream with a LabGraph message type to the Cthulhu stream registry. + Registers a stream with a Labgraph message type to the Cthulhu stream registry. Args: name: The name of the stream. @@ -168,7 +169,7 @@ def register_stream(name: str, message_type: Type[Message]) -> StreamInterface: type_id = existing_stream.description.type existing_type = typeRegistry().findTypeID(type_id) if existing_type.typeName != message_type.versioned_name: - raise LabGraphError( + raise LabgraphError( f"Tried to register stream '{name}' with type " f"'{message_type.versioned_name}', but it already exists with type " f"'{existing_type.typeName}'" diff --git a/labgraph/_cthulhu/tests/test_cthulhu.py b/labgraph/_cthulhu/tests/test_cthulhu.py index 5f8cb30d3..4eba5db69 100644 --- a/labgraph/_cthulhu/tests/test_cthulhu.py +++ b/labgraph/_cthulhu/tests/test_cthulhu.py @@ -8,7 +8,7 @@ from ...messages.message import Message from ...util.random import random_string from ...util.testing import local_test -from ..cthulhu import Consumer, LabGraphCallbackParams, Producer, register_stream +from ..cthulhu import Consumer, LabgraphCallbackParams, Producer, register_stream RANDOM_ID_LENGTH = 128 @@ -23,7 +23,7 @@ class MyMessage(Message): @local_test def test_producer_and_consumer() -> None: """ - Tests that we can use the LabGraph wrappers around the Cthulhu APIs to publish + Tests that we can use the Labgraph wrappers around the Cthulhu APIs to publish and subscribe to messages. """ stream_name = random_string(length=RANDOM_ID_LENGTH) @@ -33,7 +33,7 @@ def test_producer_and_consumer() -> None: with Producer(stream_interface=stream_interface) as producer: - def callback(params: LabGraphCallbackParams[MyMessage]) -> None: + def callback(params: LabgraphCallbackParams[MyMessage]) -> None: received_messages.append(params.message) with Consumer(stream_interface=stream_interface, sample_callback=callback): @@ -51,7 +51,7 @@ def callback(params: LabGraphCallbackParams[MyMessage]) -> None: @local_test def test_complex_graph() -> None: """ - Tests that we can use the LabGraph wrappers around the Cthulhu APIs to stream + Tests that we can use the Labgraph wrappers around the Cthulhu APIs to stream messages in a more complex graph. """ stream_name1 = random_string(length=RANDOM_ID_LENGTH) @@ -65,14 +65,14 @@ def test_complex_graph() -> None: with Producer(stream_interface=stream2) as producer2: - def transform_callback(params: LabGraphCallbackParams[MyMessage]) -> None: + def transform_callback(params: LabgraphCallbackParams[MyMessage]) -> None: producer2.produce_message( MyMessage(int_field=params.message.int_field * 2) ) with Consumer(stream_interface=stream1, sample_callback=transform_callback): - def sink_callback(params: LabGraphCallbackParams[MyMessage]) -> None: + def sink_callback(params: LabgraphCallbackParams[MyMessage]) -> None: received_messages.append(params.message) with Consumer(stream_interface=stream2, sample_callback=sink_callback): diff --git a/labgraph/cpp/bindings.cpp b/labgraph/cpp/bindings.cpp index a7294df31..2e4a8db6e 100644 --- a/labgraph/cpp/bindings.cpp +++ b/labgraph/cpp/bindings.cpp @@ -12,7 +12,7 @@ namespace py = pybind11; namespace labgraph { void bindings(py::module_& m) { - m.doc() = "LabGraph C++: C++ nodes for LabGraph"; + m.doc() = "Labgraph C++: C++ nodes for Labgraph"; py::class_(m, "Node") .def("setup", &Node::setup) diff --git a/labgraph/cpp/include/labgraph/Node.h b/labgraph/cpp/include/labgraph/Node.h index 3bec829df..b9e678ea9 100644 --- a/labgraph/cpp/include/labgraph/Node.h +++ b/labgraph/cpp/include/labgraph/Node.h @@ -10,7 +10,7 @@ namespace labgraph { /** * struct NodeTopic * - * Describes a mapping between a LabGraph topic and a Cthulhu stream. + * Describes a mapping between a Labgraph topic and a Cthulhu stream. */ struct NodeTopic { std::string topicName; @@ -20,11 +20,11 @@ struct NodeTopic { /** * struct NodeBootstrapInfo * - * Contains all information needed to bootstrap a LabGraph C++ node into a state ready - * for execution in an existing LabGraph graph. + * Contains all information needed to bootstrap a Labgraph C++ node into a state ready + * for execution in an existing Labgraph graph. */ struct NodeBootstrapInfo { - std::vector topics; // Mapping of LabGraph topics to Cthulhu streams + std::vector topics; // Mapping of Labgraph topics to Cthulhu streams }; typedef std::function Publisher; @@ -67,26 +67,26 @@ struct TransformerInfo { /** * class Node * - * Describes a C++ node in a LabGraph graph. + * Describes a C++ node in a Labgraph graph. */ class Node { public: Node(); virtual ~Node(); - /*** Setup function that is run when the LabGraph graph is starting up. */ + /*** Setup function that is run when the Labgraph graph is starting up. */ virtual void setup(); /** - * Entry point that is run in the LabGraph graph to start all the node's publishers. + * Entry point that is run in the Labgraph graph to start all the node's publishers. */ void run(); - /*** Cleanup function that is run when the LabGraph graph is shutting down. */ + /*** Cleanup function that is run when the Labgraph graph is shutting down. */ virtual void cleanup(); /** - * Bootstrapping function that is run by the LabGraph graph to connect this node's + * Bootstrapping function that is run by the Labgraph graph to connect this node's * topics with their corresponding Cthulhu streams. */ void bootstrap(NodeBootstrapInfo& bootstrapInfo); diff --git a/labgraph/cpp/tests/bindings.cpp b/labgraph/cpp/tests/bindings.cpp index 662260338..d3e1be2ae 100644 --- a/labgraph/cpp/tests/bindings.cpp +++ b/labgraph/cpp/tests/bindings.cpp @@ -8,7 +8,7 @@ namespace py = pybind11; PYBIND11_MODULE(MyCPPNodes, m) { - m.doc() = "LabGraph C++: MyCPPNodes unit test"; + m.doc() = "Labgraph C++: MyCPPNodes unit test"; std::vector sourceTopics = {"A"}; labgraph::bindNode(m, "MyCPPSource", sourceTopics) diff --git a/labgraph/events/event_generator.py b/labgraph/events/event_generator.py index 9cd5ba7a1..cf13f937e 100644 --- a/labgraph/events/event_generator.py +++ b/labgraph/events/event_generator.py @@ -8,7 +8,7 @@ from ..graphs.topic import Topic from ..messages.message import Message, TimestampedMessage -from ..util.error import LabGraphError +from ..util.error import LabgraphError from ..util.min_heap import MinHeap @@ -61,7 +61,7 @@ class Event: def __post_init__(self) -> None: if self.duration < 0.0: - raise LabGraphError("event cannot have a negative duration.") + raise LabgraphError("event cannot have a negative duration.") def __hash__(self) -> int: # Needed for usage as dictionary key return hash(id(self)) @@ -126,7 +126,7 @@ def _add_start_event(self, event: Event) -> None: Adds `event` to the heap as the first event of the graph. """ if event.delay != 0.0: - raise LabGraphError("start_event cannot have a non-zero delay.") + raise LabgraphError("start_event cannot have a non-zero delay.") self._push_heap_entry(event, 0.0) def _get_accumulated_time( @@ -134,12 +134,12 @@ def _get_accumulated_time( ) -> float: accumulated_time = self._accumulated_times.get(previous_event) if accumulated_time is None: - raise LabGraphError("previous_event has not been inserted yet.") + raise LabgraphError("previous_event has not been inserted yet.") if add_duration: accumulated_time += previous_event.duration accumulated_time += event.delay if accumulated_time < 0.0: - raise LabGraphError("event occurs before start time.") + raise LabgraphError("event occurs before start time.") return accumulated_time def _push_heap_entry(self, event: Event, accumulated_time: float) -> None: diff --git a/labgraph/events/event_generator_node.py b/labgraph/events/event_generator_node.py index fb7ea75d4..f631f63d2 100644 --- a/labgraph/events/event_generator_node.py +++ b/labgraph/events/event_generator_node.py @@ -2,7 +2,7 @@ # Copyright 2004-present Facebook. All Rights Reserved. from abc import abstractmethod -from time import time # TODO: Replace with LabGraph clock +from time import time # TODO: Replace with Labgraph clock from typing import Any, Dict, List, Tuple from ..graphs.method import AsyncPublisher, get_method_metadata diff --git a/labgraph/events/tests/test_event_generator.py b/labgraph/events/tests/test_event_generator.py index 4a4bfa3d9..d13c05858 100644 --- a/labgraph/events/tests/test_event_generator.py +++ b/labgraph/events/tests/test_event_generator.py @@ -7,7 +7,7 @@ from ...graphs import Topic from ...messages import Message, TimestampedMessage -from ...util.error import LabGraphError +from ...util.error import LabgraphError from .. import ( BaseEventGenerator, DeferredMessage, @@ -93,7 +93,7 @@ def test_event_init_negative_duration(mocker: Any) -> None: MyMessage, "unittest_args", kwargs_field="unittest_kwargs" ) topic = Topic(MyMessage) - with pytest.raises(LabGraphError): + with pytest.raises(LabgraphError): _ = Event(message, topic, 0.0, -1.0) @@ -124,7 +124,7 @@ def test_event_graph_init_bad_start_event(mocker: Any) -> None: ) topic = Topic(MyMessage) event = Event(message, topic, -1.0) - with pytest.raises(LabGraphError): + with pytest.raises(LabgraphError): _ = EventGraph(event) @@ -157,7 +157,7 @@ def test_event_graph_accumulated_time_no_previous(mocker: Any) -> None: graph = EventGraph(start) parent = Event(message, topic, 0.0, 1.0) child = Event(message, topic, 0.0, 1.0) - with pytest.raises(LabGraphError): + with pytest.raises(LabgraphError): graph.add_event_at_end(child, parent) @@ -171,7 +171,7 @@ def test_event_graph_accumulated_time_before_start(mocker: Any) -> None: parent = Event(message, topic, 0.0, 1.0) child = Event(message, topic, -3.0, 1.0) graph.add_event_at_end(parent, start) - with pytest.raises(LabGraphError): + with pytest.raises(LabgraphError): graph.add_event_at_end(child, parent) diff --git a/labgraph/graphs/config.py b/labgraph/graphs/config.py index c37068ba1..9f23c3403 100644 --- a/labgraph/graphs/config.py +++ b/labgraph/graphs/config.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional from ..messages.message import Field, Message -from ..util.error import LabGraphError +from ..util.error import LabgraphError class Config(Message): @@ -59,7 +59,7 @@ def _add_field_to_argument_parser( required=field.required, ) else: - raise LabGraphError( + raise LabgraphError( "Invalid type for argument parsing for config field. " f"'{field_name}' has type '{python_type.__name__}'" ) diff --git a/labgraph/graphs/cpp_node.py b/labgraph/graphs/cpp_node.py index 3f8b7410a..de11c5cc8 100644 --- a/labgraph/graphs/cpp_node.py +++ b/labgraph/graphs/cpp_node.py @@ -6,11 +6,11 @@ from labgraph_cpp import Node as _Node, NodeBootstrapInfo # type: ignore -# HACK: Import from LabGraph's wrapper of Cthulhu before importing dynamic libs to set +# HACK: Import from Labgraph's wrapper of Cthulhu before importing dynamic libs to set # the shared memory name from .._cthulhu import bindings # noqa: F401 from ..messages.message import Message -from ..util.error import LabGraphError +from ..util.error import LabgraphError from .config import Config from .method import Subscriber, background from .node import Node diff --git a/labgraph/graphs/graph.py b/labgraph/graphs/graph.py index 9bca19e35..03ea4e0ff 100644 --- a/labgraph/graphs/graph.py +++ b/labgraph/graphs/graph.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple -from ..util.error import LabGraphError +from ..util.error import LabgraphError from ..util.logger import get_logger from .config import Config from .group import Group, GroupMeta @@ -29,7 +29,7 @@ def __init__( if hasattr(cls, "__annotations__"): if "config" in cls.__annotations__: if not issubclass(cls.__annotations__["config"], Config): - raise LabGraphError( + raise LabgraphError( "The config for a Graph must be a subclass of Config, got " f"{cls.__annotations__['config']}" ) @@ -101,11 +101,11 @@ def _validate_topics(self) -> None: message += "".join(sorted(submessages)) message += ( "This could mean that there are publishers and/or subscribers of " - "Cthulhu streams that LabGraph doesn't know about, and/or that data " + "Cthulhu streams that Labgraph doesn't know about, and/or that data " "in some topics is being discarded.\n" ) - # TODO: We warn instead of raising an error because LabGraph currently + # TODO: We warn instead of raising an error because Labgraph currently # tries to run any publishers/subscribers it knows about as async functions, # so for now we keep it ignorant of C++ publisher/subcriber methods. logger.warning(message.strip()) diff --git a/labgraph/graphs/group.py b/labgraph/graphs/group.py index 5f659a370..cbbcb270c 100644 --- a/labgraph/graphs/group.py +++ b/labgraph/graphs/group.py @@ -14,7 +14,7 @@ # the shared memory name from .._cthulhu import bindings as cthulhu # noqa: F401 from ..messages.message import Message -from ..util.error import LabGraphError +from ..util.error import LabgraphError from .config import Config from .cpp_node import CPPNode, CPPNodeConfig from .method import Main, Publisher, Subscriber @@ -215,7 +215,7 @@ def _get_streams(self) -> Dict[str, Stream]: error_message += ( f"- {topic_path}: {message_types_by_topic[topic_path]}\n" ) - raise LabGraphError(error_message) + raise LabgraphError(error_message) if len(message_types) == 0: warning_message = ( @@ -284,7 +284,7 @@ def connections(self) -> "Connections": def _is_valid_child_name(self, name: str) -> bool: """ Returns true if the given name is valid for a child of this group. Notably, - returns false for LabGraph-internal fields starting with "__" as well as + returns false for Labgraph-internal fields starting with "__" as well as special fields like `config`. """ if name.startswith("__"): diff --git a/labgraph/graphs/method.py b/labgraph/graphs/method.py index b1bbdc153..3919d183b 100644 --- a/labgraph/graphs/method.py +++ b/labgraph/graphs/method.py @@ -17,7 +17,7 @@ from typing_extensions import Protocol -from ..util.error import LabGraphError +from ..util.error import LabgraphError from .topic import Topic @@ -27,7 +27,7 @@ class AsyncPublisher(Protocol): """ Convenience return type for async publisher methods. An async method that yields - tuples of LabGraph topics and messages can be typed as returning this. + tuples of Labgraph topics and messages can be typed as returning this. For example: ``` @@ -57,7 +57,7 @@ def __aiter__(self) -> AsyncIterator[Tuple[Topic, Any]]: class NodeMethod(ABC): """ - Represents a method on a Node that has been decorated by LabGraph. Subclasses + Represents a method on a Node that has been decorated by Labgraph. Subclasses include topic paths, if applicable. """ @@ -71,7 +71,7 @@ def __init__(self, name: str) -> None: @dataclass class MethodMetadata: """ - Represents metadata on a method that is created by a LabGraph decorator. The + Represents metadata on a method that is created by a Labgraph decorator. The `MethodMetadata` is used to validate the decorator usage and then construct a `NodeMethod`. """ @@ -119,43 +119,43 @@ def node_method(self) -> NodeMethod: elif self.is_main: return Main(name=self.name) else: - raise LabGraphError("Unexpected NodeMethod type") + raise LabgraphError("Unexpected NodeMethod type") def validate(self) -> None: for i, topic1 in enumerate(self.published_topics): for j, topic2 in enumerate(self.published_topics): if i != j and topic1 is topic2: - raise LabGraphError( + raise LabgraphError( f"Method '{self.name}' got two @publisher decorators for the " "same topic" ) if len(self.published_topics) > 0: if self.is_background: - raise LabGraphError( + raise LabgraphError( f"Method '{self.name}' cannot have both a @{publisher.__name__} " f"decorator and a @{background.__name__} decorator" ) if self.is_main: - raise LabGraphError( + raise LabgraphError( f"Method '{self.name}' cannot have both a @{publisher.__name__} " f"decorator and a @{main.__name__} decorator" ) if self.subscribed_topic is not None: if self.is_background: - raise LabGraphError( + raise LabgraphError( f"Method '{self.name}' cannot have both a @{subscriber.__name__} " f"decorator and a @{background.__name__} decorator" ) if self.is_main: - raise LabGraphError( + raise LabgraphError( f"Method '{self.name}' cannot have both a @{subscriber.__name__} " f"decorator and a @{main.__name__} decorator" ) if self.is_background and self.is_main: - raise LabGraphError( + raise LabgraphError( f"Method '{self.name}' cannot have both a @{background.__name__} " f"decorator and a @{main.__name__} decorator" ) @@ -178,7 +178,7 @@ def get_method_metadata(method: Callable[..., Any]) -> MethodMetadata: class Publisher(NodeMethod): """ - Represents a LabGraph method decorated by `@publisher`. + Represents a Labgraph method decorated by `@publisher`. """ published_topic_paths: Tuple[str, ...] @@ -205,7 +205,7 @@ def publisher_wrapper(method: PublisherType) -> PublisherType: class Subscriber(NodeMethod): """ - Represents a LabGraph method decorated by `@subscriber`. + Represents a Labgraph method decorated by `@subscriber`. """ subscribed_topic_path: str @@ -235,7 +235,7 @@ def subscriber_wrapper(method: SubscriberType) -> SubscriberType: or method.__code__.co_varnames[1] != list(annotations.keys())[0] # TODO: We could also check the return type here ): - raise LabGraphError( + raise LabgraphError( f"Expected subscriber '{method.__name__}' to have signature def " f"{method.__name__}(self, message: {topic.message_type.__name__}) -> " "None" @@ -243,7 +243,7 @@ def subscriber_wrapper(method: SubscriberType) -> SubscriberType: metadata = get_method_metadata(method) if metadata.subscribed_topic is not None: - raise LabGraphError( + raise LabgraphError( f"Method '{metadata.name}' already has a @{subscriber.__name__} " "decorator" ) @@ -257,7 +257,7 @@ def subscriber_wrapper(method: SubscriberType) -> SubscriberType: class Transformer(Publisher, Subscriber): """ - Represents a LabGraph method decorated by both `@publisher` and `@subscriber`. + Represents a Labgraph method decorated by both `@publisher` and `@subscriber`. """ def __init__( @@ -272,7 +272,7 @@ def __init__( class Background(NodeMethod): """ - Represents a LabGraph method decorated by `@background`. + Represents a Labgraph method decorated by `@background`. """ def __init__(self, name: str) -> None: @@ -288,7 +288,7 @@ def background(method: BackgroundType) -> BackgroundType: """ metadata = get_method_metadata(method) if metadata.is_background: - raise LabGraphError( + raise LabgraphError( f"Method '{metadata.name}' already has a @{background.__name__} decorator" ) metadata.is_background = True @@ -298,7 +298,7 @@ def background(method: BackgroundType) -> BackgroundType: class Main(NodeMethod): """ - Represents a LabGraph method decorated by `@main`. + Represents a Labgraph method decorated by `@main`. """ def __init__(self, name: str) -> None: @@ -315,7 +315,7 @@ def main(method: MainType) -> MainType: """ metadata = get_method_metadata(method) if metadata.is_main: - raise LabGraphError( + raise LabgraphError( f"Method '{metadata.name}' already has a @{main.__name__} decorator" ) metadata.is_main = True diff --git a/labgraph/graphs/module.py b/labgraph/graphs/module.py index d70c1cf65..efc44ed09 100644 --- a/labgraph/graphs/module.py +++ b/labgraph/graphs/module.py @@ -8,7 +8,7 @@ import typeguard from ..messages.message import Message -from ..util.error import LabGraphError +from ..util.error import LabgraphError from ..util.random import random_string from .config import Config from .method import ( @@ -63,7 +63,7 @@ def __init__( # Raise if names of topics in multiple base classes collide if topic_name in cls.__topics__: - raise LabGraphError( + raise LabgraphError( f"Base classes of {cls.__name__} have conflicting topics named " f"{topic_name}" ) @@ -78,7 +78,7 @@ def __init__( # Raise if a topic object was already used by a module (i.e., its _name) was # set if field_value._name is not None: - raise LabGraphError( + raise LabgraphError( "Duplicate topic object found: please assign different Topic " f"objects to values {field_value.name} and {cls.__name__}." f"{field_name}" @@ -88,7 +88,7 @@ def __init__( # Raise if a topic name collides with a superclass's topic name if field_name in cls.__topics__: - raise LabGraphError( + raise LabgraphError( f"Topic {cls.__name__}/{field_name} hides superclass's topic" ) @@ -105,7 +105,7 @@ def __init__( class Module(ABC, metaclass=ModuleMeta): """ - An abstraction for a LabGraph component that can be run within a single process. + An abstraction for a Labgraph component that can be run within a single process. """ state: State @@ -163,7 +163,7 @@ def config(self) -> Config: # message type with fields that all have default values) self._config = self.__class__.__config_type__() except TypeError: - raise LabGraphError( + raise LabgraphError( f"Configuration not set. Call {self.__class__.__name__}.configure() to set the " "configuration." ) @@ -302,7 +302,7 @@ def main(self) -> Tuple[Optional[str], Optional[Main]]: main_methods = self._get_methods_of_type(Main) if len(main_methods) > 1: method_names = ", ".join(main_methods.keys()) - raise LabGraphError( + raise LabgraphError( "Cannot have multiple methods decorated with @main in nodes in the " f"same process: found methods {method_names}" ) @@ -318,7 +318,7 @@ def _stream_for_topic_path(self, topic_path: str) -> Stream: for stream in self.__streams__.values(): if topic_path in stream.topic_paths: return stream - raise LabGraphError(f"Topic '{topic_path}' is not in a stream") + raise LabgraphError(f"Topic '{topic_path}' is not in a stream") def _get_topic_path(self, topic: Topic) -> str: """ @@ -328,7 +328,7 @@ def _get_topic_path(self, topic: Topic) -> str: if topic is candidate_topic: return topic_path - raise LabGraphError( + raise LabgraphError( f"Could not find topic '{topic.name}' in module {self.__class__.__name__}" ) @@ -343,7 +343,7 @@ def _get_module_path(self, module: "Module") -> str: if candidate_module is module: return module_path - raise LabGraphError( + raise LabgraphError( f"Could not find module '{module.__class__.__name__}' ({module.id}) in " f"module {self.__class__.__name__}" ) @@ -381,4 +381,4 @@ def _validate_streams(self) -> None: error_message += "\n".join( sorted(f"- {path}" for path in publisher_paths) ) - raise LabGraphError(error_message) + raise LabgraphError(error_message) diff --git a/labgraph/graphs/node.py b/labgraph/graphs/node.py index cc7d2413f..56aadf8c5 100644 --- a/labgraph/graphs/node.py +++ b/labgraph/graphs/node.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Any, Dict, Optional, Tuple -from ..util.error import LabGraphError +from ..util.error import LabgraphError from .config import Config from .method import _METADATA_LABEL, NodeMethod, get_method_metadata from .module import Module, ModuleMeta @@ -55,7 +55,7 @@ def __init__( if topic is subscribed_topic else "published" ) - raise LabGraphError( + raise LabgraphError( f"Invalid topic {topic_verb} by {cls.__name__}." f"{field_name} - set the topic as a class variable " f"in {cls.__name__}" @@ -73,7 +73,7 @@ def __init__( for other_method_name, other_method in other_methods.items(): if get_method_metadata(other_method).is_main: - raise LabGraphError( + raise LabgraphError( f"Cannot have multiple methods decorated with @main in " f"{name}: found methods '{field_name}' and " f"'{other_method_name}'" diff --git a/labgraph/graphs/node_test_harness.py b/labgraph/graphs/node_test_harness.py index a7c69c67a..915138c1e 100644 --- a/labgraph/graphs/node_test_harness.py +++ b/labgraph/graphs/node_test_harness.py @@ -38,7 +38,7 @@ class NodeTestHarness(Generic[N]): """ - Utility class for testing LabGraph nodes. This allows a user to test some behavior + Utility class for testing Labgraph nodes. This allows a user to test some behavior of a node in an asyncio event loop, with the harness taking care of setting up and cleaning up the node. diff --git a/labgraph/graphs/state.py b/labgraph/graphs/state.py index ac9c736b6..e3c3d7f46 100644 --- a/labgraph/graphs/state.py +++ b/labgraph/graphs/state.py @@ -18,7 +18,7 @@ def __init__( class State(metaclass=StateMeta): """ - Represents the state of a LabGraph module. State objects are useful when we would + Represents the state of a Labgraph module. State objects are useful when we would like the module to use some memory to run its algorithm while also being able to: - bootstrap that module into some state in the "middle" of its algorithm diff --git a/labgraph/graphs/tests/test_config.py b/labgraph/graphs/tests/test_config.py index 3f454a892..6f14df7c8 100644 --- a/labgraph/graphs/tests/test_config.py +++ b/labgraph/graphs/tests/test_config.py @@ -5,7 +5,7 @@ import pytest -from ...util import LabGraphError +from ...util import LabgraphError from ..config import Config from ..node import Node @@ -89,7 +89,7 @@ def test_node_no_config() -> None: """ node = MyNode() - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: node.setup() assert ( diff --git a/labgraph/graphs/tests/test_graph.py b/labgraph/graphs/tests/test_graph.py index acadcae8a..86b7a3568 100644 --- a/labgraph/graphs/tests/test_graph.py +++ b/labgraph/graphs/tests/test_graph.py @@ -4,6 +4,7 @@ from typing import Any, Dict from ...messages.message import Message +from ...util.error import LabgraphError from ..graph import Graph from ..group import Connections, Group from ..method import AsyncPublisher, publisher, subscriber @@ -70,7 +71,7 @@ def logging(self) -> Dict[str, Topic]: "\t- MY_COMPONENT/MY_CHILD2/MY_NODE1/A has no publishers\n" "\t- MY_COMPONENT/MY_CHILD2/MY_NODE2/B has no subscribers\n" "This could mean that there are publishers and/or subscribers of Cthulhu " - "streams that LabGraph doesn't know about, and/or that data in some topics is " + "streams that Labgraph doesn't know about, and/or that data in some topics is " "being discarded." ) diff --git a/labgraph/graphs/tests/test_group.py b/labgraph/graphs/tests/test_group.py index ebfd53210..330815585 100644 --- a/labgraph/graphs/tests/test_group.py +++ b/labgraph/graphs/tests/test_group.py @@ -6,7 +6,7 @@ import pytest from ...messages.message import Message -from ...util.error import LabGraphError +from ...util.error import LabgraphError from ..group import Connections, Group from ..method import AsyncPublisher, NodeMethod, Transformer, publisher, subscriber from ..module import Module @@ -245,7 +245,7 @@ def connections(self) -> Connections: def test_bad_publishers_group() -> None: - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: _ = BadPublishersGroup() assert ( diff --git a/labgraph/graphs/tests/test_harness.py b/labgraph/graphs/tests/test_harness.py index da6c6cdcd..1769f8363 100644 --- a/labgraph/graphs/tests/test_harness.py +++ b/labgraph/graphs/tests/test_harness.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. +import asyncio import dataclasses import pytest from ...messages.message import Message +from ...util.testing import get_event_loop from ..config import Config from ..method import AsyncPublisher, publisher, subscriber from ..node import Node @@ -130,6 +132,10 @@ def test_run_with_harness_max_num_results() -> None: # Check that using max_num_results with a non-generator raises with pytest.raises(TypeError): run_with_harness(MyNode, _not_a_generator, max_num_results=1) # type: ignore + loop = get_event_loop() + for task in asyncio.Task.all_tasks(loop=loop): + task.cancel() + loop.run_until_complete(loop.shutdown_asyncgens()) def test_run_async_max_num_results() -> None: @@ -146,6 +152,10 @@ def test_run_async_max_num_results() -> None: # Check that using max_num_results with a non-generator raises with pytest.raises(TypeError): run_async(_not_a_generator, max_num_results=1) # type: ignore + loop = get_event_loop() + for task in asyncio.Task.all_tasks(loop=loop): + task.cancel() + loop.run_until_complete(loop.shutdown_asyncgens()) async def _not_a_generator() -> None: diff --git a/labgraph/graphs/tests/test_node.py b/labgraph/graphs/tests/test_node.py index 8aa2f4fbe..fc7f2fb7b 100644 --- a/labgraph/graphs/tests/test_node.py +++ b/labgraph/graphs/tests/test_node.py @@ -6,7 +6,7 @@ import pytest from ...messages.message import Message -from ...util.error import LabGraphError +from ...util.error import LabgraphError from ..method import ( AsyncPublisher, NodeMethod, @@ -77,7 +77,7 @@ def test_node_methods(node_type: Type[Node], methods: Dict[str, NodeMethod]) -> def test_node_invalid_published_topic() -> None: bad_topic = Topic(MyMessage) - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: class BadTopicNode(Node): @publisher(bad_topic) @@ -92,7 +92,7 @@ def my_publisher(self) -> AsyncPublisher: def test_node_invalid_subscribed_topic() -> None: bad_topic = Topic(MyMessage) - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: class BadTopicNode(Node): @subscriber(bad_topic) @@ -118,7 +118,7 @@ async def bad_publisher2(self) -> AsyncPublisher: def test_bad_publishers_node() -> None: - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: _ = BadPublishersNode() assert ( diff --git a/labgraph/graphs/tests/test_publisher.py b/labgraph/graphs/tests/test_publisher.py index ebdb0db75..e746547f9 100644 --- a/labgraph/graphs/tests/test_publisher.py +++ b/labgraph/graphs/tests/test_publisher.py @@ -4,7 +4,7 @@ import pytest from ...messages.message import Message -from ...util.error import LabGraphError +from ...util.error import LabgraphError from ..method import AsyncPublisher, publisher from ..topic import Topic @@ -14,7 +14,7 @@ class MyMessage(Message): def test_duplicate_publisher() -> None: - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: A = Topic(MyMessage) @publisher(A) diff --git a/labgraph/graphs/tests/test_subscriber.py b/labgraph/graphs/tests/test_subscriber.py index cb4cde766..0b88c2fa7 100644 --- a/labgraph/graphs/tests/test_subscriber.py +++ b/labgraph/graphs/tests/test_subscriber.py @@ -4,7 +4,7 @@ import pytest from ...messages.message import Message -from ...util.error import LabGraphError +from ...util.error import LabgraphError from ..method import subscriber from ..node import Node from ..topic import Topic @@ -19,7 +19,7 @@ def test_duplicate_subscriber() -> None: Tests that an error is thrown when multiple subscriber decorators are applied to a method. """ - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: class MyNode(Node): A = Topic(MyMessage) @@ -39,7 +39,7 @@ def test_subscriber_signature() -> None: Tests that an error is thrown when a subscriber has an invalid signature for message callbacks. """ - with pytest.raises(LabGraphError) as err: + with pytest.raises(LabgraphError) as err: class MyNode(Node): A = Topic(MyMessage) diff --git a/labgraph/graphs/topic.py b/labgraph/graphs/topic.py index 073d83be0..ef5677762 100644 --- a/labgraph/graphs/topic.py +++ b/labgraph/graphs/topic.py @@ -7,7 +7,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Type, Union, cast from ..messages.message import Message -from ..util.error import LabGraphError +from ..util.error import LabgraphError PATH_DELIMITER = "/" diff --git a/labgraph/loggers/hdf5/logger.py b/labgraph/loggers/hdf5/logger.py index d5504e904..b3ff8ce47 100644 --- a/labgraph/loggers/hdf5/logger.py +++ b/labgraph/loggers/hdf5/logger.py @@ -19,11 +19,14 @@ BytesType, CFloatType, CIntType, + DataclassType, + DictType, DynamicType, FieldType, FloatType, IntEnumType, IntType, + ListType, NumpyDynamicType, NumpyType, StrDynamicType, @@ -31,11 +34,18 @@ StrType, T, ) -from ...util.error import LabGraphError +from ...util.error import LabgraphError from ..logger import Logger HDF5_PATH_DELIMITER = "/" +SERIALIZABLE_DYNAMIC_TYPES = ( + ListType, + DataclassType, + DictType, + NumpyDynamicType, + StrEnumType, +) logger = logging.getLogger(__name__) @@ -92,14 +102,12 @@ def write(self, messages_by_logging_id: Mapping[str, Sequence[Message]]) -> None # Convert dynamic-length bytes fields into numpy arrays so h5py can # read/write them message_fields = list(message.astuple()) + fields = list(message.__class__.__message_fields__.values()) for j, value in enumerate(message_fields): - if not isinstance( - list(message.__class__.__message_fields__.values())[ - j - ].data_type, - DynamicType, - ): + if not isinstance(fields[j].data_type, DynamicType): continue + if isinstance(fields[j].data_type, SERIALIZABLE_DYNAMIC_TYPES): + value = fields[j].data_type.preprocess(value) if isinstance(value, bytes): message_fields[j] = np.array(bytearray(value)) elif isinstance(value, bytearray): @@ -124,15 +132,11 @@ def cleanup(self) -> None: def get_numpy_type_for_field_type( field_type: FieldType[T], ) -> Union[Tuple[np.dtype], Tuple[np.dtype, Tuple[int, ...]]]: - if ( - isinstance(field_type, StrType) - or isinstance(field_type, BytesType) - or isinstance(field_type, StrEnumType) - ): + if isinstance(field_type, StrType) or isinstance(field_type, BytesType): encoding = ( - field_type.encoding if field_type in (StrType, StrEnumType) else "ascii" + field_type.encoding if isinstance(field_type, StrType) else "ascii" # type: ignore ) - return (h5py.string_dtype(encoding=encoding, length=field_type.length),) + return (h5py.string_dtype(encoding=encoding, length=field_type.length),) # type: ignore elif isinstance(field_type, IntType) or isinstance(field_type, IntEnumType): return (get_numpy_type_for_int_type(field_type),) elif isinstance(field_type, FloatType): @@ -144,7 +148,7 @@ def get_numpy_type_for_field_type( elif isinstance(field_type, DynamicType): return (get_dynamic_type(field_type),) - raise LabGraphError(f"No equivalent numpy type for field type: {field_type}") + raise LabgraphError(f"No equivalent numpy type for field type: {field_type}") def get_numpy_type_for_int_type(int_type: Union[IntType, IntEnumType[T_I]]) -> np.dtype: @@ -171,7 +175,5 @@ def get_numpy_type_for_float_type(float_type: FloatType) -> np.dtype: def get_dynamic_type(field_type: FieldType[Any]) -> np.dtype: if isinstance(field_type, StrDynamicType): return h5py.string_dtype(encoding=field_type.encoding) - elif isinstance(field_type, NumpyDynamicType): - return h5py.vlen_dtype(field_type.dtype) else: return h5py.vlen_dtype(np.uint8) diff --git a/labgraph/loggers/hdf5/reader.py b/labgraph/loggers/hdf5/reader.py new file mode 100644 index 000000000..2ca895995 --- /dev/null +++ b/labgraph/loggers/hdf5/reader.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import logging +import sys +from typing import Any, BinaryIO, Dict, List, Optional, Type, Union + +import h5py + +from ...messages.message import Message +from ...messages.types import ( + BoolType, + BytesDynamicType, + BytesType, + FieldType, + FloatType, + IntEnumType, + IntType, + NumpyDynamicType, + NumpyType, + StrDynamicType, + StrType, + T, +) +from .logger import SERIALIZABLE_DYNAMIC_TYPES + +FILELIKE_T = Union[str, BinaryIO] +LOGGER = logging.getLogger(__name__) +if sys.version_info > (3, 8): + STR_TYPES = (StrType, StrDynamicType) +else: + STR_TYPES = (StrType,) + + +class HDF5Reader: + def __init__(self, path: FILELIKE_T, log_types: Dict[str, Type[Message]]) -> None: + self.path = path + self.log_types = log_types + self._logs: Optional[Dict[str, List[Message]]] = None + + @property + def logs(self) -> Optional[Dict[str, List[Message]]]: + if self._logs is None: + self._parse() + return self._logs + + def _parse(self) -> None: + self._logs = {} + with h5py.File(self.path, "r") as f: + for key, type_ in self.log_types.items(): + if key not in f: + LOGGER.warning(f"{key} not found in h5 file, skipping.") + continue + messages = [] + for raw in f[key]: + kwargs = {} + raw_values = tuple(raw) + for index, field in enumerate(type_.__message_fields__.values()): + if isinstance(field.data_type, SERIALIZABLE_DYNAMIC_TYPES): + value = field.data_type.postprocess( + bytes(raw_values[index]) + ) + else: + value = get_deserialized_value( + raw_values[index], field.data_type + ) + kwargs[field.name] = value + messages.append(type_(**kwargs)) + self._logs[key] = messages + + +def get_deserialized_value(value: Any, field_type: FieldType[T]) -> Any: + if isinstance(field_type, BoolType): + return bool(value) + elif isinstance(field_type, (BytesDynamicType, BytesType)): + return bytes(value) + elif isinstance(field_type, FloatType): + return float(value) + elif isinstance(field_type, IntEnumType): + return field_type.enum_type(int(value)) + elif isinstance(field_type, IntType): + return int(value) + elif isinstance(field_type, NumpyDynamicType): + return field_type.postprocess(bytes(value)) + elif isinstance(field_type, NumpyType): + return value + elif isinstance(field_type, STR_TYPES): + return bytes(value).decode(field_type.encoding) # type: ignore + else: + return value diff --git a/labgraph/loggers/hdf5/tests/__init__.py b/labgraph/loggers/hdf5/tests/__init__.py index 860ac27c6..47642f1ff 100644 --- a/labgraph/loggers/hdf5/tests/__init__.py +++ b/labgraph/loggers/hdf5/tests/__init__.py @@ -1,2 +1,4 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. + +__all__ = [] diff --git a/labgraph/loggers/hdf5/tests/test_logger.py b/labgraph/loggers/hdf5/tests/test_logger.py index 06cf097e4..0189d369a 100644 --- a/labgraph/loggers/hdf5/tests/test_logger.py +++ b/labgraph/loggers/hdf5/tests/test_logger.py @@ -1,56 +1,17 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. -import asyncio -import functools -import random -import tempfile -from enum import Enum -from pathlib import Path -from typing import List, Tuple +import sys import h5py -from ....graphs.graph import Graph -from ....graphs.method import AsyncPublisher, publisher, subscriber -from ....graphs.node_test_harness import run_with_harness -from ....graphs.stream import Stream -from ....graphs.topic import Topic -from ....messages.message import Message from ....messages.types import ( - BytesType, DynamicType, StrDynamicType, - StrEnumType, StrType, ) -from ....util.random import random_string -from ...logger import LoggerConfig -from ..logger import HDF5Logger - - -NUM_MESSAGES = 100 - - -class MyIntEnum(int, Enum): - A = 1 - B = 2 - - -class MyStrEnum(str, Enum): - A = "A" - B = "B" - - -class MyMessage(Message): - int_field: int - str_field: str - float_field: float - bool_field: bool - bytes_field: bytes - int_enum_field: MyIntEnum - str_enum_field: MyStrEnum - fixed_bytes_field: BytesType(length=10) # type: ignore +from ..logger import HDF5Logger, SERIALIZABLE_DYNAMIC_TYPES +from .test_utils import LOGGING_IDS, write_logs_to_hdf5 def test_hdf5_logger() -> None: @@ -58,70 +19,33 @@ def test_hdf5_logger() -> None: Tests that we can write messages to an HDF5 file and then read them back. """ - logging_ids = ("test1", "test2") - logging_ids_and_messages = [] - for i in range(NUM_MESSAGES): - message = MyMessage( - int_field=i, - str_field=str(i), - float_field=float(i), - bool_field=i % 2 == 0, - bytes_field=str(i).encode("ascii"), - int_enum_field=list(MyIntEnum)[i % 2], - str_enum_field=list(MyStrEnum)[i % 2], - fixed_bytes_field=b"0123456789", - ) - for logging_id in random.sample(logging_ids, k=len(logging_ids)): - logging_ids_and_messages.append((logging_id, message)) - - output_directory = tempfile.gettempdir() - recording_name = random_string(16) - config = LoggerConfig( - output_directory=output_directory, recording_name=recording_name - ) - run_with_harness( - HDF5Logger, - functools.partial(_test_fn, logging_ids_and_messages=logging_ids_and_messages), - config=config, - ) + if sys.version_info > (3, 8): + str_types = (StrType, StrDynamicType) + else: + str_types = (StrType,) + # Write the messages to a file + output_path, logging_ids_and_messages = write_logs_to_hdf5(HDF5Logger) # Read the messages back from the file and compare to the messages array - output_path = Path(output_directory) / Path(f"{recording_name}.h5") with h5py.File(str(output_path), "r") as h5py_file: - for logging_id in logging_ids: + for logging_id in LOGGING_IDS: messages = [l[1] for l in logging_ids_and_messages if l[0] == logging_id] for i, message in enumerate(messages): for field in message.__class__.__message_fields__.values(): expected_value = getattr(message, field.name) actual_value = h5py_file[logging_id][i][field.name] - if isinstance(field.data_type, StrType) or isinstance( - field.data_type, StrEnumType - ): + if isinstance(field.data_type, str_types): assert ( actual_value.decode(field.data_type.encoding) == expected_value ) + elif isinstance(field.data_type, SERIALIZABLE_DYNAMIC_TYPES): + actual_value = field.data_type.postprocess(bytes(actual_value)) + assert actual_value == expected_value elif isinstance(field.data_type, DynamicType) and not isinstance( field.data_type, StrDynamicType ): assert bytes(actual_value) == expected_value else: assert actual_value == expected_value - - -async def _test_fn( - logger: HDF5Logger, logging_ids_and_messages: List[Tuple[str, Message]] -) -> None: - await asyncio.gather( - logger.run_logger(), _write_messages(logger, logging_ids_and_messages) - ) - - -async def _write_messages( - logger: HDF5Logger, logging_ids_and_messages: List[Tuple[str, Message]] -) -> None: - for logging_id, message in logging_ids_and_messages: - logger.buffer_message(logging_id, message) - await asyncio.sleep(0.01) - logger.running = False \ No newline at end of file diff --git a/labgraph/loggers/hdf5/tests/test_reader.py b/labgraph/loggers/hdf5/tests/test_reader.py new file mode 100644 index 000000000..b26fbebb0 --- /dev/null +++ b/labgraph/loggers/hdf5/tests/test_reader.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +from ..logger import HDF5Logger +from ..reader import HDF5Reader +from .test_utils import ( + MyDataclass, + MyIntEnum, + MyMessage, + MyStrEnum, + NUM_MESSAGES, + write_logs_to_hdf5, +) + + +def test_hdf5_reader() -> None: + path, _ = write_logs_to_hdf5(HDF5Logger) + log_types = { + "test1": MyMessage, + "test2": MyMessage, + } + reader = HDF5Reader(path, log_types) + for index in range(NUM_MESSAGES): + expected = MyMessage( + int_field=index, + str_field=str(index), + float_field=float(index), + bool_field=index % 2 == 0, + bytes_field=str(index).encode("ascii"), + int_enum_field=list(MyIntEnum)[index % 2], + str_enum_field=list(MyStrEnum)[index % 2], + fixed_bytes_field=b"0123456789", + list_field=[5, 6, 7], + dict_field={"test_key": "test_val"}, + dataclass_field=MyDataclass(sub_int_field=7, sub_str_field="seven"), + ) + assert reader.logs["test1"][index] == expected + assert reader.logs["test2"][index] == expected diff --git a/labgraph/loggers/hdf5/tests/test_utils.py b/labgraph/loggers/hdf5/tests/test_utils.py new file mode 100644 index 000000000..4ba29196c --- /dev/null +++ b/labgraph/loggers/hdf5/tests/test_utils.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +""" +Module containing utilities used in HDF5 unit tests. + +This includes a Message that has all known field types. +- When a new field type is added it should be updated here to unit test correctly. +""" + + +import asyncio +import functools +import random +import tempfile +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Dict, List, Tuple, Type + +from ....graphs.node_test_harness import run_with_harness +from ....messages.message import Message +from ....messages.types import BytesType +from ....util.random import random_string +from ...logger import Logger, LoggerConfig + +LOGGING_IDS = ("test1", "test2") +NUM_MESSAGES = 100 + + +class MyIntEnum(int, Enum): + A = 1 + B = 2 + + +class MyStrEnum(str, Enum): + A = "A" + B = "B" + + +@dataclass +class MyDataclass: + sub_int_field: int + sub_str_field: str + + +class MyMessage(Message): + int_field: int + str_field: str + float_field: float + bool_field: bool + bytes_field: bytes + int_enum_field: MyIntEnum + str_enum_field: MyStrEnum + fixed_bytes_field: BytesType(length=10) # type: ignore + list_field: List[int] + dict_field: Dict[str, str] + dataclass_field: MyDataclass + + +async def _test_fn( + logger: Logger, logging_ids_and_messages: List[Tuple[str, Message]] +) -> None: + await asyncio.gather( + logger.run_logger(), _write_messages(logger, logging_ids_and_messages) + ) + + +async def _write_messages( + logger: Logger, logging_ids_and_messages: List[Tuple[str, Message]] +) -> None: + for logging_id, message in logging_ids_and_messages: + logger.buffer_message(logging_id, message) + await asyncio.sleep(0.01) + logger.running = False + + +def write_logs_to_hdf5(logger: Type[Logger]) -> Tuple[Path, List[Tuple[str, Message]]]: + logging_ids_and_messages = [] + for i in range(NUM_MESSAGES): + message = MyMessage( + int_field=i, + str_field=str(i), + float_field=float(i), + bool_field=i % 2 == 0, + bytes_field=str(i).encode("ascii"), + int_enum_field=list(MyIntEnum)[i % 2], + str_enum_field=list(MyStrEnum)[i % 2], + fixed_bytes_field=b"0123456789", + list_field=[5, 6, 7], + dict_field={"test_key": "test_val"}, + dataclass_field=MyDataclass(sub_int_field=7, sub_str_field="seven"), + ) + for logging_id in random.sample(LOGGING_IDS, k=len(LOGGING_IDS)): + logging_ids_and_messages.append((logging_id, message)) + + output_directory = tempfile.gettempdir() + recording_name = random_string(16) + config = LoggerConfig( + output_directory=output_directory, recording_name=recording_name + ) + run_with_harness( + logger, + functools.partial(_test_fn, logging_ids_and_messages=logging_ids_and_messages), + config=config, + ) + return ( + Path(output_directory) / Path(f"{recording_name}.h5"), + logging_ids_and_messages, + ) diff --git a/labgraph/loggers/logger.py b/labgraph/loggers/logger.py index c2ba3e78b..91733e060 100644 --- a/labgraph/loggers/logger.py +++ b/labgraph/loggers/logger.py @@ -42,7 +42,7 @@ class LoggerConfig(Config): disk. Defaults to 1 second. If `None`, the logger will only flush when the buffer is full. streams_by_logging_id: - A dictionary of the LabGraph stream objects by logging id. When specified, + A dictionary of the Labgraph stream objects by logging id. When specified, the logger will subscribe to the Cthulhu streams itself. This should always be provided unless the logger is being unit tested. """ @@ -125,7 +125,7 @@ def flush_buffer(self) -> Dict[str, List[Message]]: def _get_logger_callback( self, logging_id: str, stream: Stream ) -> Callable[[Message], None]: - # We add the correct type annotation so LabGraph knows what message type to + # We add the correct type annotation so Labgraph knows what message type to # deserialize to assert stream.message_type is not None MessageType: Type[Message] = stream.message_type diff --git a/labgraph/messages/message.py b/labgraph/messages/message.py index 93adb0872..5d4c3792f 100644 --- a/labgraph/messages/message.py +++ b/labgraph/messages/message.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. -# Defines simple messaging constructs for LabGraph +# Defines simple messaging constructs for Labgraph import dataclasses import hashlib +import importlib import logging import struct from collections import OrderedDict -from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union +from enum import Enum +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union from .._cthulhu.bindings import ( Field as CthulhuField, @@ -17,26 +19,24 @@ TypeDefinition, typeRegistry, ) -from ..util.error import LabGraphError -from .types import DEFAULT_BYTE_ORDER, FieldType, StructType, get_field_type +from ..util.error import LabgraphError +from .types import ( + DEFAULT_BYTE_ORDER, + FieldType, + LOCAL_INTERNAL_FIELDS, + StructType, + get_field_type, +) logger = logging.getLogger(__name__) T = TypeVar("T") -# Internal fields that are present on Message instances but are not included when -# serializing the message for streaming -LOCAL_INTERNAL_FIELDS = ( - "__sample__", - "__original_message__", - "__original_message_type__", -) - class Field(Generic[T]): """ - Represents a field in a LabGraph message. + Represents a field in a Labgraph message. Args: name: The name of the field. @@ -246,7 +246,7 @@ def _index_of_field(cls, field_name: str) -> int: for i, (_, field) in enumerate(cls.__message_fields__.items()): if field.name == field_name: return i - raise LabGraphError(f"{cls.__name__} has no field '{field_name}'") + raise LabgraphError(f"{cls.__name__} has no field '{field_name}'") class IsOriginalMessage: @@ -258,7 +258,7 @@ class IsOriginalMessage: class Message(metaclass=MessageMeta): """ - Represents a LabGraph message. A message is a collection of data that can be sent + Represents a Labgraph message. A message is a collection of data that can be sent between nodes via topics. The fields available to every message of a certain type are defined via type annotations on the corresponding subclass of `Message`. Subclasses recursively include their superclasses' fields. @@ -457,7 +457,7 @@ def __getattribute__(self, name: str) -> Any: ][0] result = getattr(self.__original_message__, original_field_name) if not field.data_type.isinstance(result): - raise LabGraphError( + raise LabgraphError( f"Could not convert from {message_cls.__name__}." f"{original_field_name} to {cls.__name__}.{name}: invalid " f"value {result}" @@ -493,7 +493,7 @@ def __eq__(self, other: Any) -> bool: class TimestampedMessage(Message): """ - Represents a simple timestamped LabGraph message. All messages which + Represents a simple timestamped Labgraph message. All messages which may be aligned using a timestamp should inherit from this class. """ diff --git a/labgraph/messages/tests/test_message.py b/labgraph/messages/tests/test_message.py index 05e855fd9..b222b1e21 100644 --- a/labgraph/messages/tests/test_message.py +++ b/labgraph/messages/tests/test_message.py @@ -3,6 +3,7 @@ # Unit tests for the Message class. +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List @@ -129,7 +130,7 @@ class MyDynamicNumpyIntMessage(Message): """ field1: str - field2: NumpyDynamicType(dtype=np.int64) # type: ignore + field2: np.ndarray field3: int @@ -140,7 +141,16 @@ class MyDynamicMessage(Message): field1: Dict[str, Any] field2: int - field3: List[Any] + field3: List[int] + + +@dataclass +class MyDataclass: + sub_field1: int + + +class MyDataclassMessage(Message): + field1: MyDataclass class MyInvalidDefaultMessage(Message): @@ -394,16 +404,6 @@ def test_dynamic_numpy_field_with_type() -> None: assert message.field3 == 5 -def test_dynamic_numpy_field_with_invalid_type() -> None: - """ - Tests that we throw an error when we construct a message type with an incorrect - dtype for the dynamic numpy field. - """ - - with pytest.raises(TypeError): - MyDynamicNumpyIntMessage(field1="hello", field2=np.random.rand(3, 3), field3=5) - - def test_static_to_dynamic_conversion() -> None: """ Tests that we can convert a static field to a dynamic field between equivalent @@ -440,13 +440,22 @@ def test_dynamic_fields() -> None: """ field1 = {"key1": 5, "key2": "value2"} field2 = 12 - field3 = [5, "hello", 6.2] + field3 = [5, 6, 7] message = MyDynamicMessage(field1=field1, field2=field2, field3=field3) assert message.field1 == field1 assert message.field2 == field2 assert message.field3 == field3 +def test_dataclass_fields() -> None: + """ + Tests that we can serialize some more dynamic field types. + """ + field1 = MyDataclass(sub_field1=7) + message = MyDataclassMessage(field1=field1) + assert message.field1 == field1 + + def test_invalid_default_field() -> None: """ Tests that a badly-typed default field value raises an error. diff --git a/labgraph/messages/types.py b/labgraph/messages/types.py index 474335809..3be3fbcc9 100644 --- a/labgraph/messages/types.py +++ b/labgraph/messages/types.py @@ -1,18 +1,27 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. +import dataclasses import pickle import struct from abc import ABC, abstractmethod, abstractproperty from enum import Enum from io import BytesIO -from typing import Any, Generic, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import numpy as np import typeguard +DEFAULT_LEN_LENGTH = 10 DEFAULT_STR_LENGTH = 128 +# Internal fields that are present on Message instances but are not included when +# serializing the message for streaming +LOCAL_INTERNAL_FIELDS = ( + "__sample__", + "__original_message__", + "__original_message_type__", +) class ByteOrder(str, Enum): @@ -60,7 +69,7 @@ class CFloatType(Enum): class FieldType(ABC, Generic[T]): """ - Represents a LabGraph field type. Subclasses implement methods that describe how to + Represents a Labgraph field type. Subclasses implement methods that describe how to check for instances of the type and serialize the type. """ @@ -306,49 +315,6 @@ def python_type(self) -> type: return str -T_S = TypeVar("T_S", bound=Enum) # Enumerated string type - - -class StrEnumType(StructType[T_S]): - """ - Represents an enumerated string type. - - Args: - encoding: The encoding to use to serialize strings for this field. - """ - - length: int - encoding: str - - def __init__(self, enum_type: Type[T_S], encoding: str = "utf-8") -> None: - self.length = max(len(item.value) for item in enum_type) - self.encoding = encoding - self.enum_type = enum_type - - @property - def format_string(self) -> str: - return f"{self.length}s" - - def isinstance(self, obj: Any) -> bool: - return isinstance(obj, self.enum_type) - - @property - def description(self) -> str: - return self.enum_type.__name__ - - def preprocess(self, value: T_S) -> bytes: - str_value: str = value.value - return str_value.encode(self.encoding) - - def postprocess(self, value: bytes) -> T_S: - str_value = value.decode(self.encoding).rstrip("\0") - return self.enum_type(str_value) - - @property - def python_type(self) -> type: - return self.enum_type - - class BytesType(StructType[bytes]): """ Represents a bytes type. @@ -461,6 +427,47 @@ def size(self) -> Optional[int]: return None +T_S = TypeVar("T_S", bound=Enum) # Enumerated string type + + +class StrEnumType(DynamicType[T_S]): + """ + Represents an enumerated string type. + + Args: + encoding: The encoding to use to serialize strings for this field. + """ + + def __init__(self, enum_type: Type[T_S], encoding: str = "utf-8") -> None: + self.enum_type = enum_type + self.encoding = encoding + + def isinstance(self, obj: Any) -> bool: + return isinstance(obj, self.enum_type) + + @property + def description(self) -> str: + return self.enum_type.__name__ + + def preprocess(self, obj: T_S) -> bytes: + type_ = pickle.dumps(type(obj)) + type_len = get_len_bytes(type_) + str_ = bytes(obj.value, encoding=self.encoding) + str_len = get_len_bytes(str_) + return type_len + type_ + str_len + str_ + + def postprocess(self, obj_bytes: bytes) -> T_S: + raw, curr = get_next_bytes(obj_bytes, 0) + type_ = pickle.loads(raw) + raw, curr = get_next_bytes(obj_bytes, curr) + str_ = raw.decode(self.encoding) + return type_(str_) + + @property + def python_type(self) -> type: + return self.enum_type + + class ObjectDynamicType(DynamicType[Any]): """ Represents a dynamic field type for any object. This is a fallback type for when we @@ -486,39 +493,41 @@ def postprocess(self, obj_bytes: bytes) -> T: class NumpyDynamicType(DynamicType[np.ndarray]): """ Represents a numpy dynamic field type. - - Args: - dtype: The dtype of a numpy array of this type. """ - dtype: np.dtype - - def __init__(self, dtype: np.dtype = np.float64) -> None: - self.dtype = dtype - @property def python_type(self) -> type: return np.ndarray # type: ignore def isinstance(self, obj: Any) -> bool: - return isinstance(obj, np.ndarray) and obj.dtype == self.dtype + return isinstance(obj, np.ndarray) def preprocess(self, obj: np.ndarray) -> bytes: assert self.isinstance(obj) + values = [] + dtype_bytes = get_packed_value(ObjectDynamicType(), obj.dtype) + dtype_len = get_len_bytes(dtype_bytes) + values.extend([dtype_len, dtype_bytes]) buf = BytesIO() np.save(buf, obj) buf.seek(0) - return buf.read() + array_bytes = buf.read() + array_len = get_len_bytes(array_bytes) + values.extend([array_len, array_bytes]) + return b"".join(values) def postprocess(self, obj_bytes: bytes) -> np.ndarray: - buf = BytesIO(obj_bytes) + raw, curr = get_next_bytes(obj_bytes, 0) + dtype = get_unpacked_value(ObjectDynamicType(), raw) + raw, curr = get_next_bytes(obj_bytes, curr) + buf = BytesIO(raw) arr = np.load(buf) assert isinstance(arr, np.ndarray) - return arr.astype(self.dtype) + return arr.astype(dtype) @property def description(self) -> str: - return f"numpy.ndarray({self.dtype})" + return "numpy.ndarray" class StrDynamicType(DynamicType[str]): @@ -561,6 +570,152 @@ def description(self) -> str: return "bytes" +class ListType(DynamicType[List[T]]): + """ + Represents a list dynamic field type. + """ + + type_: Type[T] + + def __init__(self, type_: Type[T]) -> None: + self.type_ = type_ + self._sub_type = get_field_type(self.type_) + + @property + def python_type(self) -> type: + return list + + def isinstance(self, obj: Any) -> bool: + return isinstance(obj, list) + + def preprocess(self, obj: List[T]) -> bytes: + values = [get_len_bytes(obj)] + for item in obj: + value = get_packed_value(self._sub_type, item) + value_len = get_len_bytes(value) + values.append(value_len) + values.append(value) + return b"".join(values) + + def postprocess(self, obj_bytes: bytes) -> List[T]: + obj_len = int(obj_bytes[0:DEFAULT_LEN_LENGTH]) + curr = DEFAULT_LEN_LENGTH + values = [] + for _ in range(obj_len): + raw, curr = get_next_bytes(obj_bytes, curr) + values.append(get_unpacked_value(self._sub_type, raw)) + return values + + @property + def description(self) -> str: + return f"typing.List[{self.type_}]" + + +V = TypeVar("V") + + +class DictType(DynamicType[Dict[T, V]]): + """ + Represents a list dynamic field type. + """ + + type_: Tuple[Type[T], Type[V]] + + def __init__(self, type_: Tuple[Type[T], Type[V]]) -> None: + self.type_ = type_ + self._key_type = get_field_type(self.type_[0]) + self._val_type = get_field_type(self.type_[1]) + + @property + def python_type(self) -> type: + return dict + + def isinstance(self, obj: Any) -> bool: + return isinstance(obj, dict) + + def preprocess(self, obj: Dict[T, V]) -> bytes: + values = [bytes(str(len(obj)).rjust(DEFAULT_LEN_LENGTH, "0"), encoding="ascii")] + for key, val in obj.items(): + key_bytes = get_packed_value(self._key_type, key) + val_bytes = get_packed_value(self._val_type, val) + key_len = get_len_bytes(key_bytes) + val_len = get_len_bytes(val_bytes) + values.extend([key_len, key_bytes, val_len, val_bytes]) + return b"".join(values) + + def postprocess(self, obj_bytes: bytes) -> Dict[T, V]: + obj_len = int(obj_bytes[0:DEFAULT_LEN_LENGTH]) + curr = DEFAULT_LEN_LENGTH + values = {} + for _ in range(obj_len): + raw, curr = get_next_bytes(obj_bytes, curr) + key = get_unpacked_value(self._key_type, raw) + raw, curr = get_next_bytes(obj_bytes, curr) + val = get_unpacked_value(self._val_type, raw) + values[key] = val + return values + + @property + def description(self) -> str: + return f"typing.Dict[{self.type_[0]}, {self.type_[1]}]" + + +class DataclassType(DynamicType[T]): + """ + Represents a dataclass dynamic field type. + """ + + type_: Type[T] + + def __init__(self, type_: Type[T]) -> None: + self.type_ = type_ + + @property + def python_type(self) -> type: + return self.type_ + + def isinstance(self, obj: Any) -> bool: + return isinstance(obj, self.type_) + + def preprocess(self, obj: T) -> bytes: + type_ = pickle.dumps(type(obj)) + type_len = get_len_bytes(type_) + values = [type_len, type_] + for field in dataclasses.fields(obj): + if field.name in LOCAL_INTERNAL_FIELDS: + continue + sub_type = get_field_type(field.type) + value = get_packed_value(sub_type, getattr(obj, field.name)) + value_len = get_len_bytes(value) + values.append(value_len) + values.append(value) + return b"".join(values) + + def postprocess(self, obj_bytes: bytes) -> T: + raw, curr = get_next_bytes(obj_bytes, 0) + type_ = pickle.loads(raw) + init_values = {} + non_init_values = {} + for field in dataclasses.fields(type_): + if field.name in LOCAL_INTERNAL_FIELDS: + continue + sub_type = get_field_type(field.type) + raw, curr = get_next_bytes(obj_bytes, curr) + value = get_unpacked_value(sub_type, raw) + if field.init: + init_values[field.name] = value + else: + non_init_values[field.name] = value + obj = type_(**init_values) + for field_name, value in non_init_values.items(): + setattr(obj, field_name, value) + return obj + + @property + def description(self) -> str: + return f"{self.type_}(...))" + + PRIMITIVE_TYPES = { int: IntType, float: FloatType, @@ -572,15 +727,29 @@ def description(self) -> str: def get_field_type(python_type: Type[T]) -> FieldType[T]: """ - Returns a `FieldType` that contains all the information LabGraph needs for a field. + Returns a `FieldType` that contains all the information Labgraph needs for a field. Args: `python_type`: A Python type to get a `FieldType` for. """ - if not isinstance(python_type, type): + if isinstance(python_type, FieldType): + return python_type + if python_type.__module__ == "typing": + # TODO: Switch to `typing.get_origin` for py38 + origin = getattr(python_type, "__origin__", None) + if origin in (list, List): + # TODO: Switch to `typing.get_args` for py38 + return ListType(python_type.__args__[0]) # type: ignore + elif origin in (dict, Dict): + # TODO: Switch to `typing.get_args` for py38 + return DictType(python_type.__args__) # type: ignore return ObjectDynamicType() - if issubclass(python_type, Enum): + elif not isinstance(python_type, type): + return ObjectDynamicType() + elif dataclasses.is_dataclass(python_type): + return DataclassType(python_type) + elif issubclass(python_type, Enum): if issubclass(python_type, str): return StrEnumType(python_type) # type: ignore elif issubclass(python_type, int): @@ -595,3 +764,29 @@ def get_field_type(python_type: Type[T]) -> FieldType[T]: return NumpyDynamicType() return ObjectDynamicType() + + +def get_len_bytes(obj: Any) -> bytes: + return bytes(str(len(obj)).rjust(DEFAULT_LEN_LENGTH, "0"), encoding="ascii") + + +def get_next_bytes(obj: bytes, curr: int) -> Tuple[bytes, int]: + length = int(obj[curr : curr + DEFAULT_LEN_LENGTH]) + curr += DEFAULT_LEN_LENGTH + raw = obj[curr : curr + length] + curr += length + return (raw, curr) + + +def get_packed_value(type_: FieldType[T], value: Any) -> bytes: + value = type_.preprocess(value) + if isinstance(type_, StructType): + value = struct.pack(type_.format_string, value) + return value + + +def get_unpacked_value(type_: FieldType[T], value: bytes) -> Any: + if isinstance(type_, StructType): + value = struct.unpack(type_.format_string, value)[0] + value = type_.postprocess(value) # type: ignore + return value diff --git a/labgraph/runners/aligner.py b/labgraph/runners/aligner.py index ac1371050..134b86e9a 100644 --- a/labgraph/runners/aligner.py +++ b/labgraph/runners/aligner.py @@ -7,9 +7,9 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple -from .._cthulhu.cthulhu import LabGraphCallback, LabGraphCallbackParams +from .._cthulhu.cthulhu import LabgraphCallback, LabgraphCallbackParams from ..messages.message import TimestampedMessage -from ..util.error import LabGraphError +from ..util.error import LabgraphError from ..util.min_heap import MinHeap @@ -33,11 +33,11 @@ class Aligner(ABC): """ @abstractmethod - def register(self, stream_id: str, callback: LabGraphCallback) -> None: + def register(self, stream_id: str, callback: LabgraphCallback) -> None: pass @abstractmethod - def push(self, params: LabGraphCallbackParams[Any]) -> None: + def push(self, params: LabgraphCallbackParams[Any]) -> None: pass @abstractmethod @@ -54,7 +54,7 @@ async def run(self) -> None: TimestampedHeapEntry = Tuple[ - float, int, str, LabGraphCallbackParams[TimestampedMessage] + float, int, str, LabgraphCallbackParams[TimestampedMessage] ] TimestampedHeap = MinHeap[TimestampedHeapEntry] @@ -87,20 +87,20 @@ def __init__(self, lag: float) -> None: # Lag (in seconds) during which incoming messages are buffered self.lag: float = lag # Callbacks keyed by the id of the stream they're subscribing to - self.callbacks: Dict[str, List[LabGraphCallback]] = collections.defaultdict( + self.callbacks: Dict[str, List[LabgraphCallback]] = collections.defaultdict( list ) self.active: bool = True # Current state of associated runner self.terminate: bool = False # Flag to quit immediately - def register(self, stream_id: str, callback: LabGraphCallback) -> None: + def register(self, stream_id: str, callback: LabgraphCallback) -> None: self.callbacks[stream_id].append(callback) - def push(self, params: LabGraphCallbackParams[TimestampedMessage]) -> None: + def push(self, params: LabgraphCallbackParams[TimestampedMessage]) -> None: message = params.message if params.stream_id is None: - raise LabGraphError( + raise LabgraphError( "TimestampAligner::push expected stream id, but got None." ) heap_entry: TimestampedHeapEntry = ( diff --git a/labgraph/runners/entry.py b/labgraph/runners/entry.py index 3ec5006ef..e8f25e656 100644 --- a/labgraph/runners/entry.py +++ b/labgraph/runners/entry.py @@ -64,7 +64,7 @@ def main( f"--{ProcessManagerState.SUBPROCESS_ARG}" ) - # Get the Python class for the LabGraph module + # Get the Python class for the Labgraph module module_cls = get_module_class(*module.rsplit(".", 1)) # Restore the config and state for the module diff --git a/labgraph/runners/exceptions.py b/labgraph/runners/exceptions.py index da9683c95..d96607fef 100644 --- a/labgraph/runners/exceptions.py +++ b/labgraph/runners/exceptions.py @@ -14,7 +14,7 @@ class NormalTermination(Exception): class ExceptionMessage(Message): """ - Holds the bytes for a thrown exception in a LabGraph message. + Holds the bytes for a thrown exception in a Labgraph message. Convenient for passing exceptions between processes. """ diff --git a/labgraph/runners/local_runner.py b/labgraph/runners/local_runner.py index a31c6b603..4b6ae4fd8 100644 --- a/labgraph/runners/local_runner.py +++ b/labgraph/runners/local_runner.py @@ -27,11 +27,11 @@ import yappi from labgraph_cpp import NodeBootstrapInfo, NodeTopic # type: ignore -# HACK: Import from LabGraph's wrapper of Cthulhu before importing dynamic libs to set +# HACK: Import from Labgraph's wrapper of Cthulhu before importing dynamic libs to set # the shared memory name from .._cthulhu.cthulhu import ( Consumer, - LabGraphCallbackParams, + LabgraphCallbackParams, Mode, Producer, format_performance_summary, @@ -95,7 +95,7 @@ class LocalRunnerState: class LocalRunner(Runner): """ - A utility for running LabGraph modules. Given a module, runs the computation it + A utility for running Labgraph modules. Given a module, runs the computation it describes by creating two threads, one for its foreground processing and one for its event loop processing. @@ -125,7 +125,7 @@ def __init__(self, module: Module, options: Optional[RunnerOptions] = None) -> N def run(self) -> None: """ - Starts the LabGraph module. Returns when the module has terminated. + Starts the Labgraph module. Returns when the module has terminated. """ try: if should_profile(): @@ -240,7 +240,7 @@ def run(self) -> None: def _setup_cthulhu(self) -> None: """ - Sets up Cthulhu as the transport for the LabGraph graph. Creates streams only + Sets up Cthulhu as the transport for the Labgraph graph. Creates streams only if the module has no parent graph. Then creates producers and consumers according to the publishers and subscribers in the module. """ @@ -355,7 +355,7 @@ def _callback_for_stream(self, stream_id: str) -> Callable[..., None]: assert MessageType is not None if self._options.aligner is not None: # Type with extra information for the aligner - MessageType = LabGraphCallbackParams[MessageType] # type: ignore + MessageType = LabgraphCallbackParams[MessageType] # type: ignore def callback(message: MessageType) -> None: # type: ignore with self._state.lock: @@ -569,6 +569,10 @@ def run(self) -> None: logger.debug(f"{self.module}:background thread:terminate aligner") self.options.aligner.wait_for_completion() logger.debug(f"{self.module}:background thread:shutting down async gens") + + # https://bugs.python.org/issue38559 + for task in asyncio.Task.all_tasks(loop=loop): + task.cancel() loop.run_until_complete(loop.shutdown_asyncgens()) logger.debug(f"{self.module}:background thread:waiting for pending tasks") diff --git a/labgraph/runners/parallel_runner.py b/labgraph/runners/parallel_runner.py index 313160fe8..850299c31 100644 --- a/labgraph/runners/parallel_runner.py +++ b/labgraph/runners/parallel_runner.py @@ -54,7 +54,7 @@ def __init__(self, graph: Graph, options: Optional[RunnerOptions] = None) -> Non def run(self) -> None: """ - Starts the LabGraph graph. Returns when the graph has terminated. + Starts the Labgraph graph. Returns when the graph has terminated. """ self._graph.setup() self._create_logger() @@ -192,7 +192,7 @@ def _get_class_module(self, cls: type) -> str: except (ImportError, RuntimeError): raise RuntimeError( f"Putting {cls.__name__} in the main scope is preventing its use with " - f"df.{self.__class__.__name__}. Please consider either a) creating " + f"lg.{self.__class__.__name__}. Please consider either a) creating " "another module to use as the main module or b) using __main__.py " "instead.\nhttps://docs.python.org/3.6/library/__main__.html" ) @@ -200,7 +200,7 @@ def _get_class_module(self, cls: type) -> str: def run(graph_type: Type[Graph]) -> None: """ - Entry point for running LabGraph graphs. Call `run` with a LabGraph graph type to + Entry point for running Labgraph graphs. Call `run` with a Labgraph graph type to run a new graph of that type. """ config_type = graph_type.__config_type__ diff --git a/labgraph/runners/runner.py b/labgraph/runners/runner.py index c7cb72066..693c0c064 100644 --- a/labgraph/runners/runner.py +++ b/labgraph/runners/runner.py @@ -13,7 +13,7 @@ from ..loggers.logger import Logger, LoggerConfig from ..messages.message import Message from ..messages.types import BytesType -from ..util.error import LabGraphError +from ..util.error import LabgraphError from ..util.logger import get_logger from .aligner import Aligner from .process_manager import ProcessManagerState diff --git a/labgraph/runners/tests/test_cpp.py b/labgraph/runners/tests/test_cpp.py deleted file mode 100644 index 194fd2a87..000000000 --- a/labgraph/runners/tests/test_cpp.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2004-present Facebook. All Rights Reserved. - -import time -from pathlib import Path -from typing import Dict, Sequence - -import h5py -from MyCPPNodes import MyCPPSink, MyCPPSource # type: ignore - -from ...graphs.config import Config -from ...graphs.cpp_node import CPPNodeConfig -from ...graphs.graph import Graph -from ...graphs.group import Connections -from ...graphs.method import subscriber -from ...graphs.module import Module -from ...graphs.node import Node -from ...graphs.topic import Topic -from ...messages.message import Message -from ...runners.exceptions import NormalTermination -from ...runners.parallel_runner import ParallelRunner -from ...util.testing import get_test_filename, local_test - - -class MyMessage(Message): - value: int - - -class MySinkConfig(Config): - filename: str - - -class MyPythonSink(Node): - C = Topic(MyMessage) - - config: MySinkConfig - - def setup(self) -> None: - self.messages_seen: int = 0 - - @subscriber(C) - def sink(self, message: MyMessage) -> None: - with open(self.config.filename, "a") as f: - f.write(str(message.value) + "\n") - self.messages_seen += 1 - if self.messages_seen == MyCPPSource.NUM_SAMPLES: - time.sleep(2) - raise NormalTermination() - - -class MyGraphConfig(Config): - python_filename: str - cpp_filename: str - - -class MyMixedGraph(Graph): - CPP_SOURCE: MyCPPSource - CPP_SINK: MyCPPSink - PYTHON_SINK: MyPythonSink - - config: MyGraphConfig - - def setup(self) -> None: - self.CPP_SINK.configure(CPPNodeConfig(args=[self.config.cpp_filename])) - self.PYTHON_SINK.configure(MySinkConfig(filename=self.config.python_filename)) - - def connections(self) -> Connections: - return ( - (self.CPP_SOURCE.A, self.CPP_SINK.B), - (self.CPP_SOURCE.A, self.PYTHON_SINK.C), - ) - - def process_modules(self) -> Sequence[Module]: - return [self.CPP_SOURCE, self.CPP_SINK, self.PYTHON_SINK] - - def logging(self) -> Dict[str, Topic]: - return { - "cpp_source": self.CPP_SOURCE.A, - "cpp_sink": self.CPP_SINK.B, - "python_sink": self.PYTHON_SINK.C, - } - - -@local_test -def test_cpp_graph() -> None: - """ - Tests that we can run a graph with both C++ and Python nodes, and read the results - on disk. - """ - # Run the graph - graph = MyMixedGraph() - python_filename = get_test_filename() - cpp_filename = get_test_filename() - graph.configure( - MyGraphConfig(python_filename=python_filename, cpp_filename=cpp_filename) - ) - runner = ParallelRunner(graph=graph) - # Get the HDF5 log path to verify the logs later - output_path = str( # noqa: F841 - Path(runner._options.logger_config.output_directory) - / Path(f"{runner._options.logger_config.recording_name}.h5") - ) - runner.run() - - # Check C++ sink output - cpp_nums = set(range(MyCPPSource.NUM_SAMPLES)) - with open(cpp_filename, "r") as f: - for line in f: - num = int(line.strip()) - cpp_nums.remove(num) - - assert len(cpp_nums) == 0, f"Missing numbers in C++ sink output: {cpp_nums}" - - # Check Python sink output - python_nums = set(range(MyCPPSource.NUM_SAMPLES)) - with open(python_filename, "r") as f: - for line in f: - num = int(line.strip()) - if num in python_nums: - python_nums.remove(num) - - assert ( - len(python_nums) == 0 - ), f"Missing numbers in Python sink output: {python_nums}" - - # Check HDF5 logger output - with h5py.File(output_path, "r") as h5py_file: - for hdf5_path in ("cpp_source", "cpp_sink", "python_sink"): - dataset = h5py_file[hdf5_path] - assert dataset.shape == (MyCPPSource.NUM_SAMPLES,) - dataset_nums = {int(num[0]) for num in dataset} - assert dataset_nums == set(range(MyCPPSource.NUM_SAMPLES)) diff --git a/labgraph/runners/tests/test_exception.py b/labgraph/runners/tests/test_exception.py index eb06bf9d2..f60ab42fa 100644 --- a/labgraph/runners/tests/test_exception.py +++ b/labgraph/runners/tests/test_exception.py @@ -69,14 +69,6 @@ async def publisher(self) -> AsyncPublisher: ) -@local_test -@pytest.mark.parametrize("module_type", MODULE_TYPES) # type: ignore -def test_local_throw(module_type: Type[Node]) -> None: - node = module_type() - runner = LocalRunner(module=node) - with pytest.raises(MyTestException): - runner.run() - class SubscriberNode(Node): TOPIC = Topic(MyTestMessage) @@ -156,18 +148,6 @@ def process_modules(self) -> Sequence[Module]: ) -@local_test -@pytest.mark.parametrize("graph_type", GRAPH_TYPES) # type: ignore -def test_parallel_throw(graph_type: Type[Graph]) -> None: - graph = graph_type() - runner = ParallelRunner(graph=graph) - with pytest.raises(ProcessManagerException) as ex: - runner.run() - assert [f for f in ex.value.failures.values() if f is not None] == [ - ProcessFailureType.EXCEPTION - ] - - class PublisherSubscriberGraph(Graph): PUBLISHER: PublisherNode SUBSCRIBER: SubscriberNode @@ -184,21 +164,4 @@ def process_modules(self) -> Sequence[Module]: class ThrowerLogger(Logger): def write(self, messages_by_logging_id: Mapping[str, Sequence[Message]]) -> None: - raise MyTestException() - - -@local_test -def test_logger_throw() -> None: - graph = PublisherSubscriberGraph() - runner = ParallelRunner( - graph=graph, options=RunnerOptions(logger_type=ThrowerLogger) - ) - with pytest.raises(ProcessManagerException) as ex: - runner.run() - - assert ex.value.failures == { - LOGGER_KEY: ProcessFailureType.EXCEPTION, - "PUBLISHER": None, - "SUBSCRIBER": None, - } - assert ex.value.exceptions[LOGGER_KEY] == "MyTestException()" + raise MyTestException() \ No newline at end of file diff --git a/labgraph/runners/tests/test_process_manager.py b/labgraph/runners/tests/test_process_manager.py index 1b8a41a41..51f491721 100644 --- a/labgraph/runners/tests/test_process_manager.py +++ b/labgraph/runners/tests/test_process_manager.py @@ -129,32 +129,6 @@ def proc( time.sleep(PROCESS_SLEEP_TIME) -@local_test -def test_normal() -> None: - """ - Tests that we can run multiple processes that terminate normally. - """ - manager = ProcessManager( - processes=( - ProcessInfo( - module=__name__, - name="proc1", - args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), - ), - ProcessInfo( - module=__name__, - name="proc2", - args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), - ), - ), - name="test_manager", - startup_period=TEST_STARTUP_PERIOD, - shutdown_period=TEST_SHUTDOWN_PERIOD, - ) - - manager.run() - - @pytest.mark.parametrize( "crash_phase", ( @@ -164,39 +138,6 @@ def test_normal() -> None: ProcessPhase.STOPPING, ), ) -@local_test -def test_crash(crash_phase: ProcessPhase) -> None: - """ - Tests that we can run multiple processes where one of them crashes. - """ - manager = ProcessManager( - processes=( - ProcessInfo( - module=__name__, - name="proc1", - args=( - "--manager-name", - "test_manager", - "--shutdown", - "CRASH", - "--last-phase", - crash_phase.name, - ), - ), - ProcessInfo( - module=__name__, - name="proc2", - args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), - ), - ), - name="test_manager", - startup_period=TEST_STARTUP_PERIOD, - shutdown_period=TEST_SHUTDOWN_PERIOD, - ) - - with pytest.raises(ProcessManagerException) as ex: - manager.run() - assert ex.value.failures == {"proc1": ProcessFailureType.CRASH, "proc2": None} @pytest.mark.parametrize( @@ -208,39 +149,6 @@ def test_crash(crash_phase: ProcessPhase) -> None: ProcessPhase.STOPPING, ), ) -@local_test -def test_exception(exception_phase: ProcessPhase) -> None: - """ - Tests that we can run multiple processes where one of them raises an exception. - """ - manager = ProcessManager( - processes=( - ProcessInfo( - module=__name__, - name="proc1", - args=( - "--manager-name", - "test_manager", - "--shutdown", - "EXCEPTION", - "--last-phase", - exception_phase.name, - ), - ), - ProcessInfo( - module=__name__, - name="proc2", - args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), - ), - ), - name="test_manager", - startup_period=TEST_STARTUP_PERIOD, - shutdown_period=TEST_SHUTDOWN_PERIOD, - ) - - with pytest.raises(ProcessManagerException) as ex: - manager.run() - assert ex.value.failures == {"proc1": ProcessFailureType.EXCEPTION, "proc2": None} @pytest.mark.parametrize( @@ -252,39 +160,6 @@ def test_exception(exception_phase: ProcessPhase) -> None: ProcessPhase.STOPPING, ), ) -@local_test -def test_hang(hang_phase: ProcessPhase) -> None: - """ - Tests that we can run multiple processes where one of them hangs. - """ - manager = ProcessManager( - processes=( - ProcessInfo( - module=__name__, - name="proc1", - args=( - "--manager-name", - "test_manager", - "--shutdown", - "HANG", - "--last-phase", - hang_phase.name, - ), - ), - ProcessInfo( - module=__name__, - name="proc2", - args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), - ), - ), - name="test_manager", - startup_period=TEST_STARTUP_PERIOD, - shutdown_period=TEST_SHUTDOWN_PERIOD, - ) - - with pytest.raises(ProcessManagerException) as ex: - manager.run() - assert ex.value.failures == {"proc1": ProcessFailureType.HANG, "proc2": None} @click.command() diff --git a/labgraph/runners/util.py b/labgraph/runners/util.py index 13228ef5b..a171e268f 100644 --- a/labgraph/runners/util.py +++ b/labgraph/runners/util.py @@ -27,7 +27,7 @@ def get_module_class(python_module: str, python_class: str) -> Type[Module]: break else: raise NameError( - f"Could not find LabGraph a class in module `{python_module}`` with " + f"Could not find Labgraph a class in module `{python_module}`` with " f"the following class.__name__: `{real_class_name}`. " f"If it refers to an anonymous class, consider moving it " f"to the module scope of {python_module}." diff --git a/labgraph/tests/test_imports.py b/labgraph/tests/test_imports.py index 70ea32084..34d4789e5 100644 --- a/labgraph/tests/test_imports.py +++ b/labgraph/tests/test_imports.py @@ -4,7 +4,7 @@ def test_imports() -> None: """ - Tests that we can import top-level LabGraph objects correctly. + Tests that we can import top-level Labgraph objects correctly. """ from .. import ( # noqa: F401 @@ -19,7 +19,7 @@ def test_imports() -> None: Connections, CPPNodeConfig, DeferredMessage, - LabGraphError, + LabgraphError, Event, EventGraph, EventPublishingHeap, diff --git a/labgraph/tests/test_typecheck.py b/labgraph/tests/test_typecheck.py deleted file mode 100644 index 2ba075e86..000000000 --- a/labgraph/tests/test_typecheck.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2004-present Facebook. All Rights Reserved. - -import os -import runpy -import shutil -import subprocess -import tempfile -from glob import glob -from pathlib import Path -from typing import Optional -from zipfile import ZipFile - -import labgraph as lg - -from ..runners.launch import _get_pex_path, _in_pex, launch -from ..util.logger import get_logger -from ..util.resource import get_resource_tempfile - - -SOURCE_PATH = "labgraph" - -logger = get_logger(__name__) - - -def test_typecheck() -> None: - """ - Typechecks LabGraph using mypy. Assumes that the test is running from a PEX (if - not, the test skips). - """ - mypy_ini_path = get_resource_tempfile(__name__, "mypy.ini") - mypy_args = ["--config-file", mypy_ini_path] - - zip_path: Optional[str] = None - try: - # If available, get the path to the typecheck_src.zip source archive - zip_path = get_resource_tempfile(__name__, "typecheck_src.zip") - except FileNotFoundError: - pass # Just let zip_path be None and handle this case below - - temp_dir: Optional[tempfile.TemporaryDirectory] = None - if zip_path is None: - # If the source archive is not available, typecheck the installed location - # for LabGraph - src_path = str(Path(lg.__file__).parent) - mypy_args += glob(f"{src_path}/**/*.py", recursive=True) - else: - # If available, typecheck the typecheck_src.zip source archive - temp_dir = tempfile.TemporaryDirectory() # noqa: P201 - src_path = temp_dir.name - # Extract the source files from the zip file - src_file = ZipFile(zip_path) - for file_path in src_file.namelist(): - if file_path.startswith(SOURCE_PATH) and file_path.endswith(".py"): - src_file.extract(file_path, src_path) - mypy_args.append(file_path) - - # Typecheck in a subprocess - mypy_proc = launch("mypy", mypy_args, cwd=src_path, stdout=subprocess.PIPE) - mypy_output: Optional[str] = None - if mypy_proc.stdout is not None: - mypy_output = mypy_proc.stdout.read().decode("utf-8") - mypy_proc.wait() - - if temp_dir is not None: - temp_dir.cleanup() - - if mypy_proc.returncode != 0: - error_message = f"Typechecking failed (exit code {mypy_proc.returncode})" - if mypy_output is not None: - logger.error(mypy_output) - error_message += f":\n\n{mypy_output}" - raise RuntimeError(error_message) diff --git a/labgraph/util/__init__.py b/labgraph/util/__init__.py index 28c03ef4a..1d3fd10c6 100644 --- a/labgraph/util/__init__.py +++ b/labgraph/util/__init__.py @@ -2,7 +2,7 @@ # Copyright 2004-present Facebook. All Rights Reserved. __all__ = [ - "LabGraphError", + "LabgraphError", "get_resource_tempfile", "get_test_filename", "async_test", @@ -10,6 +10,6 @@ "get_free_port", ] -from .error import LabGraphError +from .error import LabgraphError from .resource import get_resource_tempfile from .testing import async_test, get_free_port, get_test_filename, local_test diff --git a/labgraph/util/error.py b/labgraph/util/error.py index e62f9d85f..72ea34b2f 100644 --- a/labgraph/util/error.py +++ b/labgraph/util/error.py @@ -2,10 +2,10 @@ # Copyright 2004-present Facebook. All Rights Reserved. -class LabGraphError(Exception): +class LabgraphError(Exception): """ - Represents a LabGraph error. `LabGraphError` Will be raised when an error is tied - to particular LabGraph concepts such as graph construction. + Represents a Labgraph error. `LabgraphError` Will be raised when an error is tied + to particular Labgraph concepts such as graph construction. """ pass diff --git a/labgraph/util/typing.py b/labgraph/util/typing.py new file mode 100644 index 000000000..b5df9a20d --- /dev/null +++ b/labgraph/util/typing.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import sys +import typing + + +# TODO: Remove all 3.8 checks once 3.6 is dead. +def is_py38() -> bool: + return sys.version_info > (3, 8) + + +# type_ is an instance of typing._GenericAlias +# generic is an instance of typing.Generic +def is_generic_subclass(type_: typing.Any, generic: typing.Any) -> bool: + if is_py38(): + return typing.get_origin(type_) is generic # type: ignore + else: + return issubclass(type_, generic) diff --git a/labgraph/util/version.py b/labgraph/util/version.py index dfb4605f4..131b6217b 100644 --- a/labgraph/util/version.py +++ b/labgraph/util/version.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. -# Constant representing the version of LabGraph. +# Constant representing the version of Labgraph. VERSION = "1.0.0" diff --git a/labgraph/zmq_node/tests/test_zmq_node.py b/labgraph/zmq_node/tests/test_zmq_node.py deleted file mode 100644 index fd6cd57d6..000000000 --- a/labgraph/zmq_node/tests/test_zmq_node.py +++ /dev/null @@ -1,272 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2004-present Facebook. All Rights Reserved. - -import asyncio -import os -import time -from multiprocessing import Process -from typing import List, Sequence - -import pytest -import zmq - -from ...graphs import ( - AsyncPublisher, - Config, - Connections, - Graph, - Module, - Node, - Topic, - publisher, - subscriber, -) -from ...runners import LocalRunner, NormalTermination, ParallelRunner -from ...util.testing import get_free_port, get_test_filename, local_test -from ..zmq_message import ZMQMessage -from ..zmq_poller_node import ZMQPollerConfig, ZMQPollerNode -from ..zmq_sender_node import ZMQSenderConfig, ZMQSenderNode - - -NUM_MESSAGES = 5 -SAMPLE_RATE = 10 -STARTUP_TIME = 1 - -ZMQ_ADDR = "tcp://127.0.0.1" -ZMQ_TOPIC = "zmq_topic" - -DATA_DELIMITER = b"\0" - - -class MySinkConfig(Config): - output_filename: str - - -class MySink(Node): - """ - Convenience node for receiving messages from a `ZMQPollerNode`. - """ - - TOPIC = Topic(ZMQMessage) - config: MySinkConfig - - def setup(self) -> None: - self.output_file = open(self.config.output_filename, "wb") - self.num_received = 0 - - @subscriber(TOPIC) - async def sink(self, message: ZMQMessage) -> None: - self.output_file.write(message.data) - self.output_file.write(DATA_DELIMITER) - - self.num_received += 1 - if self.num_received == NUM_MESSAGES: - raise NormalTermination() - - def cleanup(self) -> None: - self.output_file.close() - - -class MySourceConfig(Config): - should_terminate: bool = True - - -class MySource(Node): - """ - Convenience node for sending messages to a `ZMQSenderNode`. - """ - - TOPIC = Topic(ZMQMessage) - config: MySourceConfig - samples = [bytes([i]) for i in range(1, NUM_MESSAGES + 1)] - - @publisher(TOPIC) - async def source(self) -> AsyncPublisher: - for sample_bytes in self.samples: - yield self.TOPIC, ZMQMessage(sample_bytes) - await asyncio.sleep(1 / SAMPLE_RATE) - - if self.config.should_terminate: - raise NormalTermination() - - -def write_samples_to_zmq(address: str, samples: Sequence[bytes], topic: str) -> None: - context = zmq.Context() - socket = context.socket(zmq.PUB) - socket.bind(address) - - # NOTE: In place of waiting on a monitor socket, we wait for a startup time - # here. Otherwise we would need to listen on an extra socket in this function, - # increasing the complexity of this test. If this startup wait becomes flaky, we may - # want to consider adding the monitor socket here. - time.sleep(STARTUP_TIME) - for sample_bytes in samples: - socket.send_multipart((bytes(topic, "UTF-8"), sample_bytes)) - time.sleep(1 / SAMPLE_RATE) - socket.close() - - -def recv_samples_from_zmq(address: str, topic: str, output_fname: str) -> None: - context = zmq.Context() - socket = context.socket(zmq.SUB) - socket.setsockopt(zmq.SUBSCRIBE, bytes(topic, "UTF-8")) - socket.connect(address) - - with open(output_fname, "bw") as output_file: - for _ in range(NUM_MESSAGES): - received = socket.recv_multipart() - topic, data = received - output_file.write(data) - output_file.write(DATA_DELIMITER) - socket.close() - - -@local_test -def test_zmq_poller_node() -> None: - """ - Tests that a `ZMQPollerNode` is able to read samples from a ZMQ socket and echo - the samples back out of the graph. - """ - - class MyZMQPollerGraphConfig(Config): - read_addr: str - zmq_topic: str - output_filename: str - - class MyZMQPollerGraph(Graph): - MY_SOURCE: ZMQPollerNode - MY_SINK: MySink - config: MyZMQPollerGraphConfig - - def setup(self) -> None: - self.MY_SOURCE.configure( - ZMQPollerConfig( - read_addr=self.config.read_addr, zmq_topic=self.config.zmq_topic - ) - ) - self.MY_SINK.configure( - MySinkConfig(output_filename=self.config.output_filename) - ) - - def connections(self) -> Connections: - return ((self.MY_SOURCE.topic, self.MY_SINK.TOPIC),) - - graph = MyZMQPollerGraph() - output_filename = get_test_filename() - address = f"{ZMQ_ADDR}:{get_free_port()}" - graph.configure( - MyZMQPollerGraphConfig( - read_addr=address, zmq_topic=ZMQ_TOPIC, output_filename=output_filename - ) - ) - runner = LocalRunner(module=graph) - - samples = [bytes([i]) for i in range(1, NUM_MESSAGES + 1)] - p = Process(target=write_samples_to_zmq, args=(address, samples, ZMQ_TOPIC)) - p.start() - runner.run() - p.join() - - with open(output_filename, "br") as f: - data = f.read() - assert set(samples) == set(data.strip(DATA_DELIMITER).split(DATA_DELIMITER)) - - -@local_test -def test_zmq_sender_node() -> None: - """ - Tests that a `ZMQSenderNode` is able to read samples from the graph and write the - samples back out to a ZMQ socket. - """ - - class MyZMQSenderGraph(Graph): - MY_SOURCE: MySource - MY_SINK: ZMQSenderNode - - config: ZMQSenderConfig - - def setup(self) -> None: - self.MY_SOURCE.configure(MySourceConfig()) - self.MY_SINK.configure(self.config) - - def connections(self) -> Connections: - return ((self.MY_SOURCE.TOPIC, self.MY_SINK.topic),) - - output_filename = get_test_filename() - graph = MyZMQSenderGraph() - address = f"{ZMQ_ADDR}:{get_free_port()}" - graph.configure(ZMQSenderConfig(write_addr=address, zmq_topic=ZMQ_TOPIC)) - runner = LocalRunner(module=graph) - - p = Process( - target=recv_samples_from_zmq, args=(address, ZMQ_TOPIC, output_filename) - ) - p.start() - runner.run() - p.join() - - with open(output_filename, "br") as f: - data = f.read() - assert set(graph.MY_SOURCE.samples) == set( - data.strip(DATA_DELIMITER).split(DATA_DELIMITER) - ) - - -@local_test -def test_zmq_send_and_poll() -> None: - """ - Tests that a `ZMQSenderNode` and a `ZMQPollerNode` can work together. - """ - - class MyZMQGraphConfig(Config): - addr: str - zmq_topic: str - output_filename: str - - class MyZMQGraph(Graph): - DF_SOURCE: MySource - ZMQ_SENDER: ZMQSenderNode - ZMQ_POLLER: ZMQPollerNode - DF_SINK: MySink - - def setup(self) -> None: - self.DF_SOURCE.configure(MySourceConfig(should_terminate=False)) - self.ZMQ_SENDER.configure( - ZMQSenderConfig( - write_addr=self.config.addr, zmq_topic=self.config.zmq_topic - ) - ) - self.ZMQ_POLLER.configure( - ZMQPollerConfig( - read_addr=self.config.addr, zmq_topic=self.config.zmq_topic - ) - ) - self.DF_SINK.configure( - MySinkConfig(output_filename=self.config.output_filename) - ) - - def connections(self) -> Connections: - return ( - (self.DF_SOURCE.TOPIC, self.ZMQ_SENDER.topic), - (self.ZMQ_POLLER.topic, self.DF_SINK.TOPIC), - ) - - def process_modules(self) -> Sequence[Module]: - return (self.DF_SOURCE, self.ZMQ_SENDER, self.ZMQ_POLLER, self.DF_SINK) - - output_filename = get_test_filename() - graph = MyZMQGraph() - address = f"{ZMQ_ADDR}:{get_free_port()}" - graph.configure( - MyZMQGraphConfig( - addr=address, zmq_topic=ZMQ_TOPIC, output_filename=output_filename - ) - ) - runner = ParallelRunner(graph=graph) - runner.run() - - with open(output_filename, "br") as f: - data = f.read() - assert set(graph.DF_SOURCE.samples) == set( - data.strip(DATA_DELIMITER).split(DATA_DELIMITER) - ) diff --git a/labgraph/zmq_node/zmq_poller_node.py b/labgraph/zmq_node/zmq_poller_node.py index 5c7b7a359..a06d647c4 100644 --- a/labgraph/zmq_node/zmq_poller_node.py +++ b/labgraph/zmq_node/zmq_poller_node.py @@ -8,6 +8,7 @@ from zmq.utils.monitor import parse_monitor_message from ..graphs import AsyncPublisher, Config, Node, Topic, background, publisher +from ..util.error import LabgraphError from ..util.logger import get_logger from ..zmq_node import ZMQMessage from .constants import ZMQEvent diff --git a/labgraph/zmq_node/zmq_sender_node.py b/labgraph/zmq_node/zmq_sender_node.py index 38fbf65ba..02d85b8d3 100644 --- a/labgraph/zmq_node/zmq_sender_node.py +++ b/labgraph/zmq_node/zmq_sender_node.py @@ -2,12 +2,14 @@ # Copyright 2004-present Facebook. All Rights Reserved. import asyncio +import time import zmq import zmq.asyncio from zmq.utils.monitor import parse_monitor_message from ..graphs import Config, Node, Topic, background, subscriber +from ..util.error import LabgraphError from ..util.logger import get_logger from ..zmq_node import ZMQMessage from .constants import ZMQEvent @@ -25,8 +27,8 @@ class ZMQSenderConfig(Config): class ZMQSenderNode(Node): """ - Represents a node in a LabGraph graph that subscribes to messages in a - LabGraph topic and forwards them by writing to a ZMQ socket. + Represents a node in a Labgraph graph that subscribes to messages in a + Labgraph topic and forwards them by writing to a ZMQ socket. Args: write_addr: The address to which ZMQ data should be written. From ded3e82eec055dd119272627362a403802199b6d Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 15:50:37 -0800 Subject: [PATCH 05/13] Add additional tests. --- labgraph/runners/tests/test_cpp.py | 132 +++++++++++++++++++++++++++++ labgraph/tests/test_typecheck.py | 73 ++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 labgraph/runners/tests/test_cpp.py create mode 100644 labgraph/tests/test_typecheck.py diff --git a/labgraph/runners/tests/test_cpp.py b/labgraph/runners/tests/test_cpp.py new file mode 100644 index 000000000..194fd2a87 --- /dev/null +++ b/labgraph/runners/tests/test_cpp.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import time +from pathlib import Path +from typing import Dict, Sequence + +import h5py +from MyCPPNodes import MyCPPSink, MyCPPSource # type: ignore + +from ...graphs.config import Config +from ...graphs.cpp_node import CPPNodeConfig +from ...graphs.graph import Graph +from ...graphs.group import Connections +from ...graphs.method import subscriber +from ...graphs.module import Module +from ...graphs.node import Node +from ...graphs.topic import Topic +from ...messages.message import Message +from ...runners.exceptions import NormalTermination +from ...runners.parallel_runner import ParallelRunner +from ...util.testing import get_test_filename, local_test + + +class MyMessage(Message): + value: int + + +class MySinkConfig(Config): + filename: str + + +class MyPythonSink(Node): + C = Topic(MyMessage) + + config: MySinkConfig + + def setup(self) -> None: + self.messages_seen: int = 0 + + @subscriber(C) + def sink(self, message: MyMessage) -> None: + with open(self.config.filename, "a") as f: + f.write(str(message.value) + "\n") + self.messages_seen += 1 + if self.messages_seen == MyCPPSource.NUM_SAMPLES: + time.sleep(2) + raise NormalTermination() + + +class MyGraphConfig(Config): + python_filename: str + cpp_filename: str + + +class MyMixedGraph(Graph): + CPP_SOURCE: MyCPPSource + CPP_SINK: MyCPPSink + PYTHON_SINK: MyPythonSink + + config: MyGraphConfig + + def setup(self) -> None: + self.CPP_SINK.configure(CPPNodeConfig(args=[self.config.cpp_filename])) + self.PYTHON_SINK.configure(MySinkConfig(filename=self.config.python_filename)) + + def connections(self) -> Connections: + return ( + (self.CPP_SOURCE.A, self.CPP_SINK.B), + (self.CPP_SOURCE.A, self.PYTHON_SINK.C), + ) + + def process_modules(self) -> Sequence[Module]: + return [self.CPP_SOURCE, self.CPP_SINK, self.PYTHON_SINK] + + def logging(self) -> Dict[str, Topic]: + return { + "cpp_source": self.CPP_SOURCE.A, + "cpp_sink": self.CPP_SINK.B, + "python_sink": self.PYTHON_SINK.C, + } + + +@local_test +def test_cpp_graph() -> None: + """ + Tests that we can run a graph with both C++ and Python nodes, and read the results + on disk. + """ + # Run the graph + graph = MyMixedGraph() + python_filename = get_test_filename() + cpp_filename = get_test_filename() + graph.configure( + MyGraphConfig(python_filename=python_filename, cpp_filename=cpp_filename) + ) + runner = ParallelRunner(graph=graph) + # Get the HDF5 log path to verify the logs later + output_path = str( # noqa: F841 + Path(runner._options.logger_config.output_directory) + / Path(f"{runner._options.logger_config.recording_name}.h5") + ) + runner.run() + + # Check C++ sink output + cpp_nums = set(range(MyCPPSource.NUM_SAMPLES)) + with open(cpp_filename, "r") as f: + for line in f: + num = int(line.strip()) + cpp_nums.remove(num) + + assert len(cpp_nums) == 0, f"Missing numbers in C++ sink output: {cpp_nums}" + + # Check Python sink output + python_nums = set(range(MyCPPSource.NUM_SAMPLES)) + with open(python_filename, "r") as f: + for line in f: + num = int(line.strip()) + if num in python_nums: + python_nums.remove(num) + + assert ( + len(python_nums) == 0 + ), f"Missing numbers in Python sink output: {python_nums}" + + # Check HDF5 logger output + with h5py.File(output_path, "r") as h5py_file: + for hdf5_path in ("cpp_source", "cpp_sink", "python_sink"): + dataset = h5py_file[hdf5_path] + assert dataset.shape == (MyCPPSource.NUM_SAMPLES,) + dataset_nums = {int(num[0]) for num in dataset} + assert dataset_nums == set(range(MyCPPSource.NUM_SAMPLES)) diff --git a/labgraph/tests/test_typecheck.py b/labgraph/tests/test_typecheck.py new file mode 100644 index 000000000..2ba075e86 --- /dev/null +++ b/labgraph/tests/test_typecheck.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import os +import runpy +import shutil +import subprocess +import tempfile +from glob import glob +from pathlib import Path +from typing import Optional +from zipfile import ZipFile + +import labgraph as lg + +from ..runners.launch import _get_pex_path, _in_pex, launch +from ..util.logger import get_logger +from ..util.resource import get_resource_tempfile + + +SOURCE_PATH = "labgraph" + +logger = get_logger(__name__) + + +def test_typecheck() -> None: + """ + Typechecks LabGraph using mypy. Assumes that the test is running from a PEX (if + not, the test skips). + """ + mypy_ini_path = get_resource_tempfile(__name__, "mypy.ini") + mypy_args = ["--config-file", mypy_ini_path] + + zip_path: Optional[str] = None + try: + # If available, get the path to the typecheck_src.zip source archive + zip_path = get_resource_tempfile(__name__, "typecheck_src.zip") + except FileNotFoundError: + pass # Just let zip_path be None and handle this case below + + temp_dir: Optional[tempfile.TemporaryDirectory] = None + if zip_path is None: + # If the source archive is not available, typecheck the installed location + # for LabGraph + src_path = str(Path(lg.__file__).parent) + mypy_args += glob(f"{src_path}/**/*.py", recursive=True) + else: + # If available, typecheck the typecheck_src.zip source archive + temp_dir = tempfile.TemporaryDirectory() # noqa: P201 + src_path = temp_dir.name + # Extract the source files from the zip file + src_file = ZipFile(zip_path) + for file_path in src_file.namelist(): + if file_path.startswith(SOURCE_PATH) and file_path.endswith(".py"): + src_file.extract(file_path, src_path) + mypy_args.append(file_path) + + # Typecheck in a subprocess + mypy_proc = launch("mypy", mypy_args, cwd=src_path, stdout=subprocess.PIPE) + mypy_output: Optional[str] = None + if mypy_proc.stdout is not None: + mypy_output = mypy_proc.stdout.read().decode("utf-8") + mypy_proc.wait() + + if temp_dir is not None: + temp_dir.cleanup() + + if mypy_proc.returncode != 0: + error_message = f"Typechecking failed (exit code {mypy_proc.returncode})" + if mypy_output is not None: + logger.error(mypy_output) + error_message += f":\n\n{mypy_output}" + raise RuntimeError(error_message) From a608aa773f3db50cfd5564de6d34d06f9cbd456a Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 16:00:07 -0800 Subject: [PATCH 06/13] Add additional tests part2. --- labgraph/runners/tests/test_exception.py | 448 ++++++++++++------ .../runners/tests/test_process_manager.py | 128 ++++- labgraph/zmq_node/tests/test_zmq_node.py | 272 +++++++++++ 3 files changed, 695 insertions(+), 153 deletions(-) create mode 100644 labgraph/zmq_node/tests/test_zmq_node.py diff --git a/labgraph/runners/tests/test_exception.py b/labgraph/runners/tests/test_exception.py index f60ab42fa..06addf7dd 100644 --- a/labgraph/runners/tests/test_exception.py +++ b/labgraph/runners/tests/test_exception.py @@ -1,167 +1,313 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. -from typing import Awaitable, Dict, Mapping, Sequence, Type +import enum +import functools +import sys +import time +from typing import Optional +import click import pytest -from ...graphs.graph import Graph -from ...graphs.group import Connections -from ...graphs.method import AsyncPublisher, background, main, publisher, subscriber -from ...graphs.module import Module -from ...graphs.node import Node -from ...graphs.topic import Topic -from ...loggers.logger import Logger -from ...messages.message import Message -from ...runners.local_runner import LocalRunner -from ...runners.parallel_runner import LOGGER_KEY, ParallelRunner -from ...runners.runner import RunnerOptions from ...util.testing import local_test -from ..process_manager import ProcessFailureType, ProcessManagerException - - -class MyTestException(Exception): - pass - - -class MyTestMessage(Message): - int_field: int - - -class MainThrowerNode(Node): - @main - def entry(self) -> None: - raise MyTestException() - - -class BackgroundThrowerNode(Node): - @background - async def background(self) -> None: - raise MyTestException() - - -class PublisherThrowerNode(Node): - TOPIC = Topic(MyTestMessage) - - @publisher(TOPIC) - async def publisher(self) -> AsyncPublisher: - raise MyTestException() - yield self.TOPIC, MyTestMessage(1) - - -class PublisherSubscriberThrowerNode(Node): - TOPIC = Topic(MyTestMessage) - - @subscriber(TOPIC) - async def subscriber(self, message: MyTestMessage) -> None: - raise MyTestException() - - @publisher(TOPIC) - async def publisher(self) -> AsyncPublisher: - yield self.TOPIC, MyTestMessage(1) - - -MODULE_TYPES = ( - MainThrowerNode, - BackgroundThrowerNode, - PublisherThrowerNode, - PublisherSubscriberThrowerNode, +from ..process_manager import ( + ProcessFailureType, + ProcessInfo, + ProcessManager, + ProcessManagerException, + ProcessManagerState, + ProcessPhase, ) +TEST_STARTUP_PERIOD = 5 +TEST_SHUTDOWN_PERIOD = 3 +PROCESS_WAIT_TIME = 0.1 +PROCESS_SLEEP_TIME = 0.01 -class SubscriberNode(Node): - TOPIC = Topic(MyTestMessage) - - @subscriber(TOPIC) - async def subscriber(self, message: MyTestMessage) -> None: - pass - - -class PublisherThrowerGraph(Graph): - PUBLISHER: PublisherThrowerNode - SUBSCRIBER: SubscriberNode - - def connections(self) -> Connections: - return ((self.PUBLISHER.TOPIC, self.SUBSCRIBER.TOPIC),) - - def process_modules(self) -> Sequence[Module]: - return (self.PUBLISHER, self.SUBSCRIBER) - - -class SubscriberThrowerNode(Node): - TOPIC = Topic(MyTestMessage) - @subscriber(TOPIC) - async def subscriber(self, message: MyTestMessage) -> None: - raise MyTestException() +class DummyException(Exception): + """ + Dummy exception for tests to raise. + """ - -class PublisherNode(Node): - TOPIC = Topic(MyTestMessage) - - @publisher(TOPIC) - async def publisher(self) -> AsyncPublisher: - yield self.TOPIC, MyTestMessage(1) - - -class SubscriberThrowerGraph(Graph): - PUBLISHER: PublisherNode - SUBSCRIBER: SubscriberThrowerNode - - def connections(self) -> Connections: - return ((self.PUBLISHER.TOPIC, self.SUBSCRIBER.TOPIC),) - - def process_modules(self) -> Sequence[Module]: - return (self.PUBLISHER, self.SUBSCRIBER) - - -class MainThrowerGraph(Graph): - PUBLISHER: PublisherNode - SUBSCRIBER: SubscriberNode - MAIN: MainThrowerNode - - def connections(self) -> Connections: - return ((self.PUBLISHER.TOPIC, self.SUBSCRIBER.TOPIC),) - - def process_modules(self) -> Sequence[Module]: - return (self.PUBLISHER, self.SUBSCRIBER, self.MAIN) - - -class BackgroundThrowerGraph(Graph): - PUBLISHER: PublisherNode - SUBSCRIBER: SubscriberNode - BACKGROUND: BackgroundThrowerNode - - def connections(self) -> Connections: - return ((self.PUBLISHER.TOPIC, self.SUBSCRIBER.TOPIC),) - - def process_modules(self) -> Sequence[Module]: - return (self.PUBLISHER, self.SUBSCRIBER, self.BACKGROUND) + pass -GRAPH_TYPES = ( - PublisherThrowerGraph, - SubscriberThrowerGraph, - MainThrowerGraph, - BackgroundThrowerGraph, +class ShutdownBehavior(enum.Enum): + """ + A shutdown behavior that `proc` can observe. + """ + + NORMAL = enum.auto() + CRASH = enum.auto() + EXCEPTION = enum.auto() + HANG = enum.auto() + + +def proc( + state: ProcessManagerState, + name: str, + manager_name: str, + shutdown: ShutdownBehavior, + last_phase: ProcessPhase = ProcessPhase.TERMINATED, +) -> None: + """ + A minimal version of a process managed by a `ProcessManager`. Used for testing the + `ProcessManager`. The process simply updates its phase and sleeps. + Args: + state: The `ProcessManager`'s state. + name: The name of the process. + manager_name: The name of the `ProcessManager`. + shutdown: + The shutdown behavior that this process will observe: + - `ShutdownBehavior.NORMAL`: go through all phases and terminate normally. + - `ShutdownBehavior.CRASH`: exit the entire process suddenly at a certain + phase. + - `ShutdownBehavior.EXCEPTION`: raise an exception at a certain phase. + - `ShutdownBehavior.HANG: hang at a certain phase. + last_phase: + The last phase that `proc` will enter. This must be + `ProcessPhase.TERMINATED` if the shutdown behavior is normal. + """ + assert shutdown != ShutdownBehavior.NORMAL or last_phase == ProcessPhase.TERMINATED + last_phase_changed_at = time.perf_counter() + + while True: + time.sleep(PROCESS_SLEEP_TIME) + + # If the manager is stopping, set this process to be stopping + with state.lock: + if ( + state.get(manager_name) == ProcessPhase.STOPPING + and state.get(name) != ProcessPhase.STOPPING + ): + state.update(name, ProcessPhase.STOPPING) + last_phase_changed_at = time.perf_counter() + + # Leave the sleep loop if we are now in the last phase + if ProcessPhase.STOPPING.value >= last_phase.value: + break + + # Transition to the next phase if we have slept for long enough + current_time = time.perf_counter() + if current_time - last_phase_changed_at > PROCESS_WAIT_TIME: + current_phases = state.get_all() + current_phase = current_phases[name] + if ( + current_phase == ProcessPhase.READY + and current_phases[manager_name].value < ProcessPhase.READY.value + ): + continue + new_phase = { + ProcessPhase.STARTING: ProcessPhase.READY, + ProcessPhase.READY: ProcessPhase.RUNNING, + ProcessPhase.RUNNING: ProcessPhase.STOPPING, + ProcessPhase.STOPPING: ProcessPhase.TERMINATED, + }[current_phase] + state.update(name, new_phase) + last_phase_changed_at = time.perf_counter() + + # Leave the sleep loop if we are now in the last phase + if new_phase.value >= last_phase.value: + break + + if shutdown == ShutdownBehavior.EXCEPTION: + # Update the state with a dummy exception, then stop + state.set_exception(name, repr(DummyException())) + with state.lock: + if state.get(name) != ProcessPhase.STOPPING: + state.update(name, ProcessPhase.STOPPING) + time.sleep(PROCESS_WAIT_TIME) + state.update(name, ProcessPhase.TERMINATED) + return + elif shutdown == ShutdownBehavior.HANG: + # Hang forever; the `ProcessManager` should then kill this process + while True: + time.sleep(PROCESS_SLEEP_TIME) + + +@local_test +def test_normal() -> None: + """ + Tests that we can run multiple processes that terminate normally. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + manager.run() + + +@pytest.mark.parametrize( + "crash_phase", + ( + ProcessPhase.STARTING, + ProcessPhase.READY, + ProcessPhase.RUNNING, + ProcessPhase.STOPPING, + ), ) - - -class PublisherSubscriberGraph(Graph): - PUBLISHER: PublisherNode - SUBSCRIBER: SubscriberNode - - def connections(self) -> Connections: - return ((self.PUBLISHER.TOPIC, self.SUBSCRIBER.TOPIC),) - - def logging(self) -> Dict[str, Topic]: - return {"topic": self.PUBLISHER.TOPIC} - - def process_modules(self) -> Sequence[Module]: - return (self.PUBLISHER, self.SUBSCRIBER) - - -class ThrowerLogger(Logger): - def write(self, messages_by_logging_id: Mapping[str, Sequence[Message]]) -> None: - raise MyTestException() \ No newline at end of file +@local_test +def test_crash(crash_phase: ProcessPhase) -> None: + """ + Tests that we can run multiple processes where one of them crashes. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=( + "--manager-name", + "test_manager", + "--shutdown", + "CRASH", + "--last-phase", + crash_phase.name, + ), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + with pytest.raises(ProcessManagerException) as ex: + manager.run() + assert ex.value.failures == {"proc1": ProcessFailureType.CRASH, "proc2": None} + + +@pytest.mark.parametrize( + "exception_phase", + ( + ProcessPhase.STARTING, + ProcessPhase.READY, + ProcessPhase.RUNNING, + ProcessPhase.STOPPING, + ), +) +@local_test +def test_exception(exception_phase: ProcessPhase) -> None: + """ + Tests that we can run multiple processes where one of them raises an exception. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=( + "--manager-name", + "test_manager", + "--shutdown", + "EXCEPTION", + "--last-phase", + exception_phase.name, + ), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + with pytest.raises(ProcessManagerException) as ex: + manager.run() + assert ex.value.failures == {"proc1": ProcessFailureType.EXCEPTION, "proc2": None} + + +@pytest.mark.parametrize( + "hang_phase", + ( + ProcessPhase.STARTING, + ProcessPhase.READY, + ProcessPhase.RUNNING, + ProcessPhase.STOPPING, + ), +) +@local_test +def test_hang(hang_phase: ProcessPhase) -> None: + """ + Tests that we can run multiple processes where one of them hangs. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=( + "--manager-name", + "test_manager", + "--shutdown", + "HANG", + "--last-phase", + hang_phase.name, + ), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + with pytest.raises(ProcessManagerException) as ex: + manager.run() + assert ex.value.failures == {"proc1": ProcessFailureType.HANG, "proc2": None} + + +@click.command() +@click.option(f"--{ProcessManagerState.SUBPROCESS_ARG}", required=True) +@click.option("--process-name", required=True) +@click.option("--manager-name", required=True) +@click.option("--shutdown", required=True) +@click.option("--last-phase") +def child_main( + process_manager_state_file: str, + process_name: str, + manager_name: str, + shutdown: str, + last_phase: Optional[str] = None, +) -> None: + proc( + ProcessManagerState.load(process_manager_state_file), + process_name, + manager_name, + ShutdownBehavior[shutdown], + ProcessPhase[last_phase] if last_phase is not None else ProcessPhase.TERMINATED, + ) + + +if __name__ == "__main__": + if f"--{ProcessManagerState.SUBPROCESS_ARG}" in sys.argv: + child_main() \ No newline at end of file diff --git a/labgraph/runners/tests/test_process_manager.py b/labgraph/runners/tests/test_process_manager.py index 51f491721..06addf7dd 100644 --- a/labgraph/runners/tests/test_process_manager.py +++ b/labgraph/runners/tests/test_process_manager.py @@ -56,7 +56,6 @@ def proc( """ A minimal version of a process managed by a `ProcessManager`. Used for testing the `ProcessManager`. The process simply updates its phase and sleeps. - Args: state: The `ProcessManager`'s state. name: The name of the process. @@ -129,6 +128,32 @@ def proc( time.sleep(PROCESS_SLEEP_TIME) +@local_test +def test_normal() -> None: + """ + Tests that we can run multiple processes that terminate normally. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + manager.run() + + @pytest.mark.parametrize( "crash_phase", ( @@ -138,6 +163,39 @@ def proc( ProcessPhase.STOPPING, ), ) +@local_test +def test_crash(crash_phase: ProcessPhase) -> None: + """ + Tests that we can run multiple processes where one of them crashes. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=( + "--manager-name", + "test_manager", + "--shutdown", + "CRASH", + "--last-phase", + crash_phase.name, + ), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + with pytest.raises(ProcessManagerException) as ex: + manager.run() + assert ex.value.failures == {"proc1": ProcessFailureType.CRASH, "proc2": None} @pytest.mark.parametrize( @@ -149,6 +207,39 @@ def proc( ProcessPhase.STOPPING, ), ) +@local_test +def test_exception(exception_phase: ProcessPhase) -> None: + """ + Tests that we can run multiple processes where one of them raises an exception. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=( + "--manager-name", + "test_manager", + "--shutdown", + "EXCEPTION", + "--last-phase", + exception_phase.name, + ), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + with pytest.raises(ProcessManagerException) as ex: + manager.run() + assert ex.value.failures == {"proc1": ProcessFailureType.EXCEPTION, "proc2": None} @pytest.mark.parametrize( @@ -160,6 +251,39 @@ def proc( ProcessPhase.STOPPING, ), ) +@local_test +def test_hang(hang_phase: ProcessPhase) -> None: + """ + Tests that we can run multiple processes where one of them hangs. + """ + manager = ProcessManager( + processes=( + ProcessInfo( + module=__name__, + name="proc1", + args=( + "--manager-name", + "test_manager", + "--shutdown", + "HANG", + "--last-phase", + hang_phase.name, + ), + ), + ProcessInfo( + module=__name__, + name="proc2", + args=("--manager-name", "test_manager", "--shutdown", "NORMAL"), + ), + ), + name="test_manager", + startup_period=TEST_STARTUP_PERIOD, + shutdown_period=TEST_SHUTDOWN_PERIOD, + ) + + with pytest.raises(ProcessManagerException) as ex: + manager.run() + assert ex.value.failures == {"proc1": ProcessFailureType.HANG, "proc2": None} @click.command() @@ -186,4 +310,4 @@ def child_main( if __name__ == "__main__": if f"--{ProcessManagerState.SUBPROCESS_ARG}" in sys.argv: - child_main() + child_main() \ No newline at end of file diff --git a/labgraph/zmq_node/tests/test_zmq_node.py b/labgraph/zmq_node/tests/test_zmq_node.py new file mode 100644 index 000000000..fd6cd57d6 --- /dev/null +++ b/labgraph/zmq_node/tests/test_zmq_node.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import asyncio +import os +import time +from multiprocessing import Process +from typing import List, Sequence + +import pytest +import zmq + +from ...graphs import ( + AsyncPublisher, + Config, + Connections, + Graph, + Module, + Node, + Topic, + publisher, + subscriber, +) +from ...runners import LocalRunner, NormalTermination, ParallelRunner +from ...util.testing import get_free_port, get_test_filename, local_test +from ..zmq_message import ZMQMessage +from ..zmq_poller_node import ZMQPollerConfig, ZMQPollerNode +from ..zmq_sender_node import ZMQSenderConfig, ZMQSenderNode + + +NUM_MESSAGES = 5 +SAMPLE_RATE = 10 +STARTUP_TIME = 1 + +ZMQ_ADDR = "tcp://127.0.0.1" +ZMQ_TOPIC = "zmq_topic" + +DATA_DELIMITER = b"\0" + + +class MySinkConfig(Config): + output_filename: str + + +class MySink(Node): + """ + Convenience node for receiving messages from a `ZMQPollerNode`. + """ + + TOPIC = Topic(ZMQMessage) + config: MySinkConfig + + def setup(self) -> None: + self.output_file = open(self.config.output_filename, "wb") + self.num_received = 0 + + @subscriber(TOPIC) + async def sink(self, message: ZMQMessage) -> None: + self.output_file.write(message.data) + self.output_file.write(DATA_DELIMITER) + + self.num_received += 1 + if self.num_received == NUM_MESSAGES: + raise NormalTermination() + + def cleanup(self) -> None: + self.output_file.close() + + +class MySourceConfig(Config): + should_terminate: bool = True + + +class MySource(Node): + """ + Convenience node for sending messages to a `ZMQSenderNode`. + """ + + TOPIC = Topic(ZMQMessage) + config: MySourceConfig + samples = [bytes([i]) for i in range(1, NUM_MESSAGES + 1)] + + @publisher(TOPIC) + async def source(self) -> AsyncPublisher: + for sample_bytes in self.samples: + yield self.TOPIC, ZMQMessage(sample_bytes) + await asyncio.sleep(1 / SAMPLE_RATE) + + if self.config.should_terminate: + raise NormalTermination() + + +def write_samples_to_zmq(address: str, samples: Sequence[bytes], topic: str) -> None: + context = zmq.Context() + socket = context.socket(zmq.PUB) + socket.bind(address) + + # NOTE: In place of waiting on a monitor socket, we wait for a startup time + # here. Otherwise we would need to listen on an extra socket in this function, + # increasing the complexity of this test. If this startup wait becomes flaky, we may + # want to consider adding the monitor socket here. + time.sleep(STARTUP_TIME) + for sample_bytes in samples: + socket.send_multipart((bytes(topic, "UTF-8"), sample_bytes)) + time.sleep(1 / SAMPLE_RATE) + socket.close() + + +def recv_samples_from_zmq(address: str, topic: str, output_fname: str) -> None: + context = zmq.Context() + socket = context.socket(zmq.SUB) + socket.setsockopt(zmq.SUBSCRIBE, bytes(topic, "UTF-8")) + socket.connect(address) + + with open(output_fname, "bw") as output_file: + for _ in range(NUM_MESSAGES): + received = socket.recv_multipart() + topic, data = received + output_file.write(data) + output_file.write(DATA_DELIMITER) + socket.close() + + +@local_test +def test_zmq_poller_node() -> None: + """ + Tests that a `ZMQPollerNode` is able to read samples from a ZMQ socket and echo + the samples back out of the graph. + """ + + class MyZMQPollerGraphConfig(Config): + read_addr: str + zmq_topic: str + output_filename: str + + class MyZMQPollerGraph(Graph): + MY_SOURCE: ZMQPollerNode + MY_SINK: MySink + config: MyZMQPollerGraphConfig + + def setup(self) -> None: + self.MY_SOURCE.configure( + ZMQPollerConfig( + read_addr=self.config.read_addr, zmq_topic=self.config.zmq_topic + ) + ) + self.MY_SINK.configure( + MySinkConfig(output_filename=self.config.output_filename) + ) + + def connections(self) -> Connections: + return ((self.MY_SOURCE.topic, self.MY_SINK.TOPIC),) + + graph = MyZMQPollerGraph() + output_filename = get_test_filename() + address = f"{ZMQ_ADDR}:{get_free_port()}" + graph.configure( + MyZMQPollerGraphConfig( + read_addr=address, zmq_topic=ZMQ_TOPIC, output_filename=output_filename + ) + ) + runner = LocalRunner(module=graph) + + samples = [bytes([i]) for i in range(1, NUM_MESSAGES + 1)] + p = Process(target=write_samples_to_zmq, args=(address, samples, ZMQ_TOPIC)) + p.start() + runner.run() + p.join() + + with open(output_filename, "br") as f: + data = f.read() + assert set(samples) == set(data.strip(DATA_DELIMITER).split(DATA_DELIMITER)) + + +@local_test +def test_zmq_sender_node() -> None: + """ + Tests that a `ZMQSenderNode` is able to read samples from the graph and write the + samples back out to a ZMQ socket. + """ + + class MyZMQSenderGraph(Graph): + MY_SOURCE: MySource + MY_SINK: ZMQSenderNode + + config: ZMQSenderConfig + + def setup(self) -> None: + self.MY_SOURCE.configure(MySourceConfig()) + self.MY_SINK.configure(self.config) + + def connections(self) -> Connections: + return ((self.MY_SOURCE.TOPIC, self.MY_SINK.topic),) + + output_filename = get_test_filename() + graph = MyZMQSenderGraph() + address = f"{ZMQ_ADDR}:{get_free_port()}" + graph.configure(ZMQSenderConfig(write_addr=address, zmq_topic=ZMQ_TOPIC)) + runner = LocalRunner(module=graph) + + p = Process( + target=recv_samples_from_zmq, args=(address, ZMQ_TOPIC, output_filename) + ) + p.start() + runner.run() + p.join() + + with open(output_filename, "br") as f: + data = f.read() + assert set(graph.MY_SOURCE.samples) == set( + data.strip(DATA_DELIMITER).split(DATA_DELIMITER) + ) + + +@local_test +def test_zmq_send_and_poll() -> None: + """ + Tests that a `ZMQSenderNode` and a `ZMQPollerNode` can work together. + """ + + class MyZMQGraphConfig(Config): + addr: str + zmq_topic: str + output_filename: str + + class MyZMQGraph(Graph): + DF_SOURCE: MySource + ZMQ_SENDER: ZMQSenderNode + ZMQ_POLLER: ZMQPollerNode + DF_SINK: MySink + + def setup(self) -> None: + self.DF_SOURCE.configure(MySourceConfig(should_terminate=False)) + self.ZMQ_SENDER.configure( + ZMQSenderConfig( + write_addr=self.config.addr, zmq_topic=self.config.zmq_topic + ) + ) + self.ZMQ_POLLER.configure( + ZMQPollerConfig( + read_addr=self.config.addr, zmq_topic=self.config.zmq_topic + ) + ) + self.DF_SINK.configure( + MySinkConfig(output_filename=self.config.output_filename) + ) + + def connections(self) -> Connections: + return ( + (self.DF_SOURCE.TOPIC, self.ZMQ_SENDER.topic), + (self.ZMQ_POLLER.topic, self.DF_SINK.TOPIC), + ) + + def process_modules(self) -> Sequence[Module]: + return (self.DF_SOURCE, self.ZMQ_SENDER, self.ZMQ_POLLER, self.DF_SINK) + + output_filename = get_test_filename() + graph = MyZMQGraph() + address = f"{ZMQ_ADDR}:{get_free_port()}" + graph.configure( + MyZMQGraphConfig( + addr=address, zmq_topic=ZMQ_TOPIC, output_filename=output_filename + ) + ) + runner = ParallelRunner(graph=graph) + runner.run() + + with open(output_filename, "br") as f: + data = f.read() + assert set(graph.DF_SOURCE.samples) == set( + data.strip(DATA_DELIMITER).split(DATA_DELIMITER) + ) From 512998e54fc25831405677412422d466bd70e401 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 16:04:34 -0800 Subject: [PATCH 07/13] Update py36 support in Docker. --- Dockerfile | 4 ++-- Dockerfile.Centos | 2 +- setup_py36.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 setup_py36.py diff --git a/Dockerfile b/Dockerfile index ce77756a3..0048e1d5a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,8 +50,8 @@ WORKDIR "/opt/labgraph" COPY . . # Build LabGraph Wheel -RUN python3.6 setup.py install --user -RUN python3.6 setup.py sdist bdist_wheel +RUN python3.6 setup_py36.py install --user +RUN python3.6 setup_py36.py sdist bdist_wheel RUN python3.6 -m pip install auditwheel RUN auditwheel repair dist/*whl -w dist/ diff --git a/Dockerfile.Centos b/Dockerfile.Centos index e81044cda..e5da0a0cd 100644 --- a/Dockerfile.Centos +++ b/Dockerfile.Centos @@ -41,4 +41,4 @@ RUN chmod 2777 /usr/local/var/run/watchman # Copy LabGraph files WORKDIR "/opt/labgraph" COPY . . -RUN python3.6 setup.py install --user +RUN python3.6 setup_py36.py install --user diff --git a/setup_py36.py b/setup_py36.py new file mode 100644 index 000000000..58770bac4 --- /dev/null +++ b/setup_py36.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +from buck_ext import BuckExtension, buck_build_ext +from setuptools import find_packages, setup + +LIBRARY_EXTENSIONS = [ + BuckExtension( + name="cthulhubindings", + target="//:cthulhubindings#default,shared", + ), + BuckExtension( + name="labgraph_cpp", + target="//:labgraph_cpp_bindings#default,shared", + ), + BuckExtension( + name="MyCPPNodes", + target="//:MyCPPNodes#default,shared", + ), +] + + +setup( + name="labgraph", + version="1.0.3", + description="Python streaming framework", + packages=find_packages(), + package_data={"labgraph": ["tests/mypy.ini"]}, + python_requires=">=3.6, <3.7", + ext_modules=LIBRARY_EXTENSIONS, + cmdclass={"build_ext": buck_build_ext}, + install_requires=[ + "appdirs==1.4.3", + "click==7.0", + "dataclasses==0.6", + "h5py==2.10.0", + "matplotlib==3.1.1", + "mypy==0.782", + "numpy==1.16.4", + "psutil==5.6.7", + "pytest==3.10.1", + "pytest_mock==2.0.0", + "pyzmq==18.1.0", + "typeguard==2.5.1", + "typing_extensions>=3.7.4.3", + "yappi==1.2.5", + "pylsl==1.15.0", + ], +) \ No newline at end of file From 21e02c4b0f44f4e2b74cd48f144e8ab1f7a0bdf8 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 16:19:32 -0800 Subject: [PATCH 08/13] Update python/DEFS. --- third-party/python/DEFS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third-party/python/DEFS b/third-party/python/DEFS index 54f43c4bd..13a162cc9 100644 --- a/third-party/python/DEFS +++ b/third-party/python/DEFS @@ -18,7 +18,7 @@ def _config_var(executable, key): def config_var(key): - for executable in (python3", "python" + PYTHON_VERSION): + for executable in ("python3", "python" + PYTHON_VERSION): if not shutil.which(executable): continue if _config_var(executable, "py_version_short") != PYTHON_VERSION: From fb3d7d60f082c3ceb94089f4fe47fb6f95d515fd Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 16:48:48 -0800 Subject: [PATCH 09/13] Update for Dockerfile --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index 0048e1d5a..2cea4c47f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,6 +56,7 @@ RUN python3.6 -m pip install auditwheel RUN auditwheel repair dist/*whl -w dist/ # Test LabGraph +WORKDIR "/tmp" RUN python3.6 -m pytest --pyargs -v labgraph._cthulhu RUN python3.6 -m pytest --pyargs -v labgraph.events RUN python3.6 -m pytest --pyargs -v labgraph.graphs From 7aa4069fb7eb49a39a67b225080179b1c8d809bd Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 17:02:01 -0800 Subject: [PATCH 10/13] Update for Docker part 2. --- Dockerfile | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2cea4c47f..e9e4caff6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -57,14 +57,4 @@ RUN auditwheel repair dist/*whl -w dist/ # Test LabGraph WORKDIR "/tmp" -RUN python3.6 -m pytest --pyargs -v labgraph._cthulhu -RUN python3.6 -m pytest --pyargs -v labgraph.events -RUN python3.6 -m pytest --pyargs -v labgraph.graphs -RUN python3.6 -m pytest --pyargs -v labgraph.loggers -RUN python3.6 -m pytest --pyargs -v labgraph.messages -RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_process_manager -RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_aligner -RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_cpp -RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_exception -RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_launch -RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_runner +RUN python3.6 -m pytest --pyargs -v labgraph From f640c629c216e92c057d9e90cc4a1f5e176bc7f3 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 17:16:58 -0800 Subject: [PATCH 11/13] Updates for LabgraphError. --- labgraph/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/labgraph/__init__.py b/labgraph/__init__.py index cc41712e2..e17ce77a4 100644 --- a/labgraph/__init__.py +++ b/labgraph/__init__.py @@ -14,7 +14,7 @@ "Connections", "CPPNodeConfig", "DeferredMessage", - "LabGraphError", + "LabgraphError", "Event", "EventGraph", "EventPublishingHeap", @@ -108,4 +108,4 @@ TimestampAligner, run, ) -from .util import LabGraphError +from .util import LabgraphError From 910c7448bdf4d623f27d7923508643b6c3877a7e Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 17:30:50 -0800 Subject: [PATCH 12/13] Update for setup_py36. --- setup_py36.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup_py36.py b/setup_py36.py index 58770bac4..dea4091f1 100644 --- a/setup_py36.py +++ b/setup_py36.py @@ -38,12 +38,13 @@ "mypy==0.782", "numpy==1.16.4", "psutil==5.6.7", + "pylsl==1.15.0", "pytest==3.10.1", "pytest_mock==2.0.0", "pyzmq==18.1.0", "typeguard==2.5.1", "typing_extensions>=3.7.4.3", + "websockets==8.1", "yappi==1.2.5", - "pylsl==1.15.0", ], ) \ No newline at end of file From f8acf63cd9b8471da6a2a240bd794c840c53c995 Mon Sep 17 00:00:00 2001 From: jf Date: Mon, 6 Dec 2021 17:46:41 -0800 Subject: [PATCH 13/13] Revert back Dockerfile labgraph test. --- Dockerfile | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index e9e4caff6..0048e1d5a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,5 +56,14 @@ RUN python3.6 -m pip install auditwheel RUN auditwheel repair dist/*whl -w dist/ # Test LabGraph -WORKDIR "/tmp" -RUN python3.6 -m pytest --pyargs -v labgraph +RUN python3.6 -m pytest --pyargs -v labgraph._cthulhu +RUN python3.6 -m pytest --pyargs -v labgraph.events +RUN python3.6 -m pytest --pyargs -v labgraph.graphs +RUN python3.6 -m pytest --pyargs -v labgraph.loggers +RUN python3.6 -m pytest --pyargs -v labgraph.messages +RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_process_manager +RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_aligner +RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_cpp +RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_exception +RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_launch +RUN python3.6 -m pytest --pyargs -v labgraph.runners.tests.test_runner