Skip to content

Commit

Permalink
Store result times as integers instead of floats
Browse files Browse the repository at this point in the history
  • Loading branch information
pablolh committed Jan 28, 2025
1 parent 604dede commit 02d4d7a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
7 changes: 5 additions & 2 deletions emu_base/base_classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import logging
import sys
import pathlib
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from emu_base.base_classes.callback import Callback


class BackendConfig:
Expand Down Expand Up @@ -30,8 +34,7 @@ class BackendConfig:
def __init__(
self,
*,
# "Callback" is a forward type reference because of the circular import otherwise.
observables: list["Callback"] | None = None, # type: ignore # noqa: F821
observables: list[Callback] | None = None,
with_modulation: bool = False,
noise_model: NoiseModel = None,
interaction_matrix: list[list[float]] | None = None,
Expand Down
9 changes: 6 additions & 3 deletions emu_mps/mps_backend_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,14 @@ def save_simulation(self) -> None:
def fill_results(self) -> None:
normalized_state = 1 / self.state.norm() * self.state

current_time_int: int = round(self.current_time)
assert abs(self.current_time - current_time_int) < 1e-10

if self.well_prepared_qubits_filter is None:
for callback in self.config.callbacks:
callback(
self.config,
self.current_time,
current_time_int,
normalized_state,
self.hamiltonian,
self.results,
Expand All @@ -388,7 +391,7 @@ def fill_results(self) -> None:

full_mpo, full_state = None, None
for callback in self.config.callbacks:
if self.current_time not in callback.evaluation_times:
if current_time_int not in callback.evaluation_times:
continue

if full_mpo is None or full_state is None:
Expand All @@ -409,7 +412,7 @@ def fill_results(self) -> None:
),
)

callback(self.config, self.current_time, full_state, full_mpo, self.results)
callback(self.config, current_time_int, full_state, full_mpo, self.results)

def log_step_statistics(self, *, duration: float) -> None:
if self.state.factors[0].is_cuda:
Expand Down

0 comments on commit 02d4d7a

Please sign in to comment.