Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 639942359
Change-Id: Id6aeb72477e130df5ff99280351b4d1ddfa733bf
  • Loading branch information
peterjliu committed Jun 3, 2024
1 parent d13005b commit 446b874
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 99 deletions.
42 changes: 2 additions & 40 deletions nanodo/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

# pylint: disable=invalid-name,g-importing-member,g-import-not-at-top

from typing import Any, Callable, TYPE_CHECKING
from typing import Any, TYPE_CHECKING

from absl import logging
from flax.struct import dataclass
import jax
import jax.numpy as jnp
Expand All @@ -35,12 +34,10 @@


def get_init_metrics(
step_fn: Callable[["TrainState", jax.Array], "TrainState"],
state: "TrainState",
in_BxL: jax.Array,
) -> dict[str, float | int]:
"""Compute metrics only at init, as they are constant throughout training."""
metrics = _get_costs(step_fn, state, in_BxL)
metrics = {}

n_params_all = _size(state.params)

Expand Down Expand Up @@ -220,41 +217,6 @@ def _tree_to_dict(prefix: str, g: PyTree) -> dict[str, Any]:
for k, v in jax.tree_util.tree_leaves_with_path(g)}


def _get_costs(f, *args, **kwargs) -> dict[str, float]:
"""Compute FLOPS cost of evaluating `f(*args, **kwargs)`.
WARNING: `flops_compiled` are returned as `-1` on GPU:
https://github.com/google/jax/issues/16008, and in general are unreliable on
CPU/GPU: http://b/202218145.
Args:
f: JITtable function.
*args: args for `f`.
**kwargs: kwargs for `f`.
Returns:
FLOPS cost of evaluating `f(*args, **kwargs)`.
"""
e = jax.jit(f).lower(*args, **kwargs)
cost_lowered = e.cost_analysis()

try:
cost_compiled = e.compile().cost_analysis()[0]
except jax.interpreters.xla.xc.XlaRuntimeError as e:
logging.exception(e)
cost_compiled = {}

costs = {}
# Note that `bytes accessed_lowered` is very bloated since read-write
# operations overlap a lot.
for k in ["flops", "bytes accessed"]:
costs[k + "_lowered"] = cost_lowered[k]
if k in cost_compiled:
costs[k + "_compiled"] = cost_compiled[k]

return costs


def _size(g: PyTree) -> int:
return jax.tree_util.tree_reduce(lambda x, y: x + jnp.size(y), g, 0)

Expand Down
4 changes: 1 addition & 3 deletions nanodo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,14 @@ def get_metrics(
metrics = jax.device_get(metrics)
if step == 0:
metrics |= self.init_metrics
metrics["total_flops"] = self.init_metrics["flops_lowered"] * step
return metrics

def do_step(self, step: int, in_BxL: jax.Array) -> dict[str, float]:
"""Async dispatch one training step and return metrics."""
# Note that the device may be busy with the previous step.
# Avoid calling self.step as that would block until the device is ready.
if step == 0 or self.init_metrics is None:
self.init_metrics = metrics_lib.get_init_metrics(
self.step_fn, self.state, in_BxL)
self.init_metrics = metrics_lib.get_init_metrics(self.state)

self.state, metrics = self.step_fn(self.state, in_BxL)
return metrics
Expand Down
56 changes: 0 additions & 56 deletions tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,62 +142,6 @@ def test_aggregate_microbatch_metrics(self):
chex.assert_trees_all_close(
state_single.opt_state, state_multistep.opt_state, rtol=1e-2, atol=1e-1)

@parameterized.product(
x_shape=(
(3, 10),
(5, 10)
),
y_shape=(
(10, 2),
(10, 7)
),
fn_and_cost_fn=(
(lambda x, y: x @ y,
# Cost of (n, m) @ (m, k) is (n m k) * 2 (multiplies and additions).
# Memory is storing reading two matrices + writing the result.
# Each 32-bit entry is 4 bytes.
lambda x, y: ( # pylint: disable=g-long-lambda
2 * x[0] * x[1] * y[1],
4 * ((x[0] * x[1]) + (y[0] * y[1]) + (x[0] * y[1])),
)
),
(lambda x, y: jnp.maximum(x @ y, 0.),
# Cost of matmul + cost of ReLU = size of the output matrix.
# Memory is store all inputs (incl a scalar) + 5 * output size:
# 1) write and read the output of matmul
# 2) write and read a matrix of zeros
# 3) write a matrix of results
lambda x, y: ( # pylint: disable=g-long-lambda
2 * x[0] * x[1] * y[1] + x[0] * y[1],
4 * ((x[0] * x[1]) + (y[0] * y[1]) + 5 * (x[0] * y[1]) + 1),
)
),
)
)
def test_get_flops(
self,
x_shape,
y_shape,
fn_and_cost_fn
):
x = random.normal(random.PRNGKey(1), x_shape)
y = random.normal(random.PRNGKey(2), y_shape)

fn, cost_fn = fn_and_cost_fn
costs = metrics_lib._get_costs(fn, x, y)
flops_ref, memory_ref = cost_fn(x.shape, y.shape)
self.assertEqual(costs["flops_lowered"], flops_ref)

if jax.config.read("jax_enable_x64"):
memory_ref *= 2

if jax.default_backend() != "tpu":
self.assertEqual(costs["bytes accessed_lowered"], memory_ref)

if jax.default_backend() != "gpu":
# TODO: revisit after https://github.com/google/jax/issues/16008).
self.assertEqual(costs["flops_compiled"], flops_ref)

def test_gaussian(self):
rng = jax.random.PRNGKey(0)
data = jax.random.normal(rng, (100,))
Expand Down

0 comments on commit 446b874

Please sign in to comment.