Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy with Torch #21

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/lint_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10' #don't cache this, it will create a cache without all the good stuff
python-version: '3.10'
cache: 'pip'
- run: pip install pre-commit pyproject-flake8
- run: pre-commit install
- run: pre-commit run --all-files
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ repos:
hooks:
- id: mypy
exclude: examples|test|ci
additional_dependencies: [torch==2.5.0] # The version in pyproject.toml must match
4 changes: 2 additions & 2 deletions emu_base/math/krylov_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(


def krylov_exp_impl(
op: Callable,
op: Callable[[torch.Tensor], torch.Tensor],
v: torch.Tensor,
is_hermitian: bool, # note: complex-proportional to its adjoint is enough
exp_tolerance: float,
Expand Down Expand Up @@ -103,7 +103,7 @@ def krylov_exp_impl(


def krylov_exp(
op: torch.Tensor,
op: Callable[[torch.Tensor], torch.Tensor],
v: torch.Tensor,
exp_tolerance: float,
norm_tolerance: float,
Expand Down
7 changes: 5 additions & 2 deletions emu_base/pulser_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,18 @@ def _xy_interaction(sequence: pulser.Sequence) -> torch.Tensor:
qubit_positions = _get_qubit_positions(sequence.register)
interaction_matrix = torch.zeros(num_qubits, num_qubits)
mag_field = torch.tensor(sequence.magnetic_field) # by default [0.0,0.0,30.0]
mag_norm = torch.norm(mag_field)
mag_norm = torch.linalg.norm(mag_field)

for numi in range(len(qubit_positions)):
for numj in range(numi + 1, len(qubit_positions)):
cosine = 0
if mag_norm >= 1e-8: # selected by hand
cosine = torch.dot(
(qubit_positions[numi] - qubit_positions[numj]), mag_field
) / (torch.norm(qubit_positions[numi] - qubit_positions[numj]) * mag_norm)
) / (
torch.linalg.norm(qubit_positions[numi] - qubit_positions[numj])
* mag_norm
)

interaction_matrix[numi][numj] = (
c3 # check this value with pulser people
Expand Down
8 changes: 4 additions & 4 deletions emu_base/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch


def dist2(left: torch.tensor, right: torch.tensor) -> torch.Tensor:
return torch.norm(left - right) ** 2
pablolh marked this conversation as resolved.
Show resolved Hide resolved
def dist2(left: torch.Tensor, right: torch.Tensor) -> float:
return torch.dist(left, right).item() ** 2


def dist3(left: torch.tensor, right: torch.tensor) -> torch.Tensor:
return torch.norm(left - right) ** 3
def dist3(left: torch.Tensor, right: torch.Tensor) -> float:
return torch.dist(left, right).item() ** 3
22 changes: 11 additions & 11 deletions emu_mps/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


def add_factors(
left: list[torch.tensor], right: list[torch.tensor]
) -> list[torch.tensor]:
left: list[torch.Tensor], right: list[torch.Tensor]
) -> list[torch.Tensor]:
"""
Direct sum algorithm implementation to sum two tensor trains (MPS/MPO).
It assumes the left and right bond are along the dimension 0 and -1 of each tensor.
Expand Down Expand Up @@ -52,19 +52,19 @@ def add_factors(


def scale_factors(
factors: list[torch.tensor], scalar: complex, *, which: int
) -> list[torch.tensor]:
factors: list[torch.Tensor], scalar: complex, *, which: int
) -> list[torch.Tensor]:
"""
Returns a new list of factors where the tensor at the given index is scaled by `scalar`.
"""
return [scalar * f if i == which else f for i, f in enumerate(factors)]


def zip_right_step(
slider: torch.tensor,
top: torch.tensor,
bottom: torch.tensor,
) -> torch.tensor:
slider: torch.Tensor,
top: torch.Tensor,
bottom: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns a new `MPS/O` factor of the result of the multiplication MPO @ MPS/O,
and the updated slider, performing a single step of the
Expand Down Expand Up @@ -117,10 +117,10 @@ def zip_right_step(


def zip_right(
top_factors: list[torch.tensor],
bottom_factors: list[torch.tensor],
top_factors: list[torch.Tensor],
bottom_factors: list[torch.Tensor],
config: Optional[MPSConfig] = None,
) -> list[torch.tensor]:
) -> list[torch.Tensor]:
"""
Returns a new matrix product, resulting from applying `top` to `bottom`.
The resulting factors are:
Expand Down
4 changes: 2 additions & 2 deletions emu_mps/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def make_H(

interactions_to_keep = _get_interactions_to_keep(interaction_matrix)

cores = [_first_factor(interactions_to_keep[0].item())]
cores = [_first_factor(interactions_to_keep[0].item() != 0.0)]

if nqubits > 2:
for i in range(1, middle):
Expand Down Expand Up @@ -395,7 +395,7 @@ def make_H(
)
)
if nqubits == 2:
scale = interaction_matrix[0, 1]
scale = interaction_matrix[0, 1].item()
elif interactions_to_keep[-1][0]:
scale = 1.0
else:
Expand Down
43 changes: 23 additions & 20 deletions emu_mps/mpo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import itertools
from typing import Any, List, cast, Iterable, Optional
from typing import Any, List, Iterable, Optional

import torch

Expand Down Expand Up @@ -175,14 +175,15 @@ def from_operator_string(
Returns:
the operator in MPO form.
"""
operators_with_tensors: dict[str, torch.Tensor | QuditOp] = dict(operators)

_validate_operator_targets(operations, nqubits)

basis = set(basis)
if basis == {"r", "g"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"gg": torch.tensor(
[[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
).reshape(1, 2, 2, 1),
Expand All @@ -197,9 +198,9 @@ def from_operator_string(
).reshape(1, 2, 2, 1),
}
elif basis == {"0", "1"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"00": torch.tensor(
[[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
).reshape(1, 2, 2, 1),
Expand All @@ -218,26 +219,28 @@ def from_operator_string(

mpos = []
for coeff, tensorop in operations:
# this function will recurse through the operators, and replace any definitions
# this function will recurse through the operators_with_tensors,
# and replace any definitions
# in terms of strings by the computed tensor
def replace_operator_string(op: QuditOp | torch.Tensor) -> torch.Tensor:
if isinstance(op, dict):
for opstr, coeff in op.items():
tensor = replace_operator_string(operators[opstr])
operators[opstr] = tensor
op[opstr] = tensor * coeff
op = sum(cast(list[torch.Tensor], op.values()))
return op
if isinstance(op, torch.Tensor):
return op

result = torch.zeros(1, 2, 2, 1, dtype=torch.complex128)
for opstr, coeff in op.items():
tensor = replace_operator_string(operators_with_tensors[opstr])
operators_with_tensors[opstr] = tensor
result += tensor * coeff
return result

factors = [
torch.eye(2, 2, dtype=torch.complex128).reshape(1, 2, 2, 1)
] * nqubits

for i, op in enumerate(tensorop):
tensorop[i] = (replace_operator_string(op[0]), op[1])

for op in tensorop:
for i in op[1]:
factors[i] = op[0]
factor = replace_operator_string(op[0])
for target_qubit in op[1]:
factors[target_qubit] = factor

mpos.append(coeff * MPO(factors, **kwargs))
return sum(mpos[1:], start=mpos[0])
2 changes: 1 addition & 1 deletion emu_mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _sample_implementation(self, rnd_vector: torch.Tensor) -> str:
num_qubits = len(self.factors)

bitstring = ""
acc_mps_j: torch.tensor = self.factors[0]
acc_mps_j: torch.Tensor = self.factors[0]

for qubit in range(num_qubits):
# comp_basis is a projector: 0 is for ket |0> and 1 for ket |1>
Expand Down
7 changes: 3 additions & 4 deletions emu_mps/mps_backend_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def progress(self) -> None:
)
if not self.has_lindblad_noise:
# Free memory because it won't be used anymore
self.right_baths[-2] = None
self.right_baths[-2] = torch.zeros(0)

self._evolve(self.tdvp_index, dt=-delta_time / 2)
self.left_baths.pop()
Expand Down Expand Up @@ -453,13 +453,12 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
"""

jump_threshold: float
aggregated_lindblad_ops: Optional[torch.Tensor]
aggregated_lindblad_ops: torch.Tensor
norm_gap_before_jump: float
root_finder: Optional[BrentsRootFinder]

def __init__(self, config: MPSConfig, pulser_data: PulserData):
super().__init__(config, pulser_data)
self.aggregated_lindblad_ops = None
self.lindblad_ops = pulser_data.lindblad_ops
self.root_finder = None

Expand Down Expand Up @@ -520,7 +519,7 @@ def do_random_quantum_jump(self) -> None:
for qubit in range(self.state.num_sites)
for op in self.lindblad_ops
],
weights=jump_operator_weights.reshape(-1),
weights=jump_operator_weights.reshape(-1).tolist(),
)[0]

self.state.apply(jumped_qubit_index, jump_operator)
Expand Down
3 changes: 2 additions & 1 deletion emu_mps/tdvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def new_right_bath(
) -> torch.Tensor:
bath = torch.tensordot(state, bath, ([2], [2]))
bath = torch.tensordot(op.to(bath.device), bath, ([2, 3], [1, 3]))
return torch.tensordot(state.conj(), bath, ([1, 2], [1, 3]))
bath = torch.tensordot(state.conj(), bath, ([1, 2], [1, 3]))
return bath


"""
Expand Down
9 changes: 5 additions & 4 deletions emu_mps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ def new_left_bath(
# this order is more efficient than contracting the op first in general
bath = torch.tensordot(bath, state.conj(), ([0], [0]))
bath = torch.tensordot(bath, op.to(bath.device), ([0, 2], [0, 1]))
return torch.tensordot(bath, state, ([0, 2], [0, 1]))
bath = torch.tensordot(bath, state, ([0, 2], [0, 1]))
return bath


def _determine_cutoff_index(d: torch.Tensor, max_error: float) -> int:
assert max_error > 0
squared_max_error = max_error * max_error
acc = 0
acc = 0.0
for i in range(d.shape[0]):
acc += d[i]
acc += d[i].item()
if acc > squared_max_error:
return i
return 0 # type: ignore[no-any-return]
Expand Down Expand Up @@ -60,7 +61,7 @@ def split_tensor(


def truncate_impl(
factors: list[torch.tensor],
factors: list[torch.Tensor],
config: MPSConfig,
) -> None:
"""
Expand Down
52 changes: 27 additions & 25 deletions emu_sv/dense_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import itertools
from typing import Any, cast, Iterable
from typing import Any, Iterable

import torch
from emu_base.base_classes.operator import FullOp, QuditOp
Expand Down Expand Up @@ -114,7 +114,7 @@ def expect(self, state: State) -> float | complex:
), "currently, only expectation values of StateVectors are \
supported"

return torch.vdot(state.vector, self.matrix @ state.vector) # type: ignore [no-any-return]
return torch.vdot(state.vector, self.matrix @ state.vector).item()
pablolh marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def from_operator_string(
Expand All @@ -140,20 +140,22 @@ def from_operator_string(

_validate_operator_targets(operations, nqubits)

operators_with_tensors: dict[str, torch.Tensor | QuditOp] = dict(operators)

basis = set(basis)
if basis == {"r", "g"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"gg": torch.tensor([[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128),
"gr": torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128),
"rg": torch.tensor([[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128),
"rr": torch.tensor([[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128),
}
elif basis == {"0", "1"}:
# operators will now contain the basis for single qubit ops, and potentially
# user defined strings in terms of these
operators |= {
# operators_with_tensors will now contain the basis for single qubit ops,
# and potentially user defined strings in terms of these
operators_with_tensors |= {
"00": torch.tensor([[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128),
"01": torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128),
"10": torch.tensor([[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128),
Expand All @@ -164,29 +166,29 @@ def from_operator_string(

accum_res = torch.zeros(2**nqubits, 2**nqubits, dtype=torch.complex128)
for coeff, tensorop in operations:
# this function will recurse through the operators, and replace any definitions
# in terms of strings by the computed matrix
# this function will recurse through the operators_with_tensors,
# and replace any definitions in terms of strings by the computed matrix
def replace_operator_string(op: QuditOp | torch.Tensor) -> torch.Tensor:
if isinstance(op, dict):
for opstr, coeff in op.items():
tensor = replace_operator_string(operators[opstr])
operators[opstr] = tensor
op[opstr] = tensor * coeff
op = sum(cast(list[torch.Tensor], op.values()))
return op
if isinstance(op, torch.Tensor):
return op

pre_dense_op = [torch.eye(2, 2, dtype=torch.complex128)] * nqubits
result = torch.zeros(2, 2, dtype=torch.complex128)
for opstr, coeff in op.items():
tensor = replace_operator_string(operators_with_tensors[opstr])
operators_with_tensors[opstr] = tensor
result += tensor * coeff
return result

for i, op in enumerate(tensorop):
tensorop[i] = (replace_operator_string(op[0]), op[1])
total_op_per_qubit = [torch.eye(2, 2, dtype=torch.complex128)] * nqubits

for op in tensorop:
for i in op[1]:
pre_dense_op[i] = op[0]
factor = replace_operator_string(op[0])
for target_qubit in op[1]:
total_op_per_qubit[target_qubit] = factor

dense_op = pre_dense_op[0]
for i in pre_dense_op[1:]:
dense_op = torch.kron(dense_op, i)
dense_op = total_op_per_qubit[0]
for single_qubit_operator in total_op_per_qubit[1:]:
dense_op = torch.kron(dense_op, single_qubit_operator)

accum_res += coeff * dense_op
return DenseOperator(accum_res)
Loading
Loading