Skip to content

Commit

Permalink
[patch] Introduce caching (#395)
Browse files Browse the repository at this point in the history
* Refactor: output signals to emission

Wrap them in `emit()` and `emitting_channels` instead of manually calling them. This lets us tighten up If-like nodes too.

* Introduce caching

To shortcut actually running a node and just return existing output if its cached input matches its current input (by `==` test)

* Extend speedup test to include caching

* Add docstring

* Expose use_cache as a class attribute

So it can be set at class definition time, even by decorators

* Discuss caching in the deepdive

* Format black

---------

Co-authored-by: pyiron-runner <[email protected]>
  • Loading branch information
liamhuber and pyiron-runner authored Jul 30, 2024
1 parent ddc1b1e commit 41d8d42
Show file tree
Hide file tree
Showing 11 changed files with 1,154 additions and 863 deletions.
1,711 changes: 910 additions & 801 deletions notebooks/deepdive.ipynb

Large diffs are not rendered by default.

44 changes: 42 additions & 2 deletions pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Literal, Optional, TYPE_CHECKING

from pyiron_snippets.colors import SeabornColors
from pyiron_snippets.dotdict import DotDict

from pyiron_workflow.draw import Node as GraphvizNode
from pyiron_workflow.logging import logger
Expand All @@ -34,6 +35,7 @@
import graphviz
from pyiron_snippets.files import DirectoryObject

from pyiron_workflow.channels import OutputSignal
from pyiron_workflow.nodes.composite import Composite


Expand Down Expand Up @@ -128,6 +130,8 @@ class Node(
- NOTE: Don't forget to :meth:`shutdown` any created executors outside of a
`with` context when you're done with them; we give a convenience method for
this.
- Nodes can optionally cache their input to skip running altogether and use
existing output when their current input matches the cached input.
- Nodes created from a registered package store their package identifier as a class
attribute.
- [ALPHA FEATURE] Nodes can be saved to and loaded from file if python >= 3.11.
Expand Down Expand Up @@ -261,6 +265,11 @@ class Node(
Additional signal channels in derived classes can be added to
:attr:`signals.inputs` and :attr:`signals.outputs` after this mixin class is
initialized.
use_cache (bool): Whether or not to cache the inputs and, when the current
inputs match the cached input (by `==` comparison), to bypass running the
node and simply continue using the existing outputs. Note that you may be
able to trigger a false cache hit in some special case of non-idempotent
nodes working on mutable data.
Methods:
__call__: An alias for :meth:`pull` that aggressively runs upstream nodes even
Expand Down Expand Up @@ -303,6 +312,7 @@ class Node(
"""

package_identifier = None
use_cache = True

# This isn't nice, just a technical necessity in the current implementation
# Eventually, of course, this needs to be _at least_ file-format independent
Expand Down Expand Up @@ -336,6 +346,7 @@ def __init__(
storage_backend=storage_backend,
)
self.save_after_run = save_after_run
self.cached_inputs = None
self._user_data = {} # A place for power-users to bypass node-injection

self._setup_node()
Expand Down Expand Up @@ -491,6 +502,20 @@ def run(
if fetch_input:
self.inputs.fetch()

if self.use_cache and self.cache_hit: # Read and use cache

if self.parent is None and emit_ran_signal:
self.emit()
elif self.parent is not None:
self.parent.register_child_starting(self)
self.parent.register_child_finished(self)
if emit_ran_signal:
self.parent.register_child_emitting(self)

return self._outputs_to_run_return()
elif self.use_cache: # Write cache and continue
self.cached_inputs = self.inputs.to_value_dict()

if self.parent is not None:
self.parent.register_child_starting(self)

Expand Down Expand Up @@ -603,6 +628,13 @@ def run_data_tree(self, run_parent_trees_too=False) -> None:
if self.parent is not None:
self.parent.starting_nodes = parent_starting_nodes

@property
def cache_hit(self):
return self.inputs.to_value_dict() == self.cached_inputs

def _outputs_to_run_return(self):
return DotDict(self.outputs.to_value_dict())

def _finish_run(self, run_output: tuple | Future) -> Any | tuple:
try:
processed_output = super()._finish_run(run_output=run_output)
Expand All @@ -616,9 +648,9 @@ def _finish_run(self, run_output: tuple | Future) -> Any | tuple:
def _finish_run_and_emit_ran(self, run_output: tuple | Future) -> Any | tuple:
processed_output = self._finish_run(run_output)
if self.parent is None:
self.signals.output.ran()
self.emit()
else:
self.parent.register_child_emitting_ran(self)
self.parent.register_child_emitting(self)
return processed_output

_finish_run_and_emit_ran.__doc__ = (
Expand All @@ -629,6 +661,14 @@ def _finish_run_and_emit_ran(self, run_output: tuple | Future) -> Any | tuple:
"""
)

@property
def emitting_channels(self) -> tuple[OutputSignal]:
return (self.signals.output.ran,)

def emit(self):
for channel in self.emitting_channels:
channel()

def execute(self, *args, **kwargs):
"""
A shortcut for :meth:`run` with particular flags.
Expand Down
25 changes: 16 additions & 9 deletions pyiron_workflow/nodes/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def on_run(self):

while len(self.running_children) > 0 or len(self.signal_queue) > 0:
try:
ran_signal, receiver = self.signal_queue.pop(0)
receiver(ran_signal)
firing, receiving = self.signal_queue.pop(0)
receiving(firing)
except IndexError:
# The signal queue is empty, but there is still someone running...
sleep(self._child_sleep_interval)
Expand Down Expand Up @@ -192,17 +192,18 @@ def register_child_finished(self, child: Node) -> None:
f"{self.provenance_by_execution}, {self.provenance_by_completion}"
) from e

def register_child_emitting_ran(self, child: Node) -> None:
def register_child_emitting(self, child: Node) -> None:
"""
To be called by children when they want to emit their `ran` signal.
To be called by children when they want to emit their signals.
Args:
child [Node]: The child that is finished and would like to fire its `ran`
signal. Should always be a child of `self`, but this is not explicitly
verified at runtime.
signal (and possibly others). Should always be a child of `self`, but
this is not explicitly verified at runtime.
"""
for conn in child.signals.output.ran.connections:
self.signal_queue.append((child.signals.output.ran, conn))
for firing in child.emitting_channels:
for receiving in firing.connections:
self.signal_queue.append((firing, receiving))

@property
def run_args(self) -> tuple[tuple, dict]:
Expand All @@ -211,7 +212,7 @@ def run_args(self) -> tuple[tuple, dict]:
def process_run_result(self, run_output):
if run_output is not self:
self._parse_remotely_executed_self(run_output)
return DotDict(self.outputs.to_value_dict())
return self._outputs_to_run_return()

def _parse_remotely_executed_self(self, other_self):
# Un-parent existing nodes before ditching them
Expand Down Expand Up @@ -259,6 +260,7 @@ def add_child(
f"Only new {Node.__name__} instances may be added, but got "
f"{type(child)}."
)
self.cached_inputs = None # Reset cache after graph change
return super().add_child(child, label=label, strict_naming=strict_naming)

def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]:
Expand All @@ -276,6 +278,7 @@ def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]:
disconnected = child.disconnect()
if child in self.starting_nodes:
self.starting_nodes.remove(child)
self.cached_inputs = None # Reset cache after graph change
return disconnected

def replace_child(
Expand Down Expand Up @@ -354,6 +357,10 @@ def replace_child(
for sending_channel, receiving_channel in inbound_links + outbound_links:
sending_channel.value_receiver = receiving_channel

# Clear caches
self.cached_inputs = None
replacement.cached_inputs = None

return owned_node

def executor_shutdown(self, wait=True, *, cancel_futures=False):
Expand Down
9 changes: 8 additions & 1 deletion pyiron_workflow/nodes/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def for_node_factory(
iter_on: tuple[str, ...] = (),
zip_on: tuple[str, ...] = (),
output_column_map: dict | None = None,
use_cache: bool = True,
/,
):
combined_docstring = (
Expand All @@ -444,6 +445,7 @@ def for_node_factory(
"_iter_on": iter_on,
"_zip_on": zip_on,
"__doc__": combined_docstring,
"use_cache": use_cache,
},
{"output_column_map": output_column_map},
)
Expand All @@ -455,6 +457,7 @@ def for_node(
iter_on=(),
zip_on=(),
output_column_map: Optional[dict[str, str]] = None,
use_cache: bool = True,
**node_kwargs,
):
"""
Expand Down Expand Up @@ -482,6 +485,8 @@ def for_node(
Necessary iff the body node has the same label for an output channel and
an input channel being looped over. (Default is None, just use the output
channel labels as columb names.)
use_cache (bool): Whether this node should default to caching its values.
(Default is True.)
**node_kwargs: Regular keyword node arguments.
Returns:
Expand Down Expand Up @@ -555,6 +560,8 @@ def for_node(
"""
for_node_factory.clear(_for_node_class_name(body_node_class, iter_on, zip_on))
cls = for_node_factory(body_node_class, iter_on, zip_on, output_column_map)
cls = for_node_factory(
body_node_class, iter_on, zip_on, output_column_map, use_cache
)
cls.preview_io()
return cls(*node_args, **node_kwargs)
32 changes: 27 additions & 5 deletions pyiron_workflow/nodes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,13 @@ def process_run_result(self, function_output: Any | tuple) -> Any | tuple:
(function_output,) if len(self.outputs) == 1 else function_output,
):
out.value = value
return function_output
return self._outputs_to_run_return()

def _outputs_to_run_return(self):
output = tuple(self.outputs.to_value_dict().values())
if len(output) == 1:
output = output[0]
return output

def to_dict(self):
return {
Expand All @@ -348,7 +354,11 @@ def color(self) -> str:

@classfactory
def function_node_factory(
node_function: callable, validate_output_labels: bool, /, *output_labels
node_function: callable,
validate_output_labels: bool,
use_cache: bool = True,
/,
*output_labels,
):
"""
Create a new :class:`Function` node class based on the given node function. This
Expand All @@ -358,6 +368,8 @@ def function_node_factory(
node_function (callable): The function to be wrapped by the node.
validate_output_labels (bool): Flag to indicate if output labels should be
validated.
use_cache (bool): Whether nodes of this type should default to caching their
values.
*output_labels: Optional labels for the function's output channels.
Returns:
Expand All @@ -373,12 +385,17 @@ def function_node_factory(
"_output_labels": None if len(output_labels) == 0 else output_labels,
"_validate_output_labels": validate_output_labels,
"__doc__": node_function.__doc__,
"use_cache": use_cache,
},
{},
)


def as_function_node(*output_labels: str, validate_output_labels=True):
def as_function_node(
*output_labels: str,
validate_output_labels=True,
use_cache=True,
):
"""
Decorator to create a new :class:`Function` node class from a given function. This
function gets executed on each :meth:`run` of the resulting function.
Expand All @@ -388,6 +405,8 @@ def as_function_node(*output_labels: str, validate_output_labels=True):
validate_output_labels (bool): Flag to indicate if output labels should be
validated against the return values in the function node source code.
Defaults to True.
use_cache (bool): Whether nodes of this type should default to caching their
values. (Default is True.)
Returns:
Callable: A decorator that converts a function into a :class:`Function` node
Expand All @@ -397,7 +416,7 @@ def as_function_node(*output_labels: str, validate_output_labels=True):
def decorator(node_function):
function_node_factory.clear(node_function.__name__) # Force a fresh class
factory_made = function_node_factory(
node_function, validate_output_labels, *output_labels
node_function, validate_output_labels, use_cache, *output_labels
)
factory_made._class_returns_from_decorated_function = node_function
factory_made.preview_io()
Expand All @@ -411,6 +430,7 @@ def function_node(
*node_args,
output_labels: str | tuple[str, ...] | None = None,
validate_output_labels: bool = True,
use_cache: bool = True,
**node_kwargs,
):
"""
Expand All @@ -427,6 +447,8 @@ def function_node(
validated against the return values in the function source code. Defaults
to True. Disabling this may be useful if the source code is not available
or if the function has multiple return statements.
use_cache (bool): Whether this node should default to caching its values.
(Default is True.)
**node_kwargs: Keyword arguments for the :class:`Function` initialization --
parsed as node input data when the keyword matches an input channel.
Expand All @@ -439,7 +461,7 @@ def function_node(
output_labels = (output_labels,)
function_node_factory.clear(node_function.__name__) # Force a fresh class
factory_made = function_node_factory(
node_function, validate_output_labels, *output_labels
node_function, validate_output_labels, use_cache, *output_labels
)
factory_made.preview_io()
return factory_made(*node_args, **node_kwargs)
Loading

0 comments on commit 41d8d42

Please sign in to comment.