From 02d4d7a685659b6f64608ac6bb64d784806c21f3 Mon Sep 17 00:00:00 2001 From: Pablo Le Henaff Date: Mon, 27 Jan 2025 14:28:20 +0100 Subject: [PATCH] Store result times as integers instead of floats --- emu_base/base_classes/config.py | 7 +++++-- emu_mps/mps_backend_impl.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/emu_base/base_classes/config.py b/emu_base/base_classes/config.py index 5d43a40..e6cba6d 100644 --- a/emu_base/base_classes/config.py +++ b/emu_base/base_classes/config.py @@ -3,6 +3,10 @@ import logging import sys import pathlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from emu_base.base_classes.callback import Callback class BackendConfig: @@ -30,8 +34,7 @@ class BackendConfig: def __init__( self, *, - # "Callback" is a forward type reference because of the circular import otherwise. - observables: list["Callback"] | None = None, # type: ignore # noqa: F821 + observables: list[Callback] | None = None, with_modulation: bool = False, noise_model: NoiseModel = None, interaction_matrix: list[list[float]] | None = None, diff --git a/emu_mps/mps_backend_impl.py b/emu_mps/mps_backend_impl.py index 6c50d58..3ddbfb0 100644 --- a/emu_mps/mps_backend_impl.py +++ b/emu_mps/mps_backend_impl.py @@ -375,11 +375,14 @@ def save_simulation(self) -> None: def fill_results(self) -> None: normalized_state = 1 / self.state.norm() * self.state + current_time_int: int = round(self.current_time) + assert abs(self.current_time - current_time_int) < 1e-10 + if self.well_prepared_qubits_filter is None: for callback in self.config.callbacks: callback( self.config, - self.current_time, + current_time_int, normalized_state, self.hamiltonian, self.results, @@ -388,7 +391,7 @@ def fill_results(self) -> None: full_mpo, full_state = None, None for callback in self.config.callbacks: - if self.current_time not in callback.evaluation_times: + if current_time_int not in callback.evaluation_times: continue if full_mpo is None or full_state is None: @@ -409,7 +412,7 @@ def fill_results(self) -> None: ), ) - callback(self.config, self.current_time, full_state, full_mpo, self.results) + callback(self.config, current_time_int, full_state, full_mpo, self.results) def log_step_statistics(self, *, duration: float) -> None: if self.state.factors[0].is_cuda: