Skip to content

Commit

Permalink
Store result times as integers instead of floats (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablolh authored Jan 28, 2025
1 parent 604dede commit 546e35d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
7 changes: 5 additions & 2 deletions emu_base/base_classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import logging
import sys
import pathlib
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from emu_base.base_classes.callback import Callback


class BackendConfig:
Expand Down Expand Up @@ -30,8 +34,7 @@ class BackendConfig:
def __init__(
self,
*,
# "Callback" is a forward type reference because of the circular import otherwise.
observables: list["Callback"] | None = None, # type: ignore # noqa: F821
observables: list[Callback] | None = None,
with_modulation: bool = False,
noise_model: NoiseModel = None,
interaction_matrix: list[list[float]] | None = None,
Expand Down
9 changes: 6 additions & 3 deletions emu_mps/mps_backend_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,14 @@ def save_simulation(self) -> None:
def fill_results(self) -> None:
normalized_state = 1 / self.state.norm() * self.state

current_time_int: int = round(self.current_time)
assert abs(self.current_time - current_time_int) < 1e-10

if self.well_prepared_qubits_filter is None:
for callback in self.config.callbacks:
callback(
self.config,
self.current_time,
current_time_int,
normalized_state,
self.hamiltonian,
self.results,
Expand All @@ -388,7 +391,7 @@ def fill_results(self) -> None:

full_mpo, full_state = None, None
for callback in self.config.callbacks:
if self.current_time not in callback.evaluation_times:
if current_time_int not in callback.evaluation_times:
continue

if full_mpo is None or full_state is None:
Expand All @@ -409,7 +412,7 @@ def fill_results(self) -> None:
),
)

callback(self.config, self.current_time, full_state, full_mpo, self.results)
callback(self.config, current_time_int, full_state, full_mpo, self.results)

def log_step_statistics(self, *, duration: float) -> None:
if self.state.factors[0].is_cuda:
Expand Down
8 changes: 7 additions & 1 deletion emu_sv/sv_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def run(self, sequence: Sequence, sv_config: BackendConfig) -> Results:
)

for callback in sv_config.callbacks:
callback(sv_config, (step + 1) * sv_config.dt, state, H, results)
callback(
sv_config,
(step + 1) * sv_config.dt,
state,
H, # type: ignore[arg-type]
results,
)

end = time()
self.log_step_statistics(
Expand Down
20 changes: 12 additions & 8 deletions emu_sv/sv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,22 @@ def __init__(
for num, obs in enumerate(self.callbacks): # monkey patch
obs_copy = copy.deepcopy(obs)
if isinstance(obs, QubitDensity):
# type: ignore[method-assign]
obs_copy.apply = MethodType(qubit_density_sv_impl, obs)
obs_copy.apply = MethodType( # type: ignore[method-assign]
qubit_density_sv_impl, obs
)
self.callbacks[num] = obs_copy
elif isinstance(obs, EnergyVariance):
# type: ignore[method-assign]
obs_copy.apply = MethodType(energy_variance_sv_impl, obs)
obs_copy.apply = MethodType( # type: ignore[method-assign]
energy_variance_sv_impl, obs
)
self.callbacks[num] = obs_copy
elif isinstance(obs, SecondMomentOfEnergy):
# type: ignore[method-assign]
obs_copy.apply = MethodType(second_momentum_sv_impl, obs)
obs_copy.apply = MethodType( # type: ignore[method-assign]
second_momentum_sv_impl, obs
)
self.callbacks[num] = obs_copy
elif isinstance(obs, CorrelationMatrix):
# type: ignore[method-assign]
obs_copy.apply = MethodType(correlation_matrix_sv_impl, obs)
obs_copy.apply = MethodType( # type: ignore[method-assign]
correlation_matrix_sv_impl, obs
)
self.callbacks[num] = obs_copy

0 comments on commit 546e35d

Please sign in to comment.