Skip to content

Commit

Permalink
Allow resuming simulation by regularly saving to file
Browse files Browse the repository at this point in the history
- Adds autosave_prefix and autosave_dt config fields
	By default:
		emu_mps_save_<uuid>.dat
		every 10 minute
- Accidentally fix the various config entries and
carry the whole MPSConfig along. It was part of my
research for optimal implementation and takes too
much time to separate as a independent MR.
- Some fixes to the unpickled impl are necessary,
like autosave_file (to save to same existing file)
and the logger which gets messed up by the pickling.
- Fix make_H modifying its interaction_matrix argument.
That was also part of the process but in fact now
unrelated.

---> How to resume a simulation from a file

import emu_mps
results = emu_mps.MPSBackend().resume("emu_mps_save_xxx.dat")
  • Loading branch information
pablolh committed Jan 21, 2025
1 parent b09849b commit d20d248
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 122 deletions.
31 changes: 19 additions & 12 deletions emu_base/base_classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,11 @@ def __init__(
self.interaction_matrix = interaction_matrix
self.interaction_cutoff = interaction_cutoff
self.logger = logging.getLogger("global_logger")
if log_file is None:
logging.basicConfig(
level=log_level, format="%(message)s", stream=sys.stdout, force=True
) # default to stream = sys.stderr
else:
logging.basicConfig(
level=log_level,
format="%(message)s",
filename=str(log_file),
filemode="w",
force=True,
)
self.log_file = log_file
self.log_level = log_level

self.init_logging()

if noise_model is not None and (
noise_model.runs != 1
or noise_model.samples_per_run != 1
Expand All @@ -79,3 +72,17 @@ def __init__(
self.logger.warning(
"Warning: The runs and samples_per_run values of the NoiseModel are ignored!"
)

def init_logging(self) -> None:
if self.log_file is None:
logging.basicConfig(
level=self.log_level, format="%(message)s", stream=sys.stdout, force=True
) # default to stream = sys.stderr
else:
logging.basicConfig(
level=self.log_level,
format="%(message)s",
filename=str(self.log_file),
filemode="w",
force=True,
)
2 changes: 1 addition & 1 deletion emu_mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
StateResult,
SecondMomentOfEnergy,
)
from .mps_config import MPSConfig
from .mpo import MPO
from .mps import MPS, inner
from .mps_backend import MPSBackend
from .mps_config import MPSConfig


__all__ = [
Expand Down
11 changes: 8 additions & 3 deletions emu_mps/algebra.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from typing import Optional

import torch
import math

from emu_mps import MPSConfig
from emu_mps.utils import truncate_impl


Expand Down Expand Up @@ -115,8 +119,7 @@ def zip_right_step(
def zip_right(
top_factors: list[torch.tensor],
bottom_factors: list[torch.tensor],
max_error: float = 1e-5,
max_rank: int = 1024,
config: Optional[MPSConfig] = None,
) -> list[torch.tensor]:
"""
Returns a new matrix product, resulting from applying `top` to `bottom`.
Expand All @@ -136,6 +139,8 @@ def zip_right(
A final truncation sweep, from right to left,
moves back the orthogonal center to the first element.
"""
config = config if config is not None else MPSConfig()

if len(top_factors) != len(bottom_factors):
raise ValueError("Cannot multiply two matrix products of different lengths.")

Expand All @@ -146,6 +151,6 @@ def zip_right(
new_factors.append(res)
new_factors[-1] @= slider[:, :, 0]

truncate_impl(new_factors, max_error=max_error, max_rank=max_rank)
truncate_impl(new_factors, config=config)

return new_factors
4 changes: 4 additions & 0 deletions emu_mps/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import torch


DEVICE_COUNT = torch.cuda.device_count()
1 change: 1 addition & 0 deletions emu_mps/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _get_interactions_to_keep(interaction_matrix: torch.Tensor) -> list[torch.Te
returns a list of bool valued tensors,
indicating which interaction terms to keep for each bond in the MPO
"""
interaction_matrix = interaction_matrix.clone()
nqubits = interaction_matrix.size(dim=1)
middle = nqubits // 2
interaction_matrix += torch.eye(
Expand Down
6 changes: 3 additions & 3 deletions emu_mps/mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from emu_base.base_classes.operator import FullOp, QuditOp
from emu_base import Operator, State
from emu_mps.mps import MPS
from emu_mps.utils import new_left_bath, assign_devices, DEVICE_COUNT
from emu_mps.utils import new_left_bath, assign_devices
from emu_mps.constants import DEVICE_COUNT


def _validate_operator_targets(operations: FullOp, nqubits: int) -> None:
Expand Down Expand Up @@ -80,8 +81,7 @@ def __mul__(self, other: State) -> MPS:
factors = zip_right(
self.factors,
other.factors,
max_error=other.precision,
max_rank=other.max_bond_dim,
config=other.config,
)
return MPS(factors, orthogonality_center=0)

Expand Down
40 changes: 14 additions & 26 deletions emu_mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
import torch

from emu_base import State
from emu_mps import MPSConfig
from emu_mps.algebra import add_factors, scale_factors
from emu_mps.utils import (
DEVICE_COUNT,
apply_measurement_errors,
assign_devices,
truncate_impl,
tensor_trace,
n_operator,
)
from emu_mps.constants import DEVICE_COUNT


class MPS(State):
Expand All @@ -27,17 +28,13 @@ class MPS(State):
Only qubits are supported.
"""

DEFAULT_MAX_BOND_DIM: int = 1024
DEFAULT_PRECISION: float = 1e-5

def __init__(
self,
factors: List[torch.Tensor],
/,
*,
orthogonality_center: Optional[int] = None,
precision: float = DEFAULT_PRECISION,
max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
config: Optional[MPSConfig] = None,
num_gpus_to_use: Optional[int] = DEVICE_COUNT,
):
"""
Expand All @@ -58,9 +55,7 @@ def __init__(
num_gpus_to_use: distribute the factors over this many GPUs
0=all factors to cpu, None=keep the existing device assignment.
"""
self.precision = precision
self.max_bond_dim = max_bond_dim

self.config = config if config is not None else MPSConfig()
assert all(
factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
), "The dimensions of consecutive tensors should match"
Expand All @@ -84,20 +79,20 @@ def __init__(
def make(
cls,
num_sites: int,
precision: float = DEFAULT_PRECISION,
max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
config: Optional[MPSConfig] = None,
num_gpus_to_use: int = DEVICE_COUNT,
) -> MPS:
"""
Returns a MPS in ground state |000..0>.
Args:
num_sites: the number of qubits
precision: the precision with which to truncate here or in tdvp
max_bond_dim: the maximum bond dimension to allow
config: the MPSConfig
num_gpus_to_use: distribute the factors over this many GPUs
0=all factors to cpu
"""
config = config if config is not None else MPSConfig()

if num_sites <= 1:
raise ValueError("For 1 qubit states, do state vector")

Expand All @@ -106,8 +101,7 @@ def make(
torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128)
for _ in range(num_sites)
],
precision=precision,
max_bond_dim=max_bond_dim,
config=config,
num_gpus_to_use=num_gpus_to_use,
orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
)
Expand Down Expand Up @@ -165,15 +159,11 @@ def truncate(self) -> None:
"""
SVD based truncation of the state. Puts the orthogonality center at the first qubit.
Calls orthogonalize on the last qubit, and then sweeps a series of SVDs right-left.
Uses self.precision and self.max_bond_dim for determining accuracy.
Uses self.config for determining accuracy.
An in-place operation.
"""
self.orthogonalize(self.num_sites - 1)
truncate_impl(
self.factors,
max_error=self.precision,
max_rank=self.max_bond_dim,
)
truncate_impl(self.factors, config=self.config)
self.orthogonality_center = 0

def get_max_bond_dim(self) -> int:
Expand Down Expand Up @@ -303,7 +293,7 @@ def __add__(self, other: State) -> MPS:
"""
Returns the sum of two MPSs, computed with a direct algorithm.
The resulting MPS is orthogonalized on the first site and truncated
up to `self.precision`.
up to `self.config.precision`.
Args:
other: the other state
Expand All @@ -315,8 +305,7 @@ def __add__(self, other: State) -> MPS:
new_tt = add_factors(self.factors, other.factors)
result = MPS(
new_tt,
precision=self.precision,
max_bond_dim=self.max_bond_dim,
config=self.config,
num_gpus_to_use=None,
orthogonality_center=None, # Orthogonality is lost.
)
Expand All @@ -341,8 +330,7 @@ def __rmul__(self, scalar: complex) -> MPS:
factors = scale_factors(self.factors, scalar, which=which)
return MPS(
factors,
precision=self.precision,
max_bond_dim=self.max_bond_dim,
config=self.config,
num_gpus_to_use=None,
orthogonality_center=self.orthogonality_center,
)
Expand Down
43 changes: 40 additions & 3 deletions emu_mps/mps_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from pulser import Sequence

from emu_base import Backend, BackendConfig, Results
from emu_mps.mps_config import MPSConfig
from emu_mps.mps_backend_impl import create_impl
from emu_mps.mps_backend_impl import create_impl, MPSBackendImpl
from pulser import Sequence
import pickle
import os
import time
import logging
import pathlib


class MPSBackend(Backend):
Expand All @@ -11,6 +15,32 @@ class MPSBackend(Backend):
aka tensor trains.
"""

def resume(self, autosave_file: str | pathlib.Path) -> Results:
"""
Resume simulation from autosave file.
Only resume simulations from data you trust!
Unpickling of untrusted data is not safe.
"""
if isinstance(autosave_file, str):
autosave_file = pathlib.Path(autosave_file)

if not autosave_file.is_file():
raise ValueError(f"Not a file: {autosave_file}")

with open(autosave_file, "rb") as f:
impl: MPSBackendImpl = pickle.load(f)

impl.autosave_file = autosave_file
impl.last_save_time = time.time()
impl.config.init_logging() # FIXME: might be best to take logger object out of config.

logging.getLogger("global_logger").warning(
f"Resuming simulation from file {autosave_file}\n"
f"Saving simulation state every {impl.config.autosave_dt} seconds"
)

return self._run(impl)

def run(self, sequence: Sequence, mps_config: BackendConfig) -> Results:
"""
Emulates the given sequence.
Expand All @@ -29,7 +59,14 @@ def run(self, sequence: Sequence, mps_config: BackendConfig) -> Results:
impl = create_impl(sequence, mps_config)
impl.init() # This is separate from the constructor for testing purposes.

return self._run(impl)

@staticmethod
def _run(impl: MPSBackendImpl) -> Results:
while not impl.is_finished():
impl.progress()

if impl.autosave_file.is_file():
os.remove(impl.autosave_file)

return impl.results
Loading

0 comments on commit d20d248

Please sign in to comment.