Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update: inertia benefit from caching + move from pkg_resources to importlib.metada #714

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9ac44ce
starting to add caching to inertia moment calculation during the simu…
m-rauen Nov 26, 2024
37ad0e3
correct caching problem of not handle mutable objects
m-rauen Dec 3, 2024
014b59d
move ignore_unhashable to _misc module + trying to resolve circular i…
m-rauen Dec 3, 2024
7216147
circular import solved + movem from pkg_resources to importlib
m-rauen Dec 4, 2024
184a193
updated gitignore
m-rauen Dec 4, 2024
ddb4846
change ignore to copy strategy for cache + add docstring
m-rauen Dec 4, 2024
edae933
change maxsize of copy_unhashable
m-rauen Dec 5, 2024
bcce554
changing hash strategy to eliminate errors
m-rauen Dec 10, 2024
da26124
starting to add caching to inertia moment calculation during the simu…
m-rauen Nov 26, 2024
88b20e5
correct caching problem of not handle mutable objects
m-rauen Dec 3, 2024
07da1be
move ignore_unhashable to _misc module + trying to resolve circular i…
m-rauen Dec 3, 2024
6660655
circular import solved + movem from pkg_resources to importlib
m-rauen Dec 4, 2024
687d7c5
updated gitignore
m-rauen Dec 4, 2024
98be111
change ignore to copy strategy for cache + add docstring
m-rauen Dec 4, 2024
ca3e412
change maxsize of copy_unhashable
m-rauen Dec 5, 2024
6c1f0ad
changing hash strategy to eliminate errors
m-rauen Dec 10, 2024
70ae12e
changed scipy.misc.derivative to findiff.Diff (scipy derivative remov…
m-rauen Jan 28, 2025
37d2915
changed scipy.misc.derivative to findiff.Diff (scipy derivative remov…
m-rauen Jan 28, 2025
e9ce98d
resolve funky CI + cache corrected implemented
m-rauen Jan 28, 2025
8a32b23
added 'findiff' to the poetrylock file, since it's the new derivative…
m-rauen Jan 28, 2025
557092c
updated poetrylock file + cleaned some code
m-rauen Jan 28, 2025
29df5d7
testing funky CI
m-rauen Jan 28, 2025
61bde8b
testing
m-rauen Jan 29, 2025
f2298e4
added derivative func to misc module instead of using findiff
m-rauen Jan 31, 2025
3d638ac
forgetted to call _misc._derivative in the API module
m-rauen Jan 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions overreact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__docformat__ = "restructuredtext"

import pkg_resources as _pkg_resources
from importlib.metadata import version

from overreact.api import (
get_enthalpies,
Expand Down Expand Up @@ -48,7 +48,7 @@
"unparse_reactions",
]

__version__ = _pkg_resources.get_distribution(__name__).version
__version__ = version(__name__)
__license__ = "MIT" # I'm too lazy to get it from setup.py...

__headline__ = "📈 Create and analyze chemical microkinetic models built from computational chemistry data."
Expand Down
207 changes: 205 additions & 2 deletions overreact/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,217 @@
from __future__ import annotations

import contextlib
from functools import lru_cache as cache
from functools import lru_cache as cache, wraps
from copy import deepcopy

import numpy as np
from numpy import arange, newaxis, hstack, prod, array
from scipy.stats import cauchy, norm

import overreact as rx
from overreact import _constants as constants

def _central_diff_weights(Np, ndiv=1):
"""
Return weights for an Np-point central derivative.

Assumes equally-spaced function points.

If weights are in the vector w, then
derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)

Parameters
----------
Np : int
Number of points for the central derivative.
ndiv : int, optional
Number of divisions. Default is 1.

Returns
-------
w : ndarray
Weights for an Np-point central derivative. Its size is `Np`.

Notes
-----
Can be inaccurate for a large number of points.

Examples
--------
We can calculate a derivative value of a function.

>>> def f(x):
... return 2 * x**2 + 3
>>> x = 3.0 # derivative point
>>> h = 0.1 # differential step
>>> Np = 3 # point number for central derivative
>>> weights = _central_diff_weights(Np) # weights for first derivative
>>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
>>> sum(w * v for (w, v) in zip(weights, vals))/h
11.79999999999998

This value is close to the analytical solution:
f'(x) = 4x, so f'(3) = 12

References
----------
.. [1] https://en.wikipedia.org/wiki/Finite_difference

"""
if Np < ndiv + 1:
raise ValueError(
"Number of points must be at least the derivative order + 1."
)
if Np % 2 == 0:
raise ValueError("The number of points must be odd.")
from scipy import linalg

ho = Np >> 1
x = arange(-ho, ho + 1.0)
x = x[:, newaxis]
X = x**0.0
for k in range(1, Np):
X = hstack([X, x**k])
w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]
return w


def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
"""
Find the nth derivative of a function at a point.

Given a function, use a central difference formula with spacing `dx` to
compute the nth derivative at `x0`.

Parameters
----------
func : function
Input function.
x0 : float
The point at which the nth derivative is found.
dx : float, optional
Spacing.
n : int, optional
Order of the derivative. Default is 1.
args : tuple, optional
Arguments
order : int, optional
Number of points to use, must be odd.

Notes
-----
Decreasing the step size too small can result in round-off error.

Examples
--------
>>> def f(x):
... return x**3 + x**2
>>> _derivative(f, 1.0, dx=1e-6)
4.9999999999217337

"""
first_deriv_weight_map = {
3: array([-1, 0, 1]) / 2.0,
5: array([1, -8, 0, 8, -1]) / 12.0,
7: array([-1, 9, -45, 0, 45, -9, 1]) / 60.0,
9: array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0,
}

second_deriv_weight_map = {
3: array([1, -2.0, 1]),
5: array([-1, 16, -30, 16, -1]) / 12.0,
7: array([2, -27, 270, -490, 270, -27, 2]) / 180.0,
9: array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9]) / 5040.0
}

if order < n + 1:
raise ValueError(
"'order' (the number of points used to compute the derivative), "
"must be at least the derivative order 'n' + 1."
)
elif order % 2 == 0:
raise ValueError(
"'order' (the number of points used to compute the derivative) "
"must be odd."
)
else:
pass

# pre-computed for n=1 and 2 and low-order for speed.
if n == 1:
if order == 3:
weights = first_deriv_weight_map.get(3)
elif n == 1 and order == 5:
weights = first_deriv_weight_map.get(5)
elif n == 1 and order == 7:
weights = first_deriv_weight_map.get(7)
elif n == 1 and order == 9:
weights = first_deriv_weight_map.get(9)
else:
weights = _central_diff_weights(order, 1)
elif n == 2:
if order == 3:
weights = second_deriv_weight_map.get(3)
elif n == 2 and order == 5:
weights = second_deriv_weight_map.get(5)
elif n == 2 and order == 7:
weights = second_deriv_weight_map.get(7)
elif n == 2 and order == 9:
weights = second_deriv_weight_map.get(9)
else:
weights = _central_diff_weights(order, 2)
else:
weights = _central_diff_weights(order, n)

val = 0.0
ho = order >> 1
for k in range(order):
val += weights[k] * func(x0 + (k - ho) * dx, *args)
return val / prod((dx,) * n, axis=0)

# TODO(mrauen): write and add docstring here
def make_hashable(obj):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget to add the python function annotations here and in the copy_unhashable() function. I will add this piece and send a new PR asap

if isinstance(obj, np.ndarray):
return (tuple(obj.shape), tuple(obj.ravel()))
else:
return obj

# TODO(mrauen): write and add docstring here
def copy_unhashable(maxsize=128, typed=False):
def decorator(func):
@cache(maxsize=maxsize, typed=typed)
@wraps(func)
def cached_func(*hashable_args, **hashable_kwargs):
args = []
kwargs = {}

def convert_back(arg):
if isinstance(arg, tuple) and len(arg) == 2:
shape, flat_data = arg
if isinstance(shape, tuple) and isinstance(flat_data, tuple):
return np.array(flat_data).reshape(shape)
return arg

for arg in hashable_args:
args.append(convert_back(arg))
for k, v in hashable_kwargs.items():
kwargs[k] = convert_back(v)
args = tuple(args)
return func(*args, **kwargs)

def wrapper(*args, **kwargs):
wrapper_hashable_args = []
wrapper_hashable_kwargs = {}

for arg in args:
wrapper_hashable_args.append(make_hashable(arg))
for k,v in kwargs.items():
wrapper_hashable_kwargs[k] = make_hashable(v)
wrapper_hashable_args = tuple(wrapper_hashable_args)
return deepcopy(cached_func(*wrapper_hashable_args, **wrapper_hashable_kwargs))

return wrapper
return decorator

def _find_package(package):
"""Check if a package exists without importing it.
Expand Down Expand Up @@ -739,7 +942,7 @@ def _is_prime(num):
return primes


@cache(maxsize=1000000)
@cache
def _vdc(n, b=2):
"""Help haltonspace."""
res, denom = 0, 1
Expand Down
2 changes: 1 addition & 1 deletion overreact/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from typing import TYPE_CHECKING

import numpy as np
from scipy.misc import derivative

import overreact as rx
from overreact import _constants as constants
from overreact import coords, rates, tunnel
from overreact._misc import _derivative as derivative

if TYPE_CHECKING:
from overreact.core import Scheme
Expand Down
4 changes: 3 additions & 1 deletion overreact/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import overreact as rx
from overreact import _constants as constants
from overreact import _misc as _misc

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1679,8 +1680,9 @@ def gyradius(atommasses, atomcoords, method="iupac"):
else:
msg = f"unavailable method: '{method}'"
raise ValueError(msg)



@rx._misc.copy_unhashable()
def inertia(atommasses, atomcoords, align=True):
r"""Calculate primary moments and axes from the inertia tensor.

Expand Down
2 changes: 1 addition & 1 deletion overreact/thermo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging

import numpy as np
from scipy.misc import derivative
from overreact._misc import _derivative as derivative
from scipy.special import factorial

import overreact as rx
Expand Down
6 changes: 3 additions & 3 deletions overreact/thermo/_solv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import logging

import numpy as np
from scipy.misc import derivative

import overreact as rx
from overreact import _constants as constants
from overreact._misc import _derivative as derivative
from overreact import coords

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -123,14 +123,14 @@ def func(temperature, solvent):
+ (y * solvent.Vm * pressure / (constants.R * temperature)) * ratio**3
)
return -constants.R * temperature * gamma

cavity_entropy = derivative(
func,
x0=temperature,
dx=dx,
n=1,
order=order,
args=(environment,),
order=order,
)
logger.info(f"cavity entropy = {cavity_entropy} J/mol·K")
return cavity_entropy
Expand Down
Loading