From 828797423cb81e1474b6a33bd7f6346fdcbdb318 Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Mon, 2 Dec 2024 02:42:43 -0800 Subject: [PATCH] Optimize array stacking in the output by using NumPy rather than JAX. PiperOrigin-RevId: 701895415 --- torax/output.py | 12 +++++++++--- torax/simulation_app.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/torax/output.py b/torax/output.py index 5bb8560d..14605190 100644 --- a/torax/output.py +++ b/torax/output.py @@ -13,6 +13,7 @@ # limitations under the License. """Module containing functions for saving and loading simulation output.""" + from __future__ import annotations import dataclasses @@ -20,7 +21,8 @@ from absl import logging import chex import jax -from jax import numpy as jnp +import numpy as np + from torax import state from torax.config import runtime_params from torax.geometry import geometry @@ -200,7 +202,11 @@ def __init__( post_processed_output = [ state.post_processed_outputs for state in sim_outputs.sim_history ] - stack = lambda *ys: jnp.stack(ys) + + def stack(*x): + out = np.stack([np.asarray(i) for i in x]) + return out + self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map( stack, *core_profiles ) @@ -213,7 +219,7 @@ def __init__( self.post_processed_outputs: state.PostProcessedOutputs = ( jax.tree_util.tree_map(stack, *post_processed_output) ) - self.times = jnp.array([state.t for state in sim_outputs.sim_history]) + self.times = np.array([state.t for state in sim_outputs.sim_history]) chex.assert_rank(self.times, 1) self.sim_error = sim_outputs.sim_error self.source_models = source_models diff --git a/torax/simulation_app.py b/torax/simulation_app.py index dee056ce..6baff83b 100644 --- a/torax/simulation_app.py +++ b/torax/simulation_app.py @@ -58,6 +58,7 @@ def run(_): from typing import Callable, Final from absl import logging +import chex import jax from torax import output from torax import sim as sim_lib @@ -141,7 +142,7 @@ def _log_single_state( def log_simulation_output_to_stdout( core_profile_history: state.CoreProfiles, geo: geometry.Geometry, - t: jax.Array, + t: chex.Array, ) -> None: del geo _log_single_state(core_profile_history.index(0), t[0])