"
@@ -399,84 +410,26 @@
{
"data": {
"text/html": [
- " MLP Summary \n",
- "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
- "┃ path ┃ type ┃ BatchStat ┃ Param ┃ RngState ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
- "│ bn │ BatchNorm │ mean: float32[5,32] │ bias: float32[5,32] │ │\n",
- "│ │ │ var: float32[5,32] │ scale: float32[5,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ 320 (1.3 KB) │ 320 (1.3 KB) │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ dropout/rngs/default │ RngStream │ │ │ count: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: uint32[5] │\n",
- "│ │ │ │ │ key: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: key<fry>[5] │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ │ 10 (60 B) │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear1 │ Linear │ │ b: float32[5,32] │ │\n",
- "│ │ │ │ w: float32[5,10,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ 1,760 (7.0 KB) │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear2 │ Linear │ │ b: float32[5,10] │ │\n",
- "│ │ │ │ w: float32[5,32,10] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ 1,650 (6.6 KB) │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ │ Total │ 320 (1.3 KB) │ 3,730 (14.9 KB) │ 10 (60 B) │\n",
- "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
- " \n",
- " Total Parameters: 4,060 (16.3 KB) \n",
- "
\n"
+ "
"
],
"text/plain": [
- "\u001b[3m MLP Summary \u001b[0m\n",
- "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mtype \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mBatchStat \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mParam \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mRngState \u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
- "│ bn │ BatchNorm │ mean: \u001b[2mfloat32\u001b[0m[5,32] │ bias: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
- "│ │ │ var: \u001b[2mfloat32\u001b[0m[5,32] │ scale: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ dropout/rngs/default │ RngStream │ │ │ count: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: \u001b[2muint32\u001b[0m[5] │\n",
- "│ │ │ │ │ key: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: \u001b[2mkey\u001b[0m[5] │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ │ \u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear1 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
- "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,10,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ \u001b[1m1,760 \u001b[0m\u001b[1;2m(7.0 KB)\u001b[0m │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear2 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,10] │ │\n",
- "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,32,10] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ \u001b[1m1,650 \u001b[0m\u001b[1;2m(6.6 KB)\u001b[0m │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3,730 \u001b[0m\u001b[1;2m(14.9 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n",
- "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
- "\u001b[1m \u001b[0m\n",
- "\u001b[1m Total Parameters: 4,060 \u001b[0m\u001b[1;2m(16.3 KB)\u001b[0m\u001b[1m \u001b[0m\n"
+ ""
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n"
- ]
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -528,7 +481,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -540,7 +493,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -589,7 +542,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -601,7 +554,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -613,7 +566,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -714,7 +667,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -726,7 +679,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -738,7 +691,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -750,7 +703,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md
index 51e0cda53f..61b96e2d34 100644
--- a/docs_nnx/nnx_basics.md
+++ b/docs_nnx/nnx_basics.md
@@ -12,7 +12,18 @@ jupytext:
Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.
-To begin, install Flax with `pip` and import necessary dependencies:
+In this guide you will learn about:
+
+- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.
+ - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).
+ - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.
+ - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.
+- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.
+ - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.
+- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.
+ - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).
+ - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`
+ - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.
## Setup
@@ -95,7 +106,7 @@ to handle them, as demonstrated in later sections of this guide.
Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.
-The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.
+The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
```{code-cell} ipython3
class MLP(nnx.Module):
diff --git a/flax/linen/summary.py b/flax/linen/summary.py
index 5d1b214249..d6676729f0 100644
--- a/flax/linen/summary.py
+++ b/flax/linen/summary.py
@@ -48,13 +48,6 @@
LogicalNames,
)
-try:
- from IPython import get_ipython
-
- in_ipython = get_ipython() is not None
-except ImportError:
- in_ipython = False
-
class _ValueRepresentation(ABC):
"""A class that represents a value in the summary table."""
@@ -249,6 +242,11 @@ def tabulate(
Total Parameters: 50 (200 B)
+
+ **Note**: rows order in the table does not represent execution order,
+ instead it aligns with the order of keys in `variables` which are sorted
+ alphabetically.
+
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
Args:
@@ -269,9 +267,7 @@ def tabulate(
mutable.
console_kwargs: An optional dictionary with additional keyword arguments
that are passed to `rich.console.Console` when rendering the table.
- Default arguments are ``'force_terminal': True``, and ``'force_jupyter'``
- is set to ``True`` if the code is running in a Jupyter notebook, otherwise
- it is set to ``False``.
+ Default arguments are `{'force_terminal': True, 'force_jupyter': False}`.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
@@ -568,7 +564,7 @@ def _render_table(
non_params_cols: list[str],
) -> str:
"""A function that renders a Table to a string representation using rich."""
- console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython}
+ console_kwargs = {'force_terminal': True, 'force_jupyter': False}
if console_extras is not None:
console_kwargs.update(console_extras)
diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py
index 1028efb2b1..63ed371be9 100644
--- a/flax/nnx/filterlib.py
+++ b/flax/nnx/filterlib.py
@@ -54,9 +54,7 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')
-def filters_to_predicates(
- filters: tp.Sequence[Filter],
-) -> tuple[Predicate, ...]:
+def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py
index 8cc272f8eb..a29999d34f 100644
--- a/flax/nnx/graph.py
+++ b/flax/nnx/graph.py
@@ -24,7 +24,7 @@
import numpy as np
import typing_extensions as tpe
-from flax.nnx import filterlib, reprlib, visualization
+from flax.nnx import filterlib, reprlib
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
@@ -63,7 +63,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)
-class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin):
+class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
"""A mapping that uses object id as the hash for the keys."""
def __init__(
@@ -248,7 +248,8 @@ def __nnx_repr__(self):
yield reprlib.Attr('index', self.index)
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={'type': self.type, 'index': self.index},
path=path,
@@ -271,7 +272,9 @@ def __nnx_repr__(self):
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
@@ -350,7 +353,8 @@ def __nnx_repr__(self):
)
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
diff --git a/flax/nnx/module.py b/flax/nnx/module.py
index b07efa7711..795bb9a088 100644
--- a/flax/nnx/module.py
+++ b/flax/nnx/module.py
@@ -403,6 +403,23 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)
+ def __treescope_repr__(self, path, subtree_renderer):
+ import treescope # type: ignore[import-not-found,import-untyped]
+ children = {}
+ for name, value in vars(self).items():
+ if name.startswith('_'):
+ continue
+ children[name] = value
+ return treescope.repr_lib.render_object_constructor(
+ object_type=type(self),
+ attributes=children,
+ path=path,
+ subtree_renderer=subtree_renderer,
+ color=treescope.formatting_util.color_from_string(
+ type(self).__qualname__
+ )
+ )
+
# -------------------------
# Pytree Definition
# -------------------------
diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py
index 230f1d356e..364b5dac1e 100644
--- a/flax/nnx/nn/linear.py
+++ b/flax/nnx/nn/linear.py
@@ -1063,7 +1063,7 @@ class Embed(Module):
>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
- 'embedding': VariableState( # 15 (60 B)
+ 'embedding': VariableState(
type=Param,
value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
[ 0.01070483, 0.27923733, 1.7487359 ],
diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py
index 928d9cf251..b5cbaf99b6 100644
--- a/flax/nnx/nn/normalization.py
+++ b/flax/nnx/nn/normalization.py
@@ -395,11 +395,11 @@ class LayerNorm(Module):
>>> nnx.state(layer)
State({
- 'bias': VariableState( # 6 (24 B)
+ 'bias': VariableState(
type=Param,
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
),
- 'scale': VariableState( # 6 (24 B)
+ 'scale': VariableState(
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
@@ -531,7 +531,7 @@ class RMSNorm(Module):
>>> nnx.state(layer)
State({
- 'scale': VariableState( # 6 (24 B)
+ 'scale': VariableState(
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
@@ -655,11 +655,11 @@ class GroupNorm(Module):
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
- 'bias': VariableState( # 6 (24 B)
+ 'bias': VariableState(
type=Param,
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
),
- 'scale': VariableState( # 6 (24 B)
+ 'scale': VariableState(
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py
index add545634a..2a495826a4 100644
--- a/flax/nnx/nn/stochastic.py
+++ b/flax/nnx/nn/stochastic.py
@@ -24,7 +24,7 @@
from flax.nnx.module import Module, first_from
-@dataclasses.dataclass(repr=False)
+@dataclasses.dataclass
class Dropout(Module):
"""Create a dropout layer.
diff --git a/flax/nnx/object.py b/flax/nnx/object.py
index b1f7478eef..afa41cdb7b 100644
--- a/flax/nnx/object.py
+++ b/flax/nnx/object.py
@@ -20,67 +20,27 @@
from abc import ABCMeta
from copy import deepcopy
+
import jax
import numpy as np
-import treescope # type: ignore[import-untyped]
-from treescope import rendering_parts
-from flax.nnx import visualization
-from flax import errors
from flax.nnx import (
- graph,
reprlib,
tracers,
)
-from flax import nnx
+from flax.nnx import graph
from flax.nnx.variablelib import Variable, VariableState
-from flax.typing import SizeBytes, value_stats
+from flax import errors
G = tp.TypeVar('G', bound='Object')
-def _collect_stats(
- node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]]
-):
- if not graph.is_node(node) and not isinstance(node, Variable):
- raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.')
-
- if id(node) in node_stats:
- return
-
- stats: dict[type[Variable], SizeBytes] = {}
- node_stats[id(node)] = stats
-
- if isinstance(node, Variable):
- var_type = type(node)
- if issubclass(var_type, nnx.RngState):
- var_type = nnx.RngState
- size_bytes = value_stats(node.value)
- if size_bytes:
- stats[var_type] = size_bytes
-
- else:
- node_dict = graph.get_node_impl(node).node_dict(node)
- for key, value in node_dict.items():
- if id(value) in node_stats:
- continue
- if graph.is_node(value) or isinstance(value, Variable):
- _collect_stats(value, node_stats)
- child_stats = node_stats[id(value)]
- for var_type, size_bytes in child_stats.items():
- if var_type in stats:
- stats[var_type] += size_bytes
- else:
- stats[var_type] = size_bytes
-
-
@dataclasses.dataclass
-class ObjectContext(threading.local):
+class GraphUtilsContext(threading.local):
seen_modules_repr: set[int] | None = None
- node_stats: dict[int, dict[type[Variable], SizeBytes]] | None = None
-OBJECT_CONTEXT = ObjectContext()
+CONTEXT = GraphUtilsContext()
class ObjectState(reprlib.Representable):
@@ -103,14 +63,14 @@ def __nnx_repr__(self):
yield reprlib.Attr('trace_state', self._trace_state)
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
- object_type=type(self),
- attributes={'trace_state': self._trace_state},
- path=path,
- subtree_renderer=subtree_renderer,
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
+ object_type=type(self),
+ attributes={'trace_state': self._trace_state},
+ path=path,
+ subtree_renderer=subtree_renderer,
)
-
class ObjectMeta(ABCMeta):
if not tp.TYPE_CHECKING:
@@ -130,14 +90,12 @@ def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G:
@dataclasses.dataclass(frozen=True, repr=False)
-class Array(reprlib.Representable):
+class Array:
shape: tp.Tuple[int, ...]
dtype: tp.Any
- def __nnx_repr__(self):
- yield reprlib.Object(type='Array', same_line=True)
- yield reprlib.Attr('shape', self.shape)
- yield reprlib.Attr('dtype', self.dtype)
+ def __repr__(self):
+ return f'Array(shape={self.shape}, dtype={self.dtype.name})'
class Object(reprlib.Representable, metaclass=ObjectMeta):
@@ -179,41 +137,20 @@ def __deepcopy__(self: G, memo=None) -> G:
return graph.merge(graphdef, state)
def __nnx_repr__(self):
- if OBJECT_CONTEXT.node_stats is None:
- node_stats: dict[int, dict[type[Variable], SizeBytes]] = {}
- _collect_stats(self, node_stats)
- OBJECT_CONTEXT.node_stats = node_stats
- stats = node_stats[id(self)]
- clear_node_stats = True
- else:
- stats = OBJECT_CONTEXT.node_stats[id(self)]
- clear_node_stats = False
-
- if OBJECT_CONTEXT.seen_modules_repr is None:
- OBJECT_CONTEXT.seen_modules_repr = set()
+ if CONTEXT.seen_modules_repr is None:
+ CONTEXT.seen_modules_repr = set()
clear_seen = True
else:
clear_seen = False
- if id(self) in OBJECT_CONTEXT.seen_modules_repr:
+ if id(self) in CONTEXT.seen_modules_repr:
yield reprlib.Object(type=type(self), empty_repr='...')
return
- try:
- if stats:
- stats_repr = ' # ' + ', '.join(
- f'{var_type.__name__}: {size_bytes}'
- for var_type, size_bytes in stats.items()
- )
- if len(stats) > 1:
- total_bytes = sum(stats.values(), SizeBytes(0, 0))
- stats_repr += f', Total: {total_bytes}'
- else:
- stats_repr = ''
-
- yield reprlib.Object(type=type(self), comment=stats_repr)
- OBJECT_CONTEXT.seen_modules_repr.add(id(self))
+ yield reprlib.Object(type=type(self))
+ CONTEXT.seen_modules_repr.add(id(self))
+ try:
for name, value in vars(self).items():
if name.startswith('_'):
continue
@@ -231,64 +168,24 @@ def to_shape_dtype(value):
return value
value = jax.tree.map(to_shape_dtype, value)
- yield reprlib.Attr(name, value)
+ yield reprlib.Attr(name, repr(value))
finally:
if clear_seen:
- OBJECT_CONTEXT.seen_modules_repr = None
- if clear_node_stats:
- OBJECT_CONTEXT.node_stats = None
+ CONTEXT.seen_modules_repr = None
def __treescope_repr__(self, path, subtree_renderer):
- from flax import nnx
-
- if OBJECT_CONTEXT.node_stats is None:
- node_stats: dict[int, dict[type[Variable], SizeBytes]] = {}
- _collect_stats(self, node_stats)
- OBJECT_CONTEXT.node_stats = node_stats
- stats = node_stats[id(self)]
- clear_node_stats = True
- else:
- stats = OBJECT_CONTEXT.node_stats[id(self)]
- clear_node_stats = False
-
- try:
- if stats:
- stats_repr = ' # ' + ', '.join(
- f'{var_type.__name__}: {size_bytes}'
- for var_type, size_bytes in stats.items()
- )
- if len(stats) > 1:
- total_bytes = sum(stats.values(), SizeBytes(0, 0))
- stats_repr += f', Total: {total_bytes}'
-
- first_line_annotation = rendering_parts.comment_color(
- rendering_parts.text(f'{stats_repr}')
- )
- else:
- first_line_annotation = None
- children = {}
- for name, value in vars(self).items():
- if name.startswith('_'):
- continue
- children[name] = value
-
- if isinstance(self, nnx.Module):
- color = treescope.formatting_util.color_from_string(
- type(self).__qualname__
- )
- else:
- color = None
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+ children = {}
+ for name, value in vars(self).items():
+ if name.startswith('_'):
+ continue
+ children[name] = value
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
- first_line_annotation=first_line_annotation,
- color=color,
- )
- finally:
- if clear_node_stats:
- OBJECT_CONTEXT.node_stats = None
+ )
# Graph Definition
def _graph_node_flatten(self):
@@ -328,4 +225,4 @@ def _graph_node_clear(self):
module_vars['_object__state'] = module_state
def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
- vars(self).update(attributes)
+ vars(self).update(attributes)
\ No newline at end of file
diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py
index 155c2e7e90..6ed7660cdf 100644
--- a/flax/nnx/reprlib.py
+++ b/flax/nnx/reprlib.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
import dataclasses
-import os
-import sys
import threading
import typing as tp
@@ -22,125 +21,22 @@
B = tp.TypeVar('B')
-def supports_color() -> bool:
- """
- Returns True if the running system's terminal supports color, and False otherwise.
- """
- try:
- from IPython import get_ipython
-
- ipython_available = get_ipython() is not None
- except ImportError:
- ipython_available = False
-
- supported_platform = sys.platform != 'win32' or 'ANSICON' in os.environ
- is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
- return (supported_platform and is_a_tty) or ipython_available
-
-
-class Color(tp.NamedTuple):
- TYPE: str
- ATTRIBUTE: str
- SEP: str
- PAREN: str
- COMMENT: str
- INT: str
- STRING: str
- FLOAT: str
- BOOL: str
- NONE: str
- END: str
-
-
-NO_COLOR = Color(
- TYPE='',
- ATTRIBUTE='',
- SEP='',
- PAREN='',
- COMMENT='',
- INT='',
- STRING='',
- FLOAT='',
- BOOL='',
- NONE='',
- END='',
-)
-
-
-# Use python vscode theme colors
-if supports_color():
- COLOR = Color(
- TYPE='\x1b[38;2;79;201;177m',
- ATTRIBUTE='\033[38;2;156;220;254m',
- SEP='\x1b[38;2;212;212;212m',
- PAREN='\x1b[38;2;255;213;3m',
- # COMMENT='\033[38;2;87;166;74m',
- COMMENT='\033[38;2;105;105;105m', # Dark gray
- INT='\x1b[38;2;182;207;169m',
- STRING='\x1b[38;2;207;144;120m',
- FLOAT='\x1b[38;2;182;207;169m',
- BOOL='\x1b[38;2;86;156;214m',
- NONE='\x1b[38;2;86;156;214m',
- END='\x1b[0m',
- )
-else:
- COLOR = NO_COLOR
-
-
@dataclasses.dataclass
class ReprContext(threading.local):
- current_color: Color = COLOR
+ indent_stack: tp.List[str] = dataclasses.field(default_factory=lambda: [''])
REPR_CONTEXT = ReprContext()
-def colorized(x, /):
- c = REPR_CONTEXT.current_color
- if isinstance(x, list):
- return f'{c.PAREN}[{c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}]{c.END}'
- elif isinstance(x, tuple):
- if len(x) == 1:
- return f'{c.PAREN}({c.END}{colorized(x[0])},{c.PAREN}){c.END}'
- return f'{c.PAREN}({c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}){c.END}'
- elif isinstance(x, dict):
- open, close = '{', '}'
- return f'{c.PAREN}{open}{c.END}{", ".join(f"{c.STRING}{k!r}{c.END}: {colorized(v)}" for k, v in x.items())}{c.PAREN}{close}{c.END}'
- elif isinstance(x, set):
- open, close = '{', '}'
- return f'{c.PAREN}{open}{c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}{close}{c.END}'
- elif isinstance(x, type):
- return f'{c.TYPE}{x.__name__}{c.END}'
- elif isinstance(x, bool):
- return f'{c.BOOL}{x}{c.END}'
- elif isinstance(x, int):
- return f'{c.INT}{x}{c.END}'
- elif isinstance(x, str):
- return f'{c.STRING}{x!r}{c.END}'
- elif isinstance(x, float):
- return f'{c.FLOAT}{x}{c.END}'
- elif x is None:
- return f'{c.NONE}{x}{c.END}'
- elif isinstance(x, Representable):
- return get_repr(x)
- else:
- return repr(x)
-
-
@dataclasses.dataclass
class Object:
type: tp.Union[str, type]
start: str = '('
end: str = ')'
- kv_sep: str = '='
- indent: str = ' '
+ value_sep: str = '='
+ elem_indent: str = ' '
empty_repr: str = ''
- comment: str = ''
- same_line: bool = False
-
- @property
- def elem_sep(self):
- return ', ' if self.same_line else ',\n'
@dataclasses.dataclass
@@ -149,8 +45,6 @@ class Attr:
value: tp.Union[str, tp.Any]
start: str = ''
end: str = ''
- use_raw_value: bool = False
- use_raw_key: bool = False
class Representable:
@@ -160,96 +54,79 @@ def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]:
raise NotImplementedError
def __repr__(self) -> str:
- current_color = REPR_CONTEXT.current_color
- REPR_CONTEXT.current_color = NO_COLOR
- try:
- return get_repr(self)
- finally:
- REPR_CONTEXT.current_color = current_color
-
- def __str__(self) -> str:
return get_repr(self)
+@contextlib.contextmanager
+def add_indent(indent: str) -> tp.Iterator[None]:
+ REPR_CONTEXT.indent_stack.append(REPR_CONTEXT.indent_stack[-1] + indent)
+
+ try:
+ yield
+ finally:
+ REPR_CONTEXT.indent_stack.pop()
+
+
+def get_indent() -> str:
+ return REPR_CONTEXT.indent_stack[-1]
+
+
def get_repr(obj: Representable) -> str:
if not isinstance(obj, Representable):
raise TypeError(f'Object {obj!r} is not representable')
- c = REPR_CONTEXT.current_color
iterator = obj.__nnx_repr__()
config = next(iterator)
-
if not isinstance(config, Object):
raise TypeError(f'First item must be Config, got {type(config).__name__}')
- kv_sep = f'{c.SEP}{config.kv_sep}{c.END}'
-
def _repr_elem(elem: tp.Any) -> str:
if not isinstance(elem, Attr):
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
- value_repr = elem.value if elem.use_raw_value else colorized(elem.value)
- value_repr = value_repr.replace('\n', '\n' + config.indent)
- key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}'
- indent = '' if config.same_line else config.indent
+ value = elem.value if isinstance(elem.value, str) else repr(elem.value)
+
+ value = value.replace('\n', '\n' + config.elem_indent)
- return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}'
+ return f'{config.elem_indent}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}'
- elems = config.elem_sep.join(map(_repr_elem, iterator))
+ with add_indent(config.elem_indent):
+ elems = ',\n'.join(map(_repr_elem, iterator))
if elems:
- if config.same_line:
- elems_repr = elems
- comment = ''
- else:
- elems_repr = '\n' + elems + '\n'
- comment = f'{c.COMMENT}{config.comment}{c.END}'
+ elems = '\n' + elems + '\n'
else:
- elems_repr = config.empty_repr
- comment = ''
+ elems = config.empty_repr
type_repr = (
config.type if isinstance(config.type, str) else config.type.__name__
)
- type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else ''
- start = f'{c.PAREN}{config.start}{c.END}' if config.start else ''
- end = f'{c.PAREN}{config.end}{c.END}' if config.end else ''
- out = f'{type_repr}{start}{comment}{elems_repr}{end}'
- return out
+ return f'{type_repr}{config.start}{elems}{config.end}'
-class MappingReprMixin(Representable):
+class MappingReprMixin(tp.Mapping[A, B]):
def __nnx_repr__(self):
- yield Object(type='', kv_sep=': ', start='{', end='}')
+ yield Object(type='', value_sep=': ', start='{', end='}')
- for key, value in self.items(): # type: ignore
- yield Attr(colorized(key), value, use_raw_key=True)
+ for key, value in self.items():
+ yield Attr(repr(key), value)
@dataclasses.dataclass(repr=False)
class PrettyMapping(Representable):
mapping: tp.Mapping
def __nnx_repr__(self):
- yield Object(type=type(self), kv_sep=': ', start='({', end='})')
+ yield Object(type='', value_sep=': ', start='{', end='}')
for key, value in self.mapping.items():
- yield Attr(colorized(key), value, use_raw_key=True)
-
-@dataclasses.dataclass(repr=False)
-class SequenceReprMixin(Representable):
- def __nnx_repr__(self):
- yield Object(type=type(self), kv_sep='', start='([', end='])')
-
- for value in self: # type: ignore
- yield Attr('', value, use_raw_key=True)
-
+ yield Attr(repr(key), value)
@dataclasses.dataclass(repr=False)
class PrettySequence(Representable):
- sequence: tp.Sequence
+ list: tp.Sequence
def __nnx_repr__(self):
- yield Object(type=type(self), kv_sep='', start='([', end='])')
+ yield Object(type='', value_sep='', start='[', end=']')
- for value in self.sequence:
- yield Attr('', value, use_raw_key=True)
\ No newline at end of file
+ for value in self.list:
+ yield Attr('', value)
\ No newline at end of file
diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py
index 38cb3da759..42a2604042 100644
--- a/flax/nnx/statelib.py
+++ b/flax/nnx/statelib.py
@@ -38,7 +38,7 @@ def __init__(self, state: State):
self.state = state
def __nnx_repr__(self):
- yield reprlib.Object('', kv_sep=': ', start='{', end='}')
+ yield reprlib.Object('', value_sep=': ', start='{', end='}')
for r in self.state.__nnx_repr__():
if isinstance(r, reprlib.Object):
@@ -54,7 +54,7 @@ def __treescope_repr__(self, path, subtree_renderer):
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)
-class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.SequenceReprMixin):
+class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence):
_keys: tuple[PathParts, ...]
_values: list[V]
@@ -66,14 +66,6 @@ def __init__(self, items: tp.Iterable[tuple[PathParts, V]]):
self._keys = tuple(keys)
self._values = values
- @property
- def paths(self) -> tp.Sequence[PathParts]:
- return self._keys
-
- @property
- def leaves(self) -> tp.Sequence[V]:
- return self._values
-
@tp.overload
def __getitem__(self, index: int) -> tuple[PathParts, V]: ...
@tp.overload
@@ -181,7 +173,7 @@ def __len__(self) -> int:
return len(self._mapping)
def __nnx_repr__(self):
- yield reprlib.Object(type(self), kv_sep=': ', start='({', end='})')
+ yield reprlib.Object(type(self), value_sep=': ', start='({', end='})')
for k, v in self.items():
if isinstance(v, State):
diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py
index a7b72b1540..c53bbd5c4d 100644
--- a/flax/nnx/tracers.py
+++ b/flax/nnx/tracers.py
@@ -18,7 +18,7 @@
import jax
import jax.core
-from flax.nnx import reprlib, visualization
+from flax.nnx import reprlib
def current_jax_trace():
@@ -47,11 +47,12 @@ def __nnx_repr__(self):
yield reprlib.Attr('jax_trace', self._jax_trace)
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
- object_type=type(self),
- attributes={'jax_trace': self._jax_trace},
- path=path,
- subtree_renderer=subtree_renderer,
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
+ object_type=type(self),
+ attributes={'jax_trace': self._jax_trace},
+ path=path,
+ subtree_renderer=subtree_renderer,
)
def __eq__(self, other):
diff --git a/flax/nnx/training/metrics.py b/flax/nnx/training/metrics.py
index 4facf42787..2073787b0d 100644
--- a/flax/nnx/training/metrics.py
+++ b/flax/nnx/training/metrics.py
@@ -276,45 +276,45 @@ class MultiMetric(Metric):
... )
>>> metrics
- MultiMetric( # MetricState: 4 (16 B)
- accuracy=Accuracy( # MetricState: 2 (8 B)
+ MultiMetric(
+ accuracy=Accuracy(
argname='values',
- total=MetricState( # 1 (4 B)
+ total=MetricState(
value=Array(0., dtype=float32)
),
- count=MetricState( # 1 (4 B)
+ count=MetricState(
value=Array(0, dtype=int32)
)
),
- loss=Average( # MetricState: 2 (8 B)
+ loss=Average(
argname='values',
- total=MetricState( # 1 (4 B)
+ total=MetricState(
value=Array(0., dtype=float32)
),
- count=MetricState( # 1 (4 B)
+ count=MetricState(
value=Array(0, dtype=int32)
)
)
)
>>> metrics.accuracy
- Accuracy( # MetricState: 2 (8 B)
+ Accuracy(
argname='values',
- total=MetricState( # 1 (4 B)
+ total=MetricState(
value=Array(0., dtype=float32)
),
- count=MetricState( # 1 (4 B)
+ count=MetricState(
value=Array(0, dtype=int32)
)
)
>>> metrics.loss
- Average( # MetricState: 2 (8 B)
+ Average(
argname='values',
- total=MetricState( # 1 (4 B)
+ total=MetricState(
value=Array(0., dtype=float32)
),
- count=MetricState( # 1 (4 B)
+ count=MetricState(
value=Array(0, dtype=int32)
)
)
diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py
index b2c0660962..4752a9b7bd 100644
--- a/flax/nnx/variablelib.py
+++ b/flax/nnx/variablelib.py
@@ -21,15 +21,10 @@
from typing import Any
import jax
-import treescope # type: ignore[import-untyped]
from flax import errors
-from flax.nnx import filterlib, reprlib, tracers, visualization
-from flax.typing import (
- Missing,
- PathParts,
- value_stats,
-)
+from flax.nnx import filterlib, reprlib, tracers
+from flax.typing import Missing, PathParts
import jax.tree_util as jtu
A = tp.TypeVar('A')
@@ -47,7 +42,6 @@
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
-
@dataclasses.dataclass
class VariableMetadata(tp.Generic[A]):
raw_value: A
@@ -317,34 +311,20 @@ def to_state(self: Variable[A]) -> VariableState[A]:
return VariableState(type(self), self.raw_value, **self._var_metadata)
def __nnx_repr__(self):
- stats = value_stats(self.value)
- if stats:
- comment = f' # {stats}'
- else:
- comment = ''
-
- yield reprlib.Object(type=type(self).__name__, comment=comment)
+ yield reprlib.Object(type=type(self))
yield reprlib.Attr('value', self.raw_value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, repr(value))
def __treescope_repr__(self, path, subtree_renderer):
- size_bytes = value_stats(self.value)
- if size_bytes:
- stats_repr = f' # {size_bytes}'
- first_line_annotation = treescope.rendering_parts.comment_color(
- treescope.rendering_parts.text(f'{stats_repr}')
- )
- else:
- first_line_annotation = None
+ import treescope # type: ignore[import-not-found,import-untyped]
children = {'value': self.raw_value, **self._var_metadata}
- return visualization.render_object_constructor(
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
- first_line_annotation=first_line_annotation,
)
# hooks API
@@ -784,35 +764,22 @@ def __delattr__(self, name: str) -> None:
del self._var_metadata[name]
def __nnx_repr__(self):
- stats = value_stats(self.value)
- if stats:
- comment = f' # {stats}'
- else:
- comment = ''
-
- yield reprlib.Object(type=type(self), comment=comment)
- yield reprlib.Attr('type', self.type)
+ yield reprlib.Object(type=type(self))
+ yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('value', self.value)
for name, value in self._var_metadata.items():
- yield reprlib.Attr(name, value)
+ yield reprlib.Attr(name, repr(value))
def __treescope_repr__(self, path, subtree_renderer):
- size_bytes = value_stats(self.value)
- if size_bytes:
- stats_repr = f' # {size_bytes}'
- first_line_annotation = treescope.rendering_parts.comment_color(
- treescope.rendering_parts.text(f'{stats_repr}')
- )
- else:
- first_line_annotation = None
+ import treescope # type: ignore[import-not-found,import-untyped]
+
children = {'type': self.type, 'value': self.value, **self._var_metadata}
- return visualization.render_object_constructor(
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
- first_line_annotation=first_line_annotation,
)
def replace(self, value: B) -> VariableState[B]:
@@ -944,7 +911,7 @@ def wrapper(*args):
def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]],
- filters: tp.Sequence[filterlib.Filter],
+ filters: tuple[filterlib.Filter, ...],
) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
diff --git a/flax/nnx/visualization.py b/flax/nnx/visualization.py
index 8c548d040c..d49eed7cf7 100644
--- a/flax/nnx/visualization.py
+++ b/flax/nnx/visualization.py
@@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import typing as tp
-
-import treescope # type: ignore[import-untyped]
-from treescope import rendering_parts, renderers
+import importlib.util
+treescope_installed = importlib.util.find_spec('treescope') is not None
try:
from IPython import get_ipython
@@ -31,112 +29,12 @@ def display(*args):
If treescope is not installed or the code is not running in IPython,
``display`` will print the objects instead.
"""
- if not in_ipython:
+ if not treescope_installed or not in_ipython:
for x in args:
print(x)
return
+ import treescope # type: ignore[import-not-found,import-untyped]
+
for x in args:
treescope.display(x, ignore_exceptions=True, autovisualize=True)
-
-
-def render_object_constructor(
- object_type: type[tp.Any],
- attributes: tp.Mapping[str, tp.Any],
- path: str | None,
- subtree_renderer: renderers.TreescopeSubtreeRenderer,
- roundtrippable: bool = False,
- color: str | None = None,
- first_line_annotation: rendering_parts.RenderableTreePart | None = None,
-) -> rendering_parts.Rendering:
- """Renders an object in "constructor format", similar to a dataclass.
-
- This produces a rendering like `Foo(bar=1, baz=2)`, where Foo identifies the
- type of the object, and bar and baz are the names of the attributes of the
- object. It is a *requirement* that these are the actual attributes of the
- object, which can be accessed via `obj.bar` or similar; otherwise, the
- path renderings will break.
-
- This can be used from within a `__treescope_repr__` implementation via ::
-
- def __treescope_repr__(self, path, subtree_renderer):
- return repr_lib.render_object_constructor(
- object_type=type(self),
- attributes=,
- path=path,
- subtree_renderer=subtree_renderer,
- )
-
- Args:
- object_type: The type of the object.
- attributes: The attributes of the object, which will be rendered as keyword
- arguments to the constructor.
- path: The path to the object. When `render_object_constructor` is called
- from `__treescope_repr__`, this should come from the `path` argument to
- `__treescope_repr__`.
- subtree_renderer: The renderer to use to render subtrees. When
- `render_object_constructor` is called from `__treescope_repr__`, this
- should come from the `subtree_renderer` argument to `__treescope_repr__`.
- roundtrippable: Whether evaluating the rendering as Python code will produce
- an object that is equal to the original object. This implies that the
- keyword arguments are actually the keyword arguments to the constructor,
- and not some other attributes of the object.
- color: The background color to use for the object rendering. If None, does
- not use a background color. A utility for assigning a random color based
- on a string key is given in `treescope.formatting_util`.
- first_line_annotation: An annotation for the first line of the node when it
- is expanded.
-
- Returns:
- A rendering of the object, suitable for returning from `__treescope_repr__`.
- """
- if roundtrippable:
- constructor = rendering_parts.siblings(
- rendering_parts.maybe_qualified_type_name(object_type), '('
- )
- closing_suffix = rendering_parts.text(')')
- else:
- constructor = rendering_parts.siblings(
- rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text('<')),
- rendering_parts.maybe_qualified_type_name(object_type),
- '(',
- )
- closing_suffix = rendering_parts.siblings(
- ')',
- rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text('>')),
- )
-
- children = []
- for i, (name, value) in enumerate(attributes.items()):
- child_path = None if path is None else f'{path}.{name}'
-
- if i < len(attributes) - 1:
- # Not the last child. Always show a comma, and add a space when
- # collapsed.
- comma_after = rendering_parts.siblings(
- ',',
- rendering_parts.fold_condition(collapsed=rendering_parts.text(' ')),
- )
- else:
- # Last child: only show the comma when the node is expanded.
- comma_after = rendering_parts.fold_condition(
- expanded=rendering_parts.text(',')
- )
-
- child_line = rendering_parts.build_full_line_with_annotations(
- rendering_parts.siblings_with_annotations(
- f'{name}=',
- subtree_renderer(value, path=child_path),
- ),
- comma_after,
- )
- children.append(child_line)
-
- return rendering_parts.build_foldable_tree_node_from_children(
- prefix=constructor,
- children=children,
- suffix=closing_suffix,
- path=path,
- background_color=color,
- first_line_annotation=first_line_annotation,
- )
\ No newline at end of file
diff --git a/flax/struct.py b/flax/struct.py
index 6c18651aaa..4e8de0a7fe 100644
--- a/flax/struct.py
+++ b/flax/struct.py
@@ -123,7 +123,7 @@ class method that provides the smart constructor.
"""
# Support passing arguments to the decorator (e.g. @dataclass(kw_only=True))
if clz is None:
- return functools.partial(dataclass, **kwargs) # type: ignore[bad-return-type]
+ return functools.partial(dataclass, **kwargs)
# check if already a flax dataclass
if '_flax_dataclass' in clz.__dict__:
diff --git a/flax/typing.py b/flax/typing.py
index 0ae990d95a..a630a3571e 100644
--- a/flax/typing.py
+++ b/flax/typing.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import annotations
from collections import deque
from functools import partial
@@ -27,8 +26,6 @@
from collections.abc import Callable, Hashable, Mapping, Sequence
import jax
-import jax.numpy as jnp
-import numpy as np
from flax.core import FrozenDict
import dataclasses
@@ -164,63 +161,3 @@ class Missing:
MISSING = Missing()
-
-
-def _bytes_repr(num_bytes):
- count, units = (
- (f'{num_bytes / 1e9 :,.1f}', 'GB')
- if num_bytes > 1e9
- else (f'{num_bytes / 1e6 :,.1f}', 'MB')
- if num_bytes > 1e6
- else (f'{num_bytes / 1e3 :,.1f}', 'KB')
- if num_bytes > 1e3
- else (f'{num_bytes:,}', 'B')
- )
-
- return f'{count} {units}'
-
-
-class ShapeDtype(Protocol):
- shape: Shape
- dtype: Dtype
-
-
-def has_shape_dtype(x: Any) -> TypeGuard[ShapeDtype]:
- return hasattr(x, 'shape') and hasattr(x, 'dtype')
-
-
-@dataclasses.dataclass(frozen=True, slots=True)
-class SizeBytes: # type: ignore[misc]
- size: int
- bytes: int
-
- @staticmethod
- def from_array(x: ShapeDtype) -> SizeBytes:
- size = int(np.prod(x.shape))
- dtype: jnp.dtype
- if isinstance(x.dtype, str):
- dtype = jnp.dtype(x.dtype)
- else:
- dtype = x.dtype # type: ignore
- bytes = size * dtype.itemsize # type: ignore
- return SizeBytes(size, bytes)
-
- def __add__(self, other: SizeBytes) -> SizeBytes:
- return SizeBytes(self.size + other.size, self.bytes + other.bytes)
-
- def __bool__(self) -> bool:
- return bool(self.size)
-
- def __repr__(self) -> str:
- bytes_repr = _bytes_repr(self.bytes)
- return f'{self.size:,} ({bytes_repr})'
-
-
-def value_stats(x):
- leaves = jax.tree.leaves(x)
- size_bytes = SizeBytes(0, 0)
- for leaf in leaves:
- if has_shape_dtype(leaf):
- size_bytes += SizeBytes.from_array(leaf)
-
- return size_bytes
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index f7a890fad0..658b2f15d5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ dependencies = [
"rich>=11.1",
"typing_extensions>=4.2",
"PyYAML>=5.4.1",
- "treescope>=0.1.7",
+ "treescope>=0.1.2",
]
classifiers = [
"Development Status :: 3 - Alpha",
diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py
index 64928f46b8..ce65186dd2 100644
--- a/tests/nnx/module_test.py
+++ b/tests/nnx/module_test.py
@@ -25,7 +25,6 @@
import jax.numpy as jnp
import numpy as np
-
A = TypeVar('A')
class List(nnx.Module):
@@ -551,46 +550,6 @@ def __call__(self, x):
y2 = model(jnp.ones((5, 2)))
np.testing.assert_allclose(y1, y2)
- def test_repr(self):
- class Block(nnx.Module):
- def __init__(self, din, dout, rngs: nnx.Rngs):
- self.linear = nnx.Linear(din, dout, rngs=rngs)
- self.bn = nnx.BatchNorm(dout, rngs=rngs)
- self.dropout = nnx.Dropout(0.2, rngs=rngs)
-
- def __call__(self, x):
- return nnx.relu(self.dropout(self.bn(self.linear(x))))
-
- class Foo(nnx.Module):
- def __init__(self, rngs: nnx.Rngs):
- self.block1 = Block(32, 128, rngs=rngs)
- self.block2 = Block(128, 10, rngs=rngs)
-
- def __call__(self, x):
- return self.block2(self.block1(x))
-
- obj = Foo(nnx.Rngs(0))
-
- leaves = nnx.state(obj).flat_state().leaves
-
- expected_total = sum(int(np.prod(x.value.shape)) for x in leaves)
- expected_total_params = sum(
- int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.Param
- )
- expected_total_batch_stats = sum(
- int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.BatchStat
- )
- expected_total_rng_states = sum(
- int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.RngState
- )
-
- foo_repr = repr(obj).replace(',', '').splitlines()
-
- self.assertIn(str(expected_total), foo_repr[0])
- self.assertIn(str(expected_total_params), foo_repr[0])
- self.assertIn(str(expected_total_batch_stats), foo_repr[0])
- self.assertIn(str(expected_total_rng_states), foo_repr[0])
-
class TestModulePytree:
def test_tree_map(self):
diff --git a/uv.lock b/uv.lock
index 48bda4f756..e08e2dbf53 100644
--- a/uv.lock
+++ b/uv.lock
@@ -3,13 +3,13 @@ requires-python = ">=3.10"
resolution-markers = [
"python_full_version < '3.11' and platform_system == 'Darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version == '3.11.*' and platform_system == 'Darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
[[package]]
@@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version < '3.11' and platform_system == 'Darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 }
wheels = [
@@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version == '3.11.*' and platform_system == 'Darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 }
wheels = [
@@ -890,7 +890,7 @@ requires-dist = [
{ name = "tensorflow-text", marker = "platform_system != 'Darwin' and extra == 'testing'", specifier = ">=2.11.0" },
{ name = "tensorstore" },
{ name = "torch", marker = "extra == 'testing'" },
- { name = "treescope", specifier = ">=0.1.7" },
+ { name = "treescope", specifier = ">=0.1.2" },
{ name = "treescope", marker = "python_full_version >= '3.10' and extra == 'testing'", specifier = ">=0.1.1" },
{ name = "typing-extensions", specifier = ">=4.2" },
]
@@ -1202,7 +1202,7 @@ name = "ipython"
version = "8.26.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "decorator" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "jedi" },
@@ -1246,7 +1246,7 @@ wheels = [
[[package]]
name = "jax"
-version = "0.4.38"
+version = "0.4.37"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jaxlib" },
@@ -1255,14 +1255,14 @@ dependencies = [
{ name = "opt-einsum" },
{ name = "scipy" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 }
+sdist = { url = "https://files.pythonhosted.org/packages/50/30/ad7617a960c86782587540a179cef676962322d1e5411415b1aa24f02ce0/jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b", size = 1915966 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 },
+ { url = "https://files.pythonhosted.org/packages/5f/3f/6c5553baaa7faa3fa8bae8279b1e46cb54c7ce52360139eae53498786ea5/jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e", size = 2221192 },
]
[[package]]
name = "jaxlib"
-version = "0.4.38"
+version = "0.4.36"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ml-dtypes" },
@@ -1270,26 +1270,26 @@ dependencies = [
{ name = "scipy" },
]
wheels = [
- { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 },
- { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 },
- { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 },
- { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 },
- { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 },
- { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 },
- { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 },
- { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 },
- { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 },
- { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 },
- { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 },
- { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 },
- { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 },
- { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 },
- { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 },
- { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 },
- { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 },
- { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 },
- { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 },
- { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 },
+ { url = "https://files.pythonhosted.org/packages/23/8d/8a44618f3493f29d769b2b40778d24075689cc8697b98e2c43bafbe50edf/jaxlib-0.4.36-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d69f991833b6dca794767049843462805936c89553b136a8ebb8485334204457", size = 98648230 },
+ { url = "https://files.pythonhosted.org/packages/78/b8/207485eab566dcfbc29bb833714ac1ca47a1665ca605b1ff7d3d5dd2afbe/jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:807814c1ba3ec69cffaa93d3f90651c694a9b8a750b43832cc167ed590c821dd", size = 78553787 },
+ { url = "https://files.pythonhosted.org/packages/26/42/3c2b0dc86a17aafd8f46ba0e4388f39f55706ee25f6c463c3dadea7a71e2/jaxlib-0.4.36-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1bc27d9ae09549d7652eafe1fdb10c21546cd2fd02bb24a49a7e6208b69163b0", size = 84008742 },
+ { url = "https://files.pythonhosted.org/packages/b9/b2/29be712098342df10075fe085c0b39d783a579bd3325fb0d69c22712cf27/jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3379f03a794d6a30b75765d2786f6e31052f364196fcd49aaae292a3c16f12ec", size = 100263041 },
+ { url = "https://files.pythonhosted.org/packages/63/a9/93404a2f1d59647749d4d6dbab7bee9f5a7bfaeb9ade25b7e66c0ca0949a/jaxlib-0.4.36-cp310-cp310-win_amd64.whl", hash = "sha256:63e575ac8a515dee8171dd4a88c460d538bbcc9d959cabc9781e961763678f84", size = 63270658 },
+ { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 },
+ { url = "https://files.pythonhosted.org/packages/34/5a/9f3c9e5cec23e60f78bb3c3da108a5ef664601862dbc4e84fc4be3654f5d/jaxlib-0.4.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d7a89adf4c9d3cddd20482931dedc7a9e2669e904196a9599d9a605b3d9e552", size = 78574312 },
+ { url = "https://files.pythonhosted.org/packages/ff/5c/bf78ed9b8d0f174a562f6496049a4872e14a3bb3a80de09c4292d04be5f0/jaxlib-0.4.36-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c395fe8cc5bd6558dd2fbce78e24172b6f27762e17628720ae03d693001283f3", size = 84038323 },
+ { url = "https://files.pythonhosted.org/packages/67/af/6a9dd26e8a6bedd4c9fe702059767256b0d9ed18c29a180a4598d5795bb4/jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc324c6b1c64fe68400934c653e4e622f12576120dcdb451c3b4ea4dcaba2ae9", size = 100285487 },
+ { url = "https://files.pythonhosted.org/packages/b7/46/31c3a519a94e84c672ca264c4151998e3e3fd11c481d8fa5af5885b91a1e/jaxlib-0.4.36-cp311-cp311-win_amd64.whl", hash = "sha256:c9e0c45a79e63aea65447f82bd0fa21c17b9afe884aa18dd5362b9965abe9d72", size = 63308064 },
+ { url = "https://files.pythonhosted.org/packages/e3/0e/3b4a99c09431ee5820624d4dcf4efa7becd3c83b56ff0f09a078f4c421a2/jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d", size = 98718357 },
+ { url = "https://files.pythonhosted.org/packages/d3/46/05e70a1236ec3782333b3e9469f971c9d45af2aa0aebf602acd9d76292eb/jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e", size = 78596060 },
+ { url = "https://files.pythonhosted.org/packages/8e/76/6b969cbf197b8c53c84c2642069722e84a3a260af084a8acbbf90ca444ea/jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442", size = 84053202 },
+ { url = "https://files.pythonhosted.org/packages/fe/f2/7624a304426daa7b135b85caf1b8eccf879e7cb10bc074656ce628309cb0/jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357", size = 100325610 },
+ { url = "https://files.pythonhosted.org/packages/bb/8b/ded8420cd9198eb677869ffd557d9880af5833c7bf39e604e80b56550e09/jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3", size = 63338518 },
+ { url = "https://files.pythonhosted.org/packages/5d/22/b72811c61e8b594951d3ee03245cb0932c723ac35e75569005c3c976eec2/jaxlib-0.4.36-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:02df9c0e1323dde01e966c22eb12432905d2d4de8aac7b603cad2083101b0e6b", size = 98719384 },
+ { url = "https://files.pythonhosted.org/packages/f1/66/3f4a97097983914899100db9e5312493fe1d6adc924e47a0e47e15c553f5/jaxlib-0.4.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ec980e85983f41999c4dc84137dec70507d958e23d7eefa104da93053d135f", size = 78596150 },
+ { url = "https://files.pythonhosted.org/packages/3a/6f/cf02f56d1532962d8ca77a6548acab8204294b96b5a153ca4a2caf4971fc/jaxlib-0.4.36-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7ce9368515348d869d6c59d9904c3cb3c81f22ff3e9e969eae0e3563fe472080", size = 84055851 },
+ { url = "https://files.pythonhosted.org/packages/28/10/4fc4e9719c065c6455491730011e87fe4b5120a9a008161cc32663feb9ce/jaxlib-0.4.36-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:93f1c502d08e517f842fe7b18428bb086cfd077db0ea9a2418fb21e5b4e06d3d", size = 100325986 },
+ { url = "https://files.pythonhosted.org/packages/ba/28/fece5385e736ef2f1b5bed133f8001f0fc66dd0104707381343e047b341a/jaxlib-0.4.36-cp313-cp313-win_amd64.whl", hash = "sha256:bddf436a243e83ec6bc16bcbb74d15b1960a69318c9ea796fb2109492bc52575", size = 63338694 },
]
[[package]]
@@ -1431,7 +1431,7 @@ version = "5.7.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "platformdirs" },
- { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
+ { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" },
{ name = "traitlets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 }
@@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12"
version = "11.4.5.107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
- { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
- { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12"
version = "12.1.0.106"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -2262,7 +2262,7 @@ wheels = [
[[package]]
name = "orbax-checkpoint"
-version = "0.11.0"
+version = "0.10.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py" },
@@ -2280,9 +2280,9 @@ dependencies = [
{ name = "tensorstore" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/de/b3/a9a8a6bc08ded7634a9d85ba440400172f0a11f9341897b8fd3389fad245/orbax_checkpoint-0.11.0.tar.gz", hash = "sha256:d4a0dcc81edd29191cf5a4feb9cf2a4edd31fc5da79d7be616a04f11f2a4d484", size = 253035 }
+sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/87/32/3779fa524a2272f408ab51d869fde9ff1c0ca731eedd01e40436bcf7ba2c/orbax_checkpoint-0.11.0-py3-none-any.whl", hash = "sha256:892a124fce71f3e7c71451a2b2090c0251db1097803a119a00baa377113bc9ba", size = 360423 },
+ { url = "https://files.pythonhosted.org/packages/61/19/ed366f8894923f3c8db0370e4bdd57ef843d68011dafa00d8175f4a66e1a/orbax_checkpoint-0.10.2-py3-none-any.whl", hash = "sha256:dcfc425674bd8d4934986143bd22a37cd634d034652c5d30d83c539ef8587941", size = 354306 },
]
[[package]]
@@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version < '3.11' and platform_system == 'Darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 }
wheels = [
@@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version == '3.11.*' and platform_system == 'Darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 }
wheels = [
@@ -2606,7 +2606,7 @@ name = "pytest"
version = "8.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "iniconfig" },
{ name = "packaging" },
@@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "alabaster" },
{ name = "babel" },
- { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "docutils" },
{ name = "imagesize" },
{ name = "jinja2" },
@@ -3669,14 +3669,14 @@ wheels = [
[[package]]
name = "treescope"
-version = "0.1.7"
+version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/40/34/8ad5475c26837ca400c77951bcc0788b5f291d1509ae2eda5f97b042c24a/treescope-0.1.7.tar.gz", hash = "sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3", size = 530052 }
+sdist = { url = "https://files.pythonhosted.org/packages/2f/5d/ecb176971c78d90a3f74b7878ab9d013995fed285e3386a503ca008c9b03/treescope-0.1.2.tar.gz", hash = "sha256:2e4b35780884dfdbdcf44315d1c1c98fcf41daa0ea48a5b45ecc716920f88c86", size = 402255 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/59/7d/f6da2b223749c58ec8ff95c87319196765fed05bd44dd86fb9bc4bf35f77/treescope-0.1.7-py3-none-any.whl", hash = "sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102", size = 175566 },
+ { url = "https://files.pythonhosted.org/packages/af/11/1a4d1877e5f7202bb3d0778a77b6ca222848b9b36fa65cbbc1fe12cb82b7/treescope-0.1.2-py3-none-any.whl", hash = "sha256:1811df6fbf79a5f54804e3ce2230b100547dc6350c99d973a6b9ba2bcd932e57", size = 172154 },
]
[[package]]
@@ -3684,7 +3684,7 @@ name = "triton"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },