Skip to content

Commit

Permalink
streamlined projection options
Browse files Browse the repository at this point in the history
Still things to do, but for now this catches
many things.

One thing is that 'basis' projections are now called
hadamard (the proper name of the operation). While 'basis'
is still allowed it seems better to streamline a against
a common name.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Nov 25, 2024
1 parent 6b9dd25 commit f479933
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 88 deletions.
6 changes: 6 additions & 0 deletions docs/api/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ The typing types are shown below:
LatticeLike
LatticeOrGeometry
LatticeOrGeometryLike
ProjectionTypeMatrix
ProjectionTypeTrace
ProjectionTypeDiag
ProjectionTypeHadamard
ProjectionTypeHadamardAtoms
ProjectionType
SileLike
SparseMatrix
SparseMatrixExt
Expand Down
4 changes: 4 additions & 0 deletions src/sisl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def __getattr__(attr):
import sisl.constant as constant

return constant
if attr == "typing":
import sisl.typing as typing

return typing

raise AttributeError(f"module {__name__} has no attribute {attr}")

Expand Down
1 change: 1 addition & 0 deletions src/sisl/physics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"""

from ._common import *
from ._feature import *
from .distribution import *
from .sparse import *
Expand Down
44 changes: 44 additions & 0 deletions src/sisl/physics/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from enum import StrEnum

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'StrEnum' is not used.
from typing import get_args

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'get_args' is not used.

from sisl.typing import GaugeType, ProjectionType

__all__ = ["comply_gauge", "comply_projection"]


def comply_gauge(gauge: GaugeType) -> str:
"""Comply the gauge to one of two words: atom | cell"""
return {
"R": "cell",
"cell": "cell",
"r": "atom",
"orbital": "atom",
"orbitals": "atom",
"atom": "atom",
"atoms": "atom",
}[gauge]


def comply_projection(projection: ProjectionType) -> str:
"""Comply the projection to one of the allowed variants"""
return {
"matrix": "matrix",
"ij": "matrix",
"trace": "trace",
"sum": "trace",
"diagonal": "diagonal",
"diag": "diagonal",
"ii": "diagonal",
"hadamard": "hadamard",
"basis": "hadamard",
"orbital": "hadamard",
"orbitals": "hadamard",
"hadamard:atoms": "hadamard:atoms",
"atoms": "hadamard:atoms",
"atom": "hadamard:atoms",
}[projection]
15 changes: 1 addition & 14 deletions src/sisl/physics/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,7 @@

import numpy as np

__all__ = ["yield_manifolds", "comply_gauge"]


def comply_gauge(gauge: GaugeType) -> str:
"""Comply the gauge to one of two words: atom | cell"""
return {
"R": "cell",
"cell": "cell",
"r": "atom",
"orbital": "atom",
"orbitals": "atom",
"atom": "atom",
"atoms": "atom",
}[gauge]
__all__ = ["yield_manifolds"]


def yield_manifolds(values, atol: float = 0.1, axis: int = -1) -> Iterator[list]:
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/physics/_matrix_ddk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import numpy as np

cimport numpy as np

from ._feature import comply_gauge
from ._common import comply_gauge
from ._matrix_phase3 import *
from ._matrix_phase3_nc import *
from ._matrix_phase3_so import *
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/physics/_matrix_dk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import numpy as np

cimport numpy as np

from ._feature import comply_gauge
from ._common import comply_gauge
from ._matrix_phase3 import *
from ._matrix_phase3_nc import *
from ._matrix_phase3_so import *
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/physics/_matrix_k.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import numpy as np

cimport numpy as np

from ._feature import comply_gauge
from ._common import comply_gauge
from ._matrix_phase import *
from ._matrix_phase_nc import *
from ._matrix_phase_nc_diag import *
Expand Down
81 changes: 66 additions & 15 deletions src/sisl/physics/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@
progressbar,
warn,
)
from sisl.typing import CartesianAxisStrLiteral
from sisl.physics._common import comply_projection
from sisl.typing import (
CartesianAxisStrLiteral,
ProjectionType,
ProjectionTypeHadamard,
ProjectionTypeHadamardAtoms,
)
from sisl.typing._physics import ProjectionTypeDiag
from sisl.utils.misc import direction

if TYPE_CHECKING:
Expand Down Expand Up @@ -329,7 +336,7 @@ def COP(E, eig, state, M, distribution="gaussian", atol: float = 1e-10):
distribution : func or str, optional
a function that accepts :math:`E-\epsilon` as argument and calculates the
distribution function.
atol : float, optional
atol :
tolerance value where the distribution should be above before
considering an eigenstate to contribute to an energy point,
a higher value means that more energy points are discarded and so the calculation
Expand Down Expand Up @@ -437,7 +444,20 @@ def new_list(bools, tmp, we):


@set_module("sisl.physics.electron")
def spin_moment(state, S=None, project: bool = False):
@deprecate_argument(
"project",
"projection",
"argument project has been deprecated in favor of projection",
"0.15",
"0.16",
)
def spin_moment(
state,
S=None,
projection: Union[
ProjectionTypeTrace, ProjectionTypeDiag, ProjectionTypeHadamard, True, False
] = "diagonal",
):
r""" Spin magnetic moment (spin texture) and optionally orbitally resolved moments
This calculation only makes sense for non-colinear calculations.
Expand All @@ -458,7 +478,7 @@ def spin_moment(state, S=None, project: bool = False):
\\
\mathbf{S}_\alpha^z &= \langle \psi_\alpha | \boldsymbol\sigma_z \mathbf S | \psi_\alpha \rangle
If `project` is true, the above will be the orbitally resolved quantities.
If `projection` is orbitals/basis/true, the above will be the orbitally resolved quantities.
Parameters
----------
Expand All @@ -468,8 +488,8 @@ def spin_moment(state, S=None, project: bool = False):
overlap matrix used in the :math:`\langle\psi|\mathbf S|\psi\rangle` calculation. If `None` the identity
matrix is assumed. The overlap matrix should correspond to the system and :math:`\mathbf k` point the eigenvectors
has been evaluated at.
project: bool, optional
whether the spin-moments will be orbitally resolved or not
projection:
how the projection should be done
Notes
-----
Expand All @@ -485,10 +505,15 @@ def spin_moment(state, S=None, project: bool = False):
Returns
-------
numpy.ndarray
spin moments per state with final dimension ``(3, state.shape[0])``, or ``(3, state.shape[0], state.shape[1]//2)`` if project is true
spin moments per state with final dimension ``(3, state.shape[0])``, or ``(3,
state.shape[0], state.shape[1]//2)`` if projection is orbitals/basis/true
"""
if state.ndim == 1:
return spin_moment(state.reshape(1, -1), S, project)[0]
return spin_moment(state.reshape(1, -1), S, projection)[0]

if isinstance(projection, bool):
projection = "hadamard" if projection else "diagonal"
projection = comply_projection(projection)

if S is None:
S = _FakeMatrix(state.shape[1] // 2, state.shape[1] // 2)
Expand All @@ -498,7 +523,7 @@ def spin_moment(state, S=None, project: bool = False):

# see PDOS for details related to the spin-box calculations

if project:
if projection == "hadamard":
s = empty(
[3, state.shape[0], state.shape[1] // 2],
dtype=state.real.dtype,
Expand All @@ -514,7 +539,7 @@ def spin_moment(state, S=None, project: bool = False):
s[0, i] = D1.real + D2.real
s[1, i] = D2.imag - D1.imag

else:
elif projection == "diagonal":
s = empty([3, state.shape[0]], dtype=state.real.dtype)

# TODO consider doing this all in a few lines
Expand All @@ -529,6 +554,20 @@ def spin_moment(state, S=None, project: bool = False):
s[0, i] = D[1, 0].real + D[0, 1].real
s[1, i] = D[0, 1].imag - D[1, 0].imag

elif projection == "trace":
s = empty([3], dtype=state.real.dtype)

for i in range(len(state)):
cs = conj(state[i]).reshape(-1, 2)
Sstate = S @ state[i].reshape(-1, 2)
D = cs.T @ Sstate
s[2] = (D[0, 0].real - D[1, 1].real).sum()
s[0] = (D[1, 0].real + D[0, 1].real).sum()
s[1] = (D[0, 1].imag - D[1, 0].imag).sum()

else:
raise ValueError(f"spin_moment got wrong 'projection' argument: {projection}.")

return s


Expand Down Expand Up @@ -561,7 +600,7 @@ def spin_contamination(state_alpha, state_beta, S=None, sum: bool = True):
have been evaluated at.
sum:
whether the spin-contamination should be summed for all states (a single number returned).
If false, a spin-contamination per state per spin-channel will be returned.
If sum, a spin-contamination per state per spin-channel will be returned.
Notes
-----
Expand Down Expand Up @@ -1671,7 +1710,12 @@ def Sk(self, format=None):
"0.15",
"0.16",
)
def norm2(self, projection: Literal["sum", "orbitals", "basis", "atoms"] = "sum"):
def norm2(
self,
projection: Union[
ProjectionType, ProjectionTypeHadamard, ProjectionTypeHadamardAtoms
] = "diagonal",
):
r"""Return a vector with the norm of each state :math:`\langle\psi|\mathbf S|\psi\rangle`
:math:`\mathbf S` is the overlap matrix (or basis), for orthogonal basis
Expand All @@ -1693,7 +1737,14 @@ def norm2(self, projection: Literal["sum", "orbitals", "basis", "atoms"] = "sum"
"""
return self.inner(matrix=self.Sk(), projection=projection)

def spin_moment(self, project=False):
@deprecate_argument(
"project",
"projection",
"argument project has been deprecated in favor of projection",
"0.15",
"0.16",
)
def spin_moment(self, projection="diagonal"):
r"""Calculate spin moment from the states
This routine calls `~sisl.physics.electron.spin_moment` with appropriate arguments
Expand All @@ -1703,10 +1754,10 @@ def spin_moment(self, project=False):
Parameters
----------
project : bool, optional
projection:
whether the moments are orbitally resolved or not
"""
return spin_moment(self.state, self.Sk(), project=project)
return spin_moment(self.state, self.Sk(), projection=projection)

def wavefunction(self, grid, spinor=0, eta=None):
r"""Expand the coefficients as the wavefunction on `grid` *as-is*
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/physics/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sisl._internal import set_module
from sisl.typing import GaugeType

from ._feature import comply_gauge
from ._common import comply_gauge
from .distribution import get_distribution
from .electron import EigenstateElectron, EigenvalueElectron
from .sparse import SparseOrbitalBZSpin
Expand Down
Loading

0 comments on commit f479933

Please sign in to comment.