Skip to content

Commit

Permalink
Merge branch 'main' into sg/9-emu-mps-number-of-qubits-bond-dimension…
Browse files Browse the repository at this point in the history
…-benchmark
  • Loading branch information
sgrava committed Jan 23, 2025
2 parents 01964ae + a505c43 commit a46d086
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 75 deletions.
65 changes: 63 additions & 2 deletions emu_sv/custom_callback_implementations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,78 @@
import math
from typing import Any

import torch

from emu_base.base_classes.config import BackendConfig
from emu_base.base_classes.default_callbacks import QubitDensity
from emu_base.base_classes.default_callbacks import (
QubitDensity,
EnergyVariance,
SecondMomentOfEnergy,
CorrelationMatrix,
)
from emu_base.base_classes.operator import Operator

from emu_sv import StateVector
from emu_sv.hamiltonian import RydbergHamiltonian


def custom_qubit_density(
def qubit_density_sv_impl(
self: QubitDensity, config: BackendConfig, t: int, state: StateVector, H: Operator
) -> Any:

num_qubits = int(math.log2(len(state.vector)))
state_tensor = state.vector.reshape((2,) * num_qubits)
return [(state_tensor.select(i, 1).norm() ** 2).item() for i in range(num_qubits)]


def correlation_matrix_sv_impl(
self: CorrelationMatrix,
config: BackendConfig,
t: int,
state: StateVector,
H: Operator,
) -> Any:
"""'Sparse' implementation of <𝜓| nᵢ nⱼ | 𝜓 >"""
num_qubits = int(math.log2(len(state.vector)))
state_tensor = state.vector.reshape((2,) * num_qubits)

correlation_matrix = []
for numi in range(num_qubits):
one_correlation = []
select_i = state_tensor.select(numi, 1)
for numj in range(num_qubits):
if numj < numi:
one_correlation.append((select_i.select(numj, 1).norm() ** 2).item())
elif numj > numi: # the selected atom is deleted
one_correlation.append((select_i.select(numj - 1, 1).norm() ** 2).item())
else:
one_correlation.append((select_i.norm() ** 2).item())

correlation_matrix.append(one_correlation)
return correlation_matrix


def energy_variance_sv_impl(
self: EnergyVariance,
config: BackendConfig,
t: int,
state: StateVector,
H: RydbergHamiltonian,
) -> Any:
hstate = H * state.vector
h_squared = torch.vdot(hstate, hstate)
h_state = torch.vdot(state.vector, hstate)
return (h_squared.real - h_state.real**2).item()


def second_momentum_sv_impl(
self: SecondMomentOfEnergy,
config: BackendConfig,
t: int,
state: StateVector,
H: RydbergHamiltonian,
) -> Any:

hstate = H * state.vector
h_squared = torch.vdot(hstate, hstate)
return (h_squared.real).item()
12 changes: 10 additions & 2 deletions emu_sv/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch

from emu_sv.state_vector import StateVector


class RydbergHamiltonian:
"""
Expand Down Expand Up @@ -34,7 +36,7 @@ class RydbergHamiltonian:
strengths between each pair of qubits.
Methods:
__matmul__(vec): Performs matrix-vector multiplication with a vector.
__mul__(vec): Performs matrix-vector multiplication with a vector.
_diag_elemts(): Constructs the diagonal elements of the Hamiltonian
based on `deltas` and `interaction_matrix`.
_size(): Calculates the memory size of the `RydbergHamiltonian` object in MiB.
Expand All @@ -54,7 +56,7 @@ def __init__(
self.diag: torch.Tensor = self._create_diagonal().to(device=device)
self.inds = torch.tensor([1, 0], device=device) # flips the state, for 𝜎ₓ

def __matmul__(self, vec: torch.Tensor) -> torch.Tensor:
def __mul__(self, vec: torch.Tensor) -> torch.Tensor:
"""
Performs a matrix-vector multiplication between the `RydbergHamiltonian` form and
a torch vector
Expand Down Expand Up @@ -142,3 +144,9 @@ def _create_diagonal(self) -> torch.Tensor:
) # note the j-1 since i was already removed
i_j_fixed += self.interaction_matrix[i, j]
return diag

def expect(self, state: StateVector) -> float | complex:
assert isinstance(
state, StateVector
), "currently, only expectation values of StateVectors are supported"
return torch.vdot(state.vector, self * state.vector).item() # type: ignore [no-any-return]
2 changes: 1 addition & 1 deletion emu_sv/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make(cls, num_sites: int, gpu: bool = False) -> StateVector:
tensor([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=torch.complex128)
"""

device = "gpu" if gpu else "cpu"
device = "cuda" if gpu else "cpu"
ground_state = torch.zeros(2**num_sites, dtype=dtype, device=device)
ground_state[0] = torch.tensor(1.0, dtype=dtype, device=device)
return cls(ground_state)
Expand Down
8 changes: 2 additions & 6 deletions emu_sv/sv_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pulser import Sequence
from emu_base.pulser_adapter import PulserData
from emu_sv.time_evolution import do_time_step
from emu_sv import StateVector, DenseOperator
from emu_sv import StateVector
import torch
from time import time
from resource import RUSAGE_SELF, getrusage
Expand Down Expand Up @@ -53,7 +53,7 @@ def run(self, sequence: Sequence, sv_config: BackendConfig) -> Results:

start = time()

state.vector = do_time_step(
state.vector, H = do_time_step(
dt,
omega[step],
delta[step],
Expand All @@ -62,10 +62,6 @@ def run(self, sequence: Sequence, sv_config: BackendConfig) -> Results:
sv_config.krylov_tolerance,
)

# TODO: Rydberg Mamiltonian should be a dense operator

H = DenseOperator # Energy, SecondMomentun... and Variance are not implemented

for callback in sv_config.callbacks:
callback(sv_config, (step + 1) * sv_config.dt, state, H, results)

Expand Down
44 changes: 23 additions & 21 deletions emu_sv/sv_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from emu_base.base_classes import (
BitStrings,
StateResult,
CorrelationMatrix,
QubitDensity,
Fidelity,
EnergyVariance,
SecondMomentOfEnergy,
)

import copy


from emu_base import BackendConfig
from emu_sv import StateVector
from typing import Any

from emu_sv.custom_callback_implementations import (
qubit_density_sv_impl,
energy_variance_sv_impl,
second_momentum_sv_impl,
correlation_matrix_sv_impl,
)

from types import MethodType

from emu_sv.custom_callback_implementations import custom_qubit_density


class SVConfig(BackendConfig):
"""
Expand Down Expand Up @@ -55,26 +60,23 @@ def __init__(
):
super().__init__(**kwargs)

observables = set(map(type, self.callbacks))

supported_observables = {
BitStrings,
StateResult,
CorrelationMatrix,
QubitDensity,
Fidelity,
}

unsupported_observables = observables - supported_observables
if unsupported_observables:
raise ValueError(f"{unsupported_observables} are not supported in emu-sv")
self.initial_state = initial_state
self.dt = dt
self.max_krylov_dim = max_krylov_dim
self.gpu = gpu
self.krylov_tolerance = krylov_tolerance

for obs in self.callbacks:
for num, obs in enumerate(self.callbacks): # monkey patch
obs_copy = copy.deepcopy(obs)
if isinstance(obs, QubitDensity):
# mypy: ignoring dynamically replacing method
obs.apply = MethodType(custom_qubit_density, obs) # type: ignore[method-assign]
obs_copy.apply = MethodType(qubit_density_sv_impl, obs) # type: ignore[method-assign]
self.callbacks[num] = obs_copy
elif isinstance(obs, EnergyVariance):
obs_copy.apply = MethodType(energy_variance_sv_impl, obs) # type: ignore[method-assign]
self.callbacks[num] = obs_copy
elif isinstance(obs, SecondMomentOfEnergy):
obs_copy.apply = MethodType(second_momentum_sv_impl, obs) # type: ignore[method-assign]
self.callbacks[num] = obs_copy
elif isinstance(obs, CorrelationMatrix):
obs_copy.apply = MethodType(correlation_matrix_sv_impl, obs) # type: ignore[method-assign]
self.callbacks[num] = obs_copy
15 changes: 10 additions & 5 deletions emu_sv/time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@ def do_time_step(
full_interaction_matrix: torch.Tensor,
state_vector: torch.Tensor,
krylov_tolerance: float,
) -> torch.Tensor:
) -> tuple[torch.Tensor, RydbergHamiltonian]:
ham = RydbergHamiltonian(
omegas=omega,
deltas=delta,
interaction_matrix=full_interaction_matrix,
device=state_vector.device,
)
op = lambda x: -1j * dt * (ham @ x)

return krylov_exp(
op, state_vector, norm_tolerance=krylov_tolerance, exp_tolerance=krylov_tolerance
op = lambda x: -1j * dt * (ham * x)
return (
krylov_exp(
op,
state_vector,
norm_tolerance=krylov_tolerance,
exp_tolerance=krylov_tolerance,
),
ham,
)
Loading

0 comments on commit a46d086

Please sign in to comment.