Skip to content

Commit

Permalink
Optimize array stacking in the output by using NumPy rather than JAX.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701895415
  • Loading branch information
sbodenstein authored and Torax team committed Dec 2, 2024
1 parent 5d6007f commit b650202
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
61 changes: 50 additions & 11 deletions torax/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

RHO_NORM = 'rho_norm'

interp_fn = jax_utils.jit(jnp.interp)
interp_fn_vmap = jax_utils.jit(jax.vmap(jnp.interp, in_axes=(None, None, 1)))


class InterpolatedParamBase(abc.ABC):
"""Base class for interpolated params.
Expand Down Expand Up @@ -70,23 +73,31 @@ class PiecewiseLinearInterpolatedParam(InterpolatedParamBase):

def __init__(self, xs: chex.Array, ys: chex.Array):
"""Initialises a piecewise-linear interpolated param, xs must be sorted."""
self._xs = xs
self._ys = ys

# TODO(b/323504363): Without this, some tests like sim.py fail due to
# this being int64 and the new interp implementation assumes ys is a float.
# Fix this upstream and remove the explicit cast + add a check that this
# is a float.
self._xs = (
xs.astype(np.float64) if np.issubdtype(xs.dtype, np.integer) else xs
)
self._ys = (
ys.astype(np.float64) if np.issubdtype(ys.dtype, np.integer) else ys
)

jax_utils.assert_rank(self.xs, 1)
if self.xs.shape[0] != self.ys.shape[0]:
raise ValueError(
'xs and ys must have the same number of elements in the first '
f'dimension. Given: {self.xs.shape} and {self.ys.shape}.'
)
diff = jnp.sum(jnp.abs(jnp.sort(self.xs) - self.xs))
jax_utils.error_if(diff, diff > 1e-8, 'xs must be sorted.')
if self.ys.ndim == 1:
self._fn = jax_utils.jit(jnp.interp)
elif self.ys.ndim == 2:
self._fn = jax_utils.jit(jax.vmap(jnp.interp, in_axes=(None, None, 1)))
else:
if ys.ndim not in (1, 2):
raise ValueError(f'ys must be either 1D or 2D. Given: {self.ys.shape}.')

xs_np = np.array(self.xs)
if not np.array_equal(np.sort(xs_np), xs_np):
raise RuntimeError('xs must be sorted.')

@property
def xs(self) -> chex.Array:
return self._xs
Expand All @@ -99,7 +110,35 @@ def get_value(
self,
x: chex.Numeric,
) -> chex.Array:
return self._fn(x, self.xs, self.ys)
x_shape = getattr(x, 'shape', ())
is_jax = isinstance(x, jax.Array)
# This function can be used inside a JITted function, where x are
# tracers. Thus are required to use the JAX versions of functions in this
# case.
interp = interp_fn if is_jax else np.interp
full = jnp.full if is_jax else np.full

match self.ys.ndim:
# This is simply interp, but with fast paths for common special cases.
case 1:
# When ys is size 1, no interpolation is needed: all values are just
# ys.
if self.ys.size == 1:
if x_shape == (): # pylint: disable=g-explicit-bool-comparison
return self.ys[0]
else:
return full(x_shape, self.ys[0], dtype=self.ys.dtype)
else:
return interp(x, self.xs, self.ys)
# The 2D case is mapped across the last dimension.
case 2:
# Special case: no interpolation needed.
if len(self.ys) == 1 and x_shape == (): # pylint: disable=g-explicit-bool-comparison
return self.ys[0]
else:
return interp_fn_vmap(x, self.xs, self.ys)
case _:
raise ValueError(f'ys must be either 1D or 2D. Given: {self.ys.shape}.')


@jax_utils.jit
Expand Down Expand Up @@ -203,7 +242,7 @@ def rhonorm1_defined_in_timerhoinput(
elif len(values) == 3:
_, rho_norm, _ = values
else:
# pytype: enable=bad-unpacking
# pytype: enable=bad-unpacking
raise ValueError('Only array tuples of length 2 or 3 are supported.')
if 1.0 not in rho_norm:
return False
Expand Down
11 changes: 8 additions & 3 deletions torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

"""Module containing functions for saving and loading simulation output."""

from __future__ import annotations

import dataclasses

from absl import logging
import chex
import jax
from jax import numpy as jnp
import numpy as np
from torax import geometry
from torax import state
from torax.config import runtime_params
Expand Down Expand Up @@ -195,7 +196,11 @@ def __init__(
post_processed_output = [
state.post_processed_outputs for state in sim_outputs.sim_history
]
stack = lambda *ys: jnp.stack(ys)

def stack(*x):
out = np.stack([np.asarray(i) for i in x])
return out

self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map(
stack, *core_profiles
)
Expand All @@ -208,7 +213,7 @@ def __init__(
self.post_processed_outputs: state.PostProcessedOutputs = (
jax.tree_util.tree_map(stack, *post_processed_output)
)
self.times = jnp.array([state.t for state in sim_outputs.sim_history])
self.times = np.array([state.t for state in sim_outputs.sim_history])
chex.assert_rank(self.times, 1)
self.sim_error = sim_outputs.sim_error
self.source_models = source_models
Expand Down
3 changes: 2 additions & 1 deletion torax/simulation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def run(_):
from absl import logging
import jax
from matplotlib import pyplot as plt
import numpy as np
import torax
from torax import geometry_provider
from torax import output
Expand Down Expand Up @@ -141,7 +142,7 @@ def _log_single_state(
def log_simulation_output_to_stdout(
core_profile_history: torax.CoreProfiles,
geo: torax.Geometry,
t: jax.Array,
t: np.ndarray,
) -> None:
del geo
_log_single_state(core_profile_history.index(0), t[0])
Expand Down

0 comments on commit b650202

Please sign in to comment.