Skip to content

Commit

Permalink
Fix mypy with Torch (#21)
Browse files Browse the repository at this point in the history
Add torch as mypy dependency, which revealed lot of
new typing errors to be fixed.
  • Loading branch information
pablolh authored Jan 28, 2025
1 parent d086260 commit 604dede
Show file tree
Hide file tree
Showing 21 changed files with 165 additions and 120 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/lint_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10' #don't cache this, it will create a cache without all the good stuff
python-version: '3.10'
cache: 'pip'
- run: pip install pre-commit pyproject-flake8
- run: pre-commit install
- run: pre-commit run --all-files
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ repos:
hooks:
- id: mypy
exclude: examples|test|ci
additional_dependencies: [torch==2.5.0] # The version in pyproject.toml must match
4 changes: 2 additions & 2 deletions emu_base/math/krylov_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(


def krylov_exp_impl(
op: Callable,
op: Callable[[torch.Tensor], torch.Tensor],
v: torch.Tensor,
is_hermitian: bool, # note: complex-proportional to its adjoint is enough
exp_tolerance: float,
Expand Down Expand Up @@ -103,7 +103,7 @@ def krylov_exp_impl(


def krylov_exp(
op: torch.Tensor,
op: Callable[[torch.Tensor], torch.Tensor],
v: torch.Tensor,
exp_tolerance: float,
norm_tolerance: float,
Expand Down
7 changes: 5 additions & 2 deletions emu_base/pulser_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,18 @@ def _xy_interaction(sequence: pulser.Sequence) -> torch.Tensor:
qubit_positions = _get_qubit_positions(sequence.register)
interaction_matrix = torch.zeros(num_qubits, num_qubits)
mag_field = torch.tensor(sequence.magnetic_field) # by default [0.0,0.0,30.0]
mag_norm = torch.norm(mag_field)
mag_norm = torch.linalg.norm(mag_field)

for numi in range(len(qubit_positions)):
for numj in range(numi + 1, len(qubit_positions)):
cosine = 0
if mag_norm >= 1e-8: # selected by hand
cosine = torch.dot(
(qubit_positions[numi] - qubit_positions[numj]), mag_field
) / (torch.norm(qubit_positions[numi] - qubit_positions[numj]) * mag_norm)
) / (
torch.linalg.norm(qubit_positions[numi] - qubit_positions[numj])
* mag_norm
)

interaction_matrix[numi][numj] = (
c3 # check this value with pulser people
Expand Down
8 changes: 4 additions & 4 deletions emu_base/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch


def dist2(left: torch.tensor, right: torch.tensor) -> torch.Tensor:
return torch.norm(left - right) ** 2
def dist2(left: torch.Tensor, right: torch.Tensor) -> float:
return torch.dist(left, right).item() ** 2


def dist3(left: torch.tensor, right: torch.tensor) -> torch.Tensor:
return torch.norm(left - right) ** 3
def dist3(left: torch.Tensor, right: torch.Tensor) -> float:
return torch.dist(left, right).item() ** 3
22 changes: 11 additions & 11 deletions emu_mps/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


def add_factors(
left: list[torch.tensor], right: list[torch.tensor]
) -> list[torch.tensor]:
left: list[torch.Tensor], right: list[torch.Tensor]
) -> list[torch.Tensor]:
"""
Direct sum algorithm implementation to sum two tensor trains (MPS/MPO).
It assumes the left and right bond are along the dimension 0 and -1 of each tensor.
Expand Down Expand Up @@ -52,19 +52,19 @@ def add_factors(


def scale_factors(
factors: list[torch.tensor], scalar: complex, *, which: int
) -> list[torch.tensor]:
factors: list[torch.Tensor], scalar: complex, *, which: int
) -> list[torch.Tensor]:
"""
Returns a new list of factors where the tensor at the given index is scaled by `scalar`.
"""
return [scalar * f if i == which else f for i, f in enumerate(factors)]


def zip_right_step(
slider: torch.tensor,
top: torch.tensor,
bottom: torch.tensor,
) -> torch.tensor:
slider: torch.Tensor,
top: torch.Tensor,
bottom: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns a new `MPS/O` factor of the result of the multiplication MPO @ MPS/O,
and the updated slider, performing a single step of the
Expand Down Expand Up @@ -117,10 +117,10 @@ def zip_right_step(


def zip_right(
top_factors: list[torch.tensor],
bottom_factors: list[torch.tensor],
top_factors: list[torch.Tensor],
bottom_factors: list[torch.Tensor],
config: Optional[MPSConfig] = None,
) -> list[torch.tensor]:
) -> list[torch.Tensor]:
"""
Returns a new matrix product, resulting from applying `top` to `bottom`.
The resulting factors are:
Expand Down
4 changes: 2 additions & 2 deletions emu_mps/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def make_H(

interactions_to_keep = _get_interactions_to_keep(interaction_matrix)

cores = [_first_factor(interactions_to_keep[0].item())]
cores = [_first_factor(interactions_to_keep[0].item() != 0.0)]

if nqubits > 2:
for i in range(1, middle):
Expand Down Expand Up @@ -395,7 +395,7 @@ def make_H(
)
)
if nqubits == 2:
scale = interaction_matrix[0, 1]
scale = interaction_matrix[0, 1].item()
elif interactions_to_keep[-1][0]:
scale = 1.0
else:
Expand Down
43 changes: 23 additions & 20 deletions emu_mps/mpo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import itertools
from typing import Any, List, cast, Iterable, Optional
from typing import Any, List, Iterable, Optional

import torch

Expand Down Expand Up @@ -175,14 +175,15 @@ def from_operator_string(
Returns:
the operator in MPO form.
"""
operators_with_tensors: dict[str, torch.Tensor | QuditOp] = dict(operators)

_validate_operator_targets(operations, nqubits)

basis = set(basis)
if basis == {"r", "g"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"gg": torch.tensor(
[[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
).reshape(1, 2, 2, 1),
Expand All @@ -197,9 +198,9 @@ def from_operator_string(
).reshape(1, 2, 2, 1),
}
elif basis == {"0", "1"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"00": torch.tensor(
[[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
).reshape(1, 2, 2, 1),
Expand All @@ -218,26 +219,28 @@ def from_operator_string(

mpos = []
for coeff, tensorop in operations:
# this function will recurse through the operators, and replace any definitions
# this function will recurse through the operators_with_tensors,
# and replace any definitions
# in terms of strings by the computed tensor
def replace_operator_string(op: QuditOp | torch.Tensor) -> torch.Tensor:
if isinstance(op, dict):
for opstr, coeff in op.items():
tensor = replace_operator_string(operators[opstr])
operators[opstr] = tensor
op[opstr] = tensor * coeff
op = sum(cast(list[torch.Tensor], op.values()))
return op
if isinstance(op, torch.Tensor):
return op

result = torch.zeros(1, 2, 2, 1, dtype=torch.complex128)
for opstr, coeff in op.items():
tensor = replace_operator_string(operators_with_tensors[opstr])
operators_with_tensors[opstr] = tensor
result += tensor * coeff
return result

factors = [
torch.eye(2, 2, dtype=torch.complex128).reshape(1, 2, 2, 1)
] * nqubits

for i, op in enumerate(tensorop):
tensorop[i] = (replace_operator_string(op[0]), op[1])

for op in tensorop:
for i in op[1]:
factors[i] = op[0]
factor = replace_operator_string(op[0])
for target_qubit in op[1]:
factors[target_qubit] = factor

mpos.append(coeff * MPO(factors, **kwargs))
return sum(mpos[1:], start=mpos[0])
2 changes: 1 addition & 1 deletion emu_mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _sample_implementation(self, rnd_vector: torch.Tensor) -> str:
num_qubits = len(self.factors)

bitstring = ""
acc_mps_j: torch.tensor = self.factors[0]
acc_mps_j: torch.Tensor = self.factors[0]

for qubit in range(num_qubits):
# comp_basis is a projector: 0 is for ket |0> and 1 for ket |1>
Expand Down
7 changes: 3 additions & 4 deletions emu_mps/mps_backend_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def progress(self) -> None:
)
if not self.has_lindblad_noise:
# Free memory because it won't be used anymore
self.right_baths[-2] = None
self.right_baths[-2] = torch.zeros(0)

self._evolve(self.tdvp_index, dt=-delta_time / 2)
self.left_baths.pop()
Expand Down Expand Up @@ -453,13 +453,12 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
"""

jump_threshold: float
aggregated_lindblad_ops: Optional[torch.Tensor]
aggregated_lindblad_ops: torch.Tensor
norm_gap_before_jump: float
root_finder: Optional[BrentsRootFinder]

def __init__(self, config: MPSConfig, pulser_data: PulserData):
super().__init__(config, pulser_data)
self.aggregated_lindblad_ops = None
self.lindblad_ops = pulser_data.lindblad_ops
self.root_finder = None

Expand Down Expand Up @@ -520,7 +519,7 @@ def do_random_quantum_jump(self) -> None:
for qubit in range(self.state.num_sites)
for op in self.lindblad_ops
],
weights=jump_operator_weights.reshape(-1),
weights=jump_operator_weights.reshape(-1).tolist(),
)[0]

self.state.apply(jumped_qubit_index, jump_operator)
Expand Down
3 changes: 2 additions & 1 deletion emu_mps/tdvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def new_right_bath(
) -> torch.Tensor:
bath = torch.tensordot(state, bath, ([2], [2]))
bath = torch.tensordot(op.to(bath.device), bath, ([2, 3], [1, 3]))
return torch.tensordot(state.conj(), bath, ([1, 2], [1, 3]))
bath = torch.tensordot(state.conj(), bath, ([1, 2], [1, 3]))
return bath


"""
Expand Down
9 changes: 5 additions & 4 deletions emu_mps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ def new_left_bath(
# this order is more efficient than contracting the op first in general
bath = torch.tensordot(bath, state.conj(), ([0], [0]))
bath = torch.tensordot(bath, op.to(bath.device), ([0, 2], [0, 1]))
return torch.tensordot(bath, state, ([0, 2], [0, 1]))
bath = torch.tensordot(bath, state, ([0, 2], [0, 1]))
return bath


def _determine_cutoff_index(d: torch.Tensor, max_error: float) -> int:
assert max_error > 0
squared_max_error = max_error * max_error
acc = 0
acc = 0.0
for i in range(d.shape[0]):
acc += d[i]
acc += d[i].item()
if acc > squared_max_error:
return i
return 0 # type: ignore[no-any-return]
Expand Down Expand Up @@ -60,7 +61,7 @@ def split_tensor(


def truncate_impl(
factors: list[torch.tensor],
factors: list[torch.Tensor],
config: MPSConfig,
) -> None:
"""
Expand Down
52 changes: 27 additions & 25 deletions emu_sv/dense_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import itertools
from typing import Any, cast, Iterable
from typing import Any, Iterable

import torch
from emu_base.base_classes.operator import FullOp, QuditOp
Expand Down Expand Up @@ -114,7 +114,7 @@ def expect(self, state: State) -> float | complex:
), "currently, only expectation values of StateVectors are \
supported"

return torch.vdot(state.vector, self.matrix @ state.vector) # type: ignore [no-any-return]
return torch.vdot(state.vector, self.matrix @ state.vector).item()

@staticmethod
def from_operator_string(
Expand All @@ -140,20 +140,22 @@ def from_operator_string(

_validate_operator_targets(operations, nqubits)

operators_with_tensors: dict[str, torch.Tensor | QuditOp] = dict(operators)

basis = set(basis)
if basis == {"r", "g"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"gg": torch.tensor([[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128),
"gr": torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128),
"rg": torch.tensor([[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128),
"rr": torch.tensor([[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128),
}
elif basis == {"0", "1"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"00": torch.tensor([[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128),
"01": torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128),
"10": torch.tensor([[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128),
Expand All @@ -164,29 +166,29 @@ def from_operator_string(

accum_res = torch.zeros(2**nqubits, 2**nqubits, dtype=torch.complex128)
for coeff, tensorop in operations:
# this function will recurse through the operators, and replace any definitions
# in terms of strings by the computed matrix
# this function will recurse through the operators_with_tensors,
# and replace any definitions in terms of strings by the computed matrix
def replace_operator_string(op: QuditOp | torch.Tensor) -> torch.Tensor:
if isinstance(op, dict):
for opstr, coeff in op.items():
tensor = replace_operator_string(operators[opstr])
operators[opstr] = tensor
op[opstr] = tensor * coeff
op = sum(cast(list[torch.Tensor], op.values()))
return op
if isinstance(op, torch.Tensor):
return op

pre_dense_op = [torch.eye(2, 2, dtype=torch.complex128)] * nqubits
result = torch.zeros(2, 2, dtype=torch.complex128)
for opstr, coeff in op.items():
tensor = replace_operator_string(operators_with_tensors[opstr])
operators_with_tensors[opstr] = tensor
result += tensor * coeff
return result

for i, op in enumerate(tensorop):
tensorop[i] = (replace_operator_string(op[0]), op[1])
total_op_per_qubit = [torch.eye(2, 2, dtype=torch.complex128)] * nqubits

for op in tensorop:
for i in op[1]:
pre_dense_op[i] = op[0]
factor = replace_operator_string(op[0])
for target_qubit in op[1]:
total_op_per_qubit[target_qubit] = factor

dense_op = pre_dense_op[0]
for i in pre_dense_op[1:]:
dense_op = torch.kron(dense_op, i)
dense_op = total_op_per_qubit[0]
for single_qubit_operator in total_op_per_qubit[1:]:
dense_op = torch.kron(dense_op, single_qubit_operator)

accum_res += coeff * dense_op
return DenseOperator(accum_res)
Loading

0 comments on commit 604dede

Please sign in to comment.