diff --git a/nanodo/metrics.py b/nanodo/metrics.py index bc49dd6..ea36b89 100644 --- a/nanodo/metrics.py +++ b/nanodo/metrics.py @@ -15,9 +15,8 @@ # pylint: disable=invalid-name,g-importing-member,g-import-not-at-top -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING -from absl import logging from flax.struct import dataclass import jax import jax.numpy as jnp @@ -35,12 +34,10 @@ def get_init_metrics( - step_fn: Callable[["TrainState", jax.Array], "TrainState"], state: "TrainState", - in_BxL: jax.Array, ) -> dict[str, float | int]: """Compute metrics only at init, as they are constant throughout training.""" - metrics = _get_costs(step_fn, state, in_BxL) + metrics = {} n_params_all = _size(state.params) @@ -220,41 +217,6 @@ def _tree_to_dict(prefix: str, g: PyTree) -> dict[str, Any]: for k, v in jax.tree_util.tree_leaves_with_path(g)} -def _get_costs(f, *args, **kwargs) -> dict[str, float]: - """Compute FLOPS cost of evaluating `f(*args, **kwargs)`. - - WARNING: `flops_compiled` are returned as `-1` on GPU: - https://github.com/google/jax/issues/16008, and in general are unreliable on - CPU/GPU: http://b/202218145. - - Args: - f: JITtable function. - *args: args for `f`. - **kwargs: kwargs for `f`. - - Returns: - FLOPS cost of evaluating `f(*args, **kwargs)`. - """ - e = jax.jit(f).lower(*args, **kwargs) - cost_lowered = e.cost_analysis() - - try: - cost_compiled = e.compile().cost_analysis()[0] - except jax.interpreters.xla.xc.XlaRuntimeError as e: - logging.exception(e) - cost_compiled = {} - - costs = {} - # Note that `bytes accessed_lowered` is very bloated since read-write - # operations overlap a lot. - for k in ["flops", "bytes accessed"]: - costs[k + "_lowered"] = cost_lowered[k] - if k in cost_compiled: - costs[k + "_compiled"] = cost_compiled[k] - - return costs - - def _size(g: PyTree) -> int: return jax.tree_util.tree_reduce(lambda x, y: x + jnp.size(y), g, 0) diff --git a/nanodo/train.py b/nanodo/train.py index decea3f..c4edf6b 100644 --- a/nanodo/train.py +++ b/nanodo/train.py @@ -280,7 +280,6 @@ def get_metrics( metrics = jax.device_get(metrics) if step == 0: metrics |= self.init_metrics - metrics["total_flops"] = self.init_metrics["flops_lowered"] * step return metrics def do_step(self, step: int, in_BxL: jax.Array) -> dict[str, float]: @@ -288,8 +287,7 @@ def do_step(self, step: int, in_BxL: jax.Array) -> dict[str, float]: # Note that the device may be busy with the previous step. # Avoid calling self.step as that would block until the device is ready. if step == 0 or self.init_metrics is None: - self.init_metrics = metrics_lib.get_init_metrics( - self.step_fn, self.state, in_BxL) + self.init_metrics = metrics_lib.get_init_metrics(self.state) self.state, metrics = self.step_fn(self.state, in_BxL) return metrics diff --git a/tests/metrics_test.py b/tests/metrics_test.py index 8bfdaf8..e712d6e 100644 --- a/tests/metrics_test.py +++ b/tests/metrics_test.py @@ -142,62 +142,6 @@ def test_aggregate_microbatch_metrics(self): chex.assert_trees_all_close( state_single.opt_state, state_multistep.opt_state, rtol=1e-2, atol=1e-1) - @parameterized.product( - x_shape=( - (3, 10), - (5, 10) - ), - y_shape=( - (10, 2), - (10, 7) - ), - fn_and_cost_fn=( - (lambda x, y: x @ y, - # Cost of (n, m) @ (m, k) is (n m k) * 2 (multiplies and additions). - # Memory is storing reading two matrices + writing the result. - # Each 32-bit entry is 4 bytes. - lambda x, y: ( # pylint: disable=g-long-lambda - 2 * x[0] * x[1] * y[1], - 4 * ((x[0] * x[1]) + (y[0] * y[1]) + (x[0] * y[1])), - ) - ), - (lambda x, y: jnp.maximum(x @ y, 0.), - # Cost of matmul + cost of ReLU = size of the output matrix. - # Memory is store all inputs (incl a scalar) + 5 * output size: - # 1) write and read the output of matmul - # 2) write and read a matrix of zeros - # 3) write a matrix of results - lambda x, y: ( # pylint: disable=g-long-lambda - 2 * x[0] * x[1] * y[1] + x[0] * y[1], - 4 * ((x[0] * x[1]) + (y[0] * y[1]) + 5 * (x[0] * y[1]) + 1), - ) - ), - ) - ) - def test_get_flops( - self, - x_shape, - y_shape, - fn_and_cost_fn - ): - x = random.normal(random.PRNGKey(1), x_shape) - y = random.normal(random.PRNGKey(2), y_shape) - - fn, cost_fn = fn_and_cost_fn - costs = metrics_lib._get_costs(fn, x, y) - flops_ref, memory_ref = cost_fn(x.shape, y.shape) - self.assertEqual(costs["flops_lowered"], flops_ref) - - if jax.config.read("jax_enable_x64"): - memory_ref *= 2 - - if jax.default_backend() != "tpu": - self.assertEqual(costs["bytes accessed_lowered"], memory_ref) - - if jax.default_backend() != "gpu": - # TODO: revisit after https://github.com/google/jax/issues/16008). - self.assertEqual(costs["flops_compiled"], flops_ref) - def test_gaussian(self): rng = jax.random.PRNGKey(0) data = jax.random.normal(rng, (100,))