Skip to content

Commit

Permalink
Add specific model typing for nnx.Optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelroed committed Jan 10, 2025
1 parent adbad95 commit 8ab5d5c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions flax/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from __future__ import annotations

import typing as tp

import jax
import jax.numpy as jnp
import optax
Expand All @@ -23,6 +25,8 @@
from flax.nnx.object import Object
from flax.nnx.variablelib import Variable, VariableState

M = tp.TypeVar('M', bound=nnx.Module)

# TODO: add tests and docstrings


Expand Down Expand Up @@ -101,7 +105,7 @@ def optimizer_update_variables(x, update):
return jax.tree.map(optimizer_update_variables, opt_state, updates)


class Optimizer(Object):
class Optimizer(Object, tp.Generic[M]):
"""Simple train state for the common case with a single Optax optimizer.
Example usage::
Expand Down Expand Up @@ -168,7 +172,7 @@ class Optimizer(Object):

def __init__(
self,
model: nnx.Module,
model: M,
tx: optax.GradientTransformation,
wrt: filterlib.Filter = nnx.Param,
):
Expand Down

0 comments on commit 8ab5d5c

Please sign in to comment.