Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor external model to gate functions #336

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
16b2761
made book keeping fold indices buffers
loreloc Dec 18, 2024
bc72154
replaced fold indexing in the address book with more efficient unsque…
loreloc Dec 18, 2024
819f26e
re-run all notebooks
loreloc Dec 18, 2024
9b66a93
Merge branch 'april-tools:main' into source/symbolic-neural-conditioning
n28div Dec 18, 2024
31d2bcd
Merge branch 'symbolic-neural-conditioning' of https://github.com/apr…
n28div Dec 19, 2024
a97e903
Merge branch 'main' into address-book-gpu
loreloc Dec 19, 2024
a577a47
refactored external model to gate function; parametrize circuit based…
n28div Dec 20, 2024
c81c9fe
add tucker circuit construction template and generalize kronecker/tuc…
loreloc Jan 3, 2025
a8c1dcf
clean import
loreloc Jan 3, 2025
f367966
add tensor train circuit template
loreloc Jan 3, 2025
b287092
add probabilistic versions of cp and tucker, and improved unit tests
loreloc Jan 10, 2025
f294141
removed dangling test file
loreloc Jan 10, 2025
2931c04
fix torchvision versioning mess
loreloc Jan 10, 2025
2a3452d
re-run compression-cp-factorization.ipynb
loreloc Jan 10, 2025
a5839aa
removed channels from everywhere
loreloc Jan 11, 2025
4b130b9
fix pics
loreloc Jan 11, 2025
6dcc375
fix logic circuit construction
loreloc Jan 11, 2025
159dea1
fix integrate of embedding layers
loreloc Jan 11, 2025
6f87531
add PySDD to notebooks deps
loreloc Jan 11, 2025
e2f9eef
re-run notebooks
loreloc Jan 11, 2025
aa6574f
Added gaussian opt for input nodes to doc string
turnmanh Jan 27, 2025
38ddf45
random fix in sampling unit test
loreloc Feb 5, 2025
c446634
Merge pull request #335 from april-tools/address-book-gpu
loreloc Feb 5, 2025
226c66d
Merge branch 'main' into tucker-tensor-train
loreloc Feb 5, 2025
c733a96
strengthen tensor train circuit template unite test
loreloc Feb 5, 2025
503dc96
Merge pull request #339 from april-tools/tucker-tensor-train
loreloc Feb 5, 2025
2627047
Merge branch 'main' into remove-channels
loreloc Feb 5, 2025
134cd5d
Merge pull request #340 from april-tools/remove-channels
loreloc Feb 5, 2025
b200550
Merge branch 'main' into patch-1
loreloc Feb 5, 2025
0cd00d8
Merge pull request #345 from turnmanh/patch-1
loreloc Feb 5, 2025
72394ca
minor fix device
loreloc Feb 5, 2025
9a9e288
allow labeling of symbolic layers
n28div Jan 7, 2025
2dc3272
moved layer label type to utils
n28div Jan 7, 2025
4a69061
simplify logic circuit
n28div Feb 10, 2025
cd94a7f
use layer label in plotting
n28div Feb 10, 2025
356bb86
upgrade to 0.2.1
loreloc Feb 11, 2025
675a7c8
Layer label into initializer
n28div Feb 20, 2025
5dda1fb
Merge branch 'logic-circuits' of github.com:n28div/cirkit into logic-…
n28div Feb 20, 2025
87a642f
format code with black
n28div Feb 20, 2025
2e116d9
caching pre-processing on logic-circuit construction
n28div Feb 20, 2025
88a5fe4
Merge branch 'main' into source/symbolic-neural-conditioning
n28div Feb 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions cirkit/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SUPPORTED_BACKENDS = ["torch"]

CompiledCircuit = TypeVar("CompiledCircuit")
ExternalModel = TypeVar("ExternalModel")
GateFunction = TypeVar("GateFunction")


class CompiledCircuitsMap:
Expand Down Expand Up @@ -153,7 +153,7 @@ def _retrieve_signature(cls, func: ParameterCompilationFunc) -> InitializerCompi
return cast(InitializerCompilationSign, ann[args[-1]])


CompilerModelRegistry = dict[str, ExternalModel]
CompilerGateFunctionRegistry = dict[str, GateFunction]


class AbstractCompiler(ABC):
Expand All @@ -162,13 +162,15 @@ def __init__(
layers_registry: CompilerLayerRegistry,
parameters_registry: CompilerParameterRegistry,
initializers_registry: CompilerInitializerRegistry,
models_registry: CompilerModelRegistry | None = None,
gate_function_registry: CompilerGateFunctionRegistry | None = None,
**flags,
):
self._layers_registry = layers_registry
self._parameters_registry = parameters_registry
self._initializers_registry = initializers_registry
self._model_registry = {} if models_registry is None else models_registry
self._gate_function_registry = (
{} if gate_function_registry is None else gate_function_registry
)
self._flags = flags
self._compiled_circuits = CompiledCircuitsMap()

Expand Down Expand Up @@ -209,11 +211,11 @@ def retrieve_initializer_rule(
) -> InitializerCompilationFunc:
return self._initializers_registry.retrieve_rule(signature)

def add_external_model(self, model_id: str, model: ExternalModel):
self._model_registry[model_id] = model
def add_gate_function(self, function_id: str, gate_function: GateFunction):
self._gate_function_registry[function_id] = gate_function

def get_external_model(self, model_id: str) -> ExternalModel:
return self._model_registry[model_id]
def get_gate_function(self, function_id: str) -> GateFunction:
return self._gate_function_registry[function_id]

def compile(self, sc: Circuit) -> CompiledCircuit:
if self.is_compiled(sc):
Expand Down
83 changes: 41 additions & 42 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
TorchDiAcyclicGraph,
)
from cirkit.backend.torch.layers import TorchInputLayer, TorchLayer
from cirkit.backend.torch.utils import ExternalModelEval
from cirkit.backend.torch.utils import CachedGateFunctionEval
from cirkit.symbolic.circuit import StructuralProperties
from cirkit.utils.scope import Scope

Expand All @@ -33,31 +33,39 @@ def lookup(
self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchLayer | None, tuple]]:
# Loop through the entries and yield inputs
for entry in self._entries:
for entry in self:
layer = entry.module
in_layer_ids = entry.in_module_ids
in_fold_idx = entry.in_fold_idx
# Catch the case there are some inputs coming from other modules
if entry.in_module_ids:
(in_fold_idx,) = entry.in_fold_idx
(in_module_ids,) = entry.in_module_ids
if len(in_module_ids) == 1:
x = module_outputs[in_module_ids[0]]
if in_layer_ids:
in_fold_idx_h = in_fold_idx[0]
in_layer_ids_h = in_layer_ids[0]
if len(in_layer_ids_h) == 1:
x = module_outputs[in_layer_ids_h[0]]
else:
x = torch.cat([module_outputs[mid] for mid in in_module_ids], dim=0)
x = x[in_fold_idx]
yield entry.module, (x,)
x = torch.cat([module_outputs[mid] for mid in in_layer_ids_h], dim=0)
x = x[in_fold_idx_h]
yield layer, (x,)
continue

# Catch the case there are no inputs coming from other modules
# That is, we are gathering the inputs of input layers
layer = entry.module
assert isinstance(layer, TorchInputLayer)
if layer.num_variables:
if in_graph is None:
yield layer, ()
continue
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
# in_graph: An input batch (assignments to variables) of shape (B, D)
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
x = in_graph[..., layer.scope_idx].permute(2, 1, 0, 3)
# x: (B, D) -> (B, F, D') -> (F, B, D')
if len(in_graph.shape) != 2:
raise ValueError(
"The input to the circuit should have shape (B, D), "
"where B is the batch size and D is the number of variables "
"the circuit is defined on"
)
x = in_graph[..., layer.scope_idx].permute(1, 0, 2)
yield layer, (x,)
continue

Expand Down Expand Up @@ -121,27 +129,25 @@ class AbstractTorchCircuit(TorchDiAcyclicGraph[TorchLayer]):
def __init__(
self,
scope: Scope,
num_channels: int,
layers: Sequence[TorchLayer],
in_layers: Mapping[TorchLayer, Sequence[TorchLayer]],
outputs: Sequence[TorchLayer],
*,
properties: StructuralProperties,
fold_idx_info: FoldIndexInfo | None = None,
ext_model_evals: Mapping[str, ExternalModelEval] | None = None,
gate_function_evals: Mapping[str, CachedGateFunctionEval] | None = None,
) -> None:
"""Initializes a torch circuit.

Args:
scope: The variables scope.
num_channels: The number of channels per variable.
layers: The sequence of layers.
in_layers: A dictionary mapping layers to their inputs, if any.
outputs: A list of output layers.
properties: The structural properties of the circuit.
fold_idx_info: The folding index information.
It can be None if the circuit is not folded.
ext_model_evals: A mapping from external model identifiers to a cached evaluator.
gate_function_evals: A mapping from gate function identifiers to a cached evaluator.
"""
super().__init__(
layers,
Expand All @@ -150,10 +156,9 @@ def __init__(
fold_idx_info=fold_idx_info,
)
self._scope = scope
self._num_channels = num_channels
self._properties = properties
ext_model_evals = {} if ext_model_evals is None else ext_model_evals
self._ext_model_evals: Mapping[str, ExternalModelEval] = ext_model_evals
gate_function_evals = {} if gate_function_evals is None else gate_function_evals
self._gate_function_evals: Mapping[str, CachedGateFunctionEval] = gate_function_evals

@property
def scope(self) -> Scope:
Expand All @@ -173,15 +178,6 @@ def num_variables(self) -> int:
"""
return len(self.scope)

@property
def num_channels(self) -> int:
"""Retrieve the number of channels of each variable.

Returns:
The number of variables.
"""
return self._num_channels

@property
def properties(self) -> StructuralProperties:
"""Retrieve the structural properties of the circuit.
Expand All @@ -192,8 +188,8 @@ def properties(self) -> StructuralProperties:
return self._properties

@property
def ext_model_evals(self) -> Mapping[str, ExternalModelEval]:
return self._ext_model_evals
def gate_function_evals(self) -> Mapping[str, CachedGateFunctionEval]:
return self._gate_function_evals

@property
def layers(self) -> Sequence[TorchLayer]:
Expand Down Expand Up @@ -260,15 +256,18 @@ def _build_address_book(self, fold_idx_info: FoldIndexInfo) -> LayerAddressBook:
return LayerAddressBook.from_index_info(fold_idx_info, incomings_fn=self.layer_inputs)

def _evaluate_layers(
self, x: Tensor | None, *, ext_model_kwargs: Mapping[str, Mapping[str, Any]] | None = None
self,
x: Tensor | None,
*,
gate_function_kwargs: Mapping[str, Mapping[str, Any]] | None = None,
) -> Tensor:
# Evaluate the external models and cache their result.
# This will be called just before the invokation of the
# [evaluate][cirkit.backend.torch.graph.modules.TorchDiAcyclicGraph.evaluate] method.
ext_model_kwargs = {} if ext_model_kwargs is None else ext_model_kwargs
for ext_model_id, ext_model_eval in self._ext_model_evals.items():
kwargs = ext_model_kwargs.get(ext_model_id, {})
ext_model_eval.cache_forward(**kwargs)
gate_function_kwargs = {} if gate_function_kwargs is None else gate_function_kwargs
for gate_function_id, gate_function_eval in self._gate_function_evals.items():
kwargs = gate_function_kwargs.get(gate_function_id, {})
gate_function_eval.cache_forward(**kwargs)

# Evaluate layers on the given input
y = self.evaluate(x) # (O, B, K)
Expand All @@ -282,28 +281,28 @@ class TorchCircuit(AbstractTorchCircuit):
"""

def __call__(
self, x: Tensor, *, ext_model_kwargs: Mapping[str, Mapping[str, Any]] | None = None
self, x: Tensor, *, gate_function_kwargs: Mapping[str, Mapping[str, Any]] | None = None
) -> Tensor:
# IGNORE: Idiom for nn.Module.__call__.
return super().__call__(x, ext_model_kwargs=ext_model_kwargs) # type: ignore[no-any-return,misc]
return super().__call__(x, gate_function_kwargs=gate_function_kwargs) # type: ignore[no-any-return,misc]

def forward(
self, x: Tensor, *, ext_model_kwargs: Mapping[str, Mapping[str, Any]] | None = None
self, x: Tensor, *, gate_function_kwargs: Mapping[str, Mapping[str, Any]] | None = None
) -> Tensor:
"""Evaluate the circuit layers in forward mode, i.e., by evaluating each layer by
following the topological ordering.

Args:
x: The tensor input of the circuit, with shape $(B, C, D)$, where B is the batch size,
$C$ is the number of channels, and $D$ is the number of variables.
ext_model_kwargs: The arguments to pass to each external models.
gate_function_kwargs: The arguments to pass to each gate function models.

Returns:
Tensor: The tensor output of the circuit, with shape $(B, O, K)$,
where $O$ is the number of vectorized outputs (i.e., the number of output layers),
and $K$ is the number of scalars in each output (e.g., the number of classes).
"""
return self._evaluate_layers(x, ext_model_kwargs=ext_model_kwargs)
return self._evaluate_layers(x, gate_function_kwargs=gate_function_kwargs)


class TorchConstantCircuit(AbstractTorchCircuit):
Expand Down
58 changes: 29 additions & 29 deletions cirkit/backend/torch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
ParameterOptRegistry,
)
from cirkit.backend.torch.parameters.nodes import (
TorchModelParameter,
TorchGateFunctionParameter,
TorchParameterNode,
TorchParameterOp,
TorchPointerParameter,
Expand All @@ -53,7 +53,7 @@
DEFAULT_PARAMETER_COMPILATION_RULES,
)
from cirkit.backend.torch.semiring import Semiring, SemiringImpl
from cirkit.backend.torch.utils import ExternalModelEval
from cirkit.backend.torch.utils import CachedGateFunctionEval
from cirkit.symbolic.circuit import Circuit, pipeline_topological_ordering
from cirkit.symbolic.initializers import Initializer
from cirkit.symbolic.layers import Layer
Expand All @@ -74,27 +74,27 @@ def __init__(self):
# Since this is useful only for folding, it will be cleared after each circuit compilation.
self._symbolic_parameters: dict[TorchTensorParameter, TensorParameter] = {}

# A map from external model identifiers to the corresponding object used to evaluate them
self._ext_model_evals: dict[str, ExternalModelEval] = {}
# A map from external fate functions identifiers to the corresponding object used to evaluate them
self._gate_functions_evals: dict[str, CachedGateFunctionEval] = {}

@property
def ext_model_evals(self) -> Mapping[str, ExternalModelEval]:
return self._ext_model_evals
def gate_functions(self) -> Mapping[str, CachedGateFunctionEval]:
return self._gate_functions_evals

def finish_compilation(self):
# Clear the map from (unfolded) compiled parameter tensors to symbolic ones
self._symbolic_parameters = {}

# Clear the map of external models
self._ext_model_evals = {}
# Clear the map of gate functions
self._gate_functions_evals = {}

def has_compiled_parameter(self, p: TensorParameter) -> bool:
# Retrieve whether a tensor parameter has already been compiled
return p in self._compiled_parameters

def has_ext_model_eval(self, model_id: str) -> bool:
# Retrieve whether a function has already been compiled
return model_id in self._ext_model_evals
def has_gate_function(self, gate_function_id: str) -> bool:
# Retrieve whether a gate function has already been compiled
return gate_function_id in self._gate_functions_evals

def retrieve_compiled_parameter(self, p: TensorParameter) -> tuple[TorchTensorParameter, int]:
# Retrieve the compiled parameter: we return the fold index as well.
Expand All @@ -104,9 +104,9 @@ def retrieve_symbolic_parameter(self, p: TorchTensorParameter) -> TensorParamete
# Retrieve the symbolic parameter tensor associated to the compiled one (which is unfolded)
return self._symbolic_parameters[p]

def retrieve_ext_model_eval(self, model_id: str) -> ExternalModelEval:
# Retrieve the external model evaluator
return self._ext_model_evals[model_id]
def retrieve_gate_function(self, gate_function_id: str) -> CachedGateFunctionEval:
# Retrieve the external gate function evaluator
return self._gate_functions_evals[gate_function_id]

def register_compiled_parameter(
self, sp: TensorParameter, cp: TorchTensorParameter, *, fold_idx: int | None = None
Expand All @@ -123,9 +123,9 @@ def register_compiled_parameter(
# folded compiled parameter tensor, which is specified by the 'fold_idx'.
self._compiled_parameters[sp] = (cp, fold_idx)

def register_ext_model_eval(self, model_id: str, model_eval: ExternalModelEval):
# Register the external model evaluator to the running state of the compiler
self._ext_model_evals[model_id] = model_eval
def register_gate_function(self, function_id: str, gate_function_eval: CachedGateFunctionEval):
# Register the gate function evaluator to the running state of the compiler
self._gate_functions_evals[function_id] = gate_function_eval


class TorchCompiler(AbstractCompiler):
Expand Down Expand Up @@ -248,19 +248,18 @@ def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
# Construct the sequence of output layers
outputs = [compiled_layers_map[sl] for sl in sc.outputs]

# Retrieve the external model evaluators
ext_model_evals = self._state.ext_model_evals
# Retrieve the external gate function evaluators
gate_function_evals = self._state.gate_functions

# Construct the tensorized circuit
layers = list(compiled_layers_map.values())
cc = cc_cls(
sc.scope,
sc.num_channels,
layers=layers,
in_layers=in_layers,
outputs=outputs,
properties=sc.properties,
ext_model_evals=ext_model_evals,
gate_function_evals=gate_function_evals,
)

# Post-process the compiled circuit, i.e.,
Expand Down Expand Up @@ -303,13 +302,12 @@ def _fold_circuit(compiler: TorchCompiler, cc: AbstractTorchCircuit) -> Abstract
# Instantiate a folded circuit
return type(cc)(
cc.scope,
cc.num_channels,
layers,
in_layers,
outputs,
properties=cc.properties,
fold_idx_info=fold_idx_info,
ext_model_evals=cc.ext_model_evals,
gate_function_evals=cc.gate_function_evals,
)


Expand Down Expand Up @@ -428,16 +426,19 @@ def _fold_parameter_nodes_group(
)
return TorchPointerParameter(in_folded_node, fold_idx=in_fold_idx)
# Catch the case we are folding parameters obtained from an external function
if issubclass(fold_node_cls, TorchModelParameter):
assert all(isinstance(p, TorchModelParameter) for p in group)
if issubclass(fold_node_cls, TorchGateFunctionParameter):
assert all(isinstance(p, TorchGateFunctionParameter) for p in group)
if len(group) == 1:
# Catch the case we are folding a single torch function parameter
# In such a case, we just return it as it is
return group[0]
# Catch the case we are folding multiple torch function parameters
fold_idx: list[int] = list(chain.from_iterable(p.fold_idx for p in group))
return TorchModelParameter(
*group[0].shape, model_eval=group[0].model_eval, name=group[0].name, fold_idx=fold_idx
return TorchGateFunctionParameter(
*group[0].shape,
gate_function_eval=group[0].gate_function_eval,
name=group[0].name,
fold_idx=fold_idx,
)
# We are folding an operator: just set the number of folds and copy the configuration parameters
assert all(isinstance(p, TorchParameterOp) for p in group)
Expand Down Expand Up @@ -550,12 +551,11 @@ def match_optimizer_fuse(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
layers, in_layers, outputs = optimize_result
cc = type(cc)(
cc.scope,
cc.num_channels,
layers,
in_layers,
outputs,
properties=cc.properties,
ext_model_evals=cc.ext_model_evals,
gate_function_evals=cc.gate_function_evals,
)
return cc, True

Expand Down
Loading