Skip to content

Commit

Permalink
fix formatting via black
Browse files Browse the repository at this point in the history
  • Loading branch information
epens94 committed Oct 24, 2024
1 parent b2c8367 commit a5268e1
Show file tree
Hide file tree
Showing 2 changed files with 549 additions and 500 deletions.
61 changes: 34 additions & 27 deletions src/schnetpack/nn/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import torch
import torch.nn as nn

__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF","BernsteinRBF","PhysNetBasisRBF"]
__all__ = [
"gaussian_rbf",
"GaussianRBF",
"GaussianRBFCentered",
"BesselRBF",
"BernsteinRBF",
"PhysNetBasisRBF",
]

from torch import nn as nn

Expand Down Expand Up @@ -111,18 +118,16 @@ def forward(self, inputs):


class BernsteinRBF(torch.nn.Module):


r"""Bernstein radial basis functions.
According to
According to
B_{v,n}(x) = \binom{n}{v} x^v (1 - x)^{n - v}
with
with
B as the Bernstein polynomial of degree v
binom{k}{n} as the binomial coefficient n! / (k! * (n - k)!)
they become in logaritmic form log(n!) - log(k!) - log((n - k)!)
n as index running from 0 to degree k
The logarithmic form of the k-th Bernstein polynominal of degree n is
log(B_{k}_{n}) = logBinomCoeff + k * log(x) - (n-k) * log(1-x)
Expand All @@ -133,12 +138,11 @@ class BernsteinRBF(torch.nn.Module):
logBinomCoeff is a scalar
k_term is a vector
n_k_term is also a vector
log to avoid numerical overflow errors, and ensure stability
"""

def __init__(
self, n_rbf: int, cutoff:float, init_alpha:float = 0.95):
def __init__(self, n_rbf: int, cutoff: float, init_alpha: float = 0.95):
"""
Args:
n_rbf: total number of Bernstein functions, :math:`N_g`.
Expand All @@ -154,52 +158,54 @@ def __init__(
n_k_idx = n_rbf - 1 - n_idx

# register buffers and parameters
self.register_buffer("cutoff",torch.tensor(cutoff))
self.register_buffer("cutoff", torch.tensor(cutoff))
self.register_buffer("b", b)
self.register_buffer("n", n_idx)
self.register_buffer("n_k", n_k_idx)
self.register_buffer("init_alpha",torch.tensor(init_alpha))
self.register_buffer("init_alpha", torch.tensor(init_alpha))

# log of factorial (n! or k! or n-k!)
def log_factorial(self,n):
def log_factorial(self, n):
# log of factorial degree n
return torch.sum(torch.log(torch.arange(1, n + 1)))

# calculate log binominal coefficient
def log_binomial_coefficient(self,n, k):
def log_binomial_coefficient(self, n, k):
# n_factorial - k_factorial - n_k_factorial
return self.log_factorial(n) - (self.log_factorial(k) + self.log_factorial(n - k))
return self.log_factorial(n) - (
self.log_factorial(k) + self.log_factorial(n - k)
)

# vector of log binominal coefficients
def calculate_log_binomial_coefficients(self,n_rbf):
def calculate_log_binomial_coefficients(self, n_rbf):
# store the log binomial coefficients
# Loop through each value from 0 to n_rbf-1
log_binomial_coeffs = [
self.log_binomial_coefficient(n_rbf - 1, x) for x in range(n_rbf)
]
return torch.tensor(log_binomial_coeffs)
return torch.tensor(log_binomial_coeffs)

def forward(self, inputs):
exp_x = -self.init_alpha * inputs[...,None]
exp_x = -self.init_alpha * inputs[..., None]
x = torch.exp(exp_x)
k_term = self.n * torch.where(self.n != 0, torch.log(x), torch.zeros_like(x))
n_k_term = self.n_k * torch.where(self.n_k != 0, torch.log(1 - x), torch.zeros_like(x))
n_k_term = self.n_k * torch.where(
self.n_k != 0, torch.log(1 - x), torch.zeros_like(x)
)
y = torch.exp(self.b + k_term + n_k_term)
return y


class PhysNetBasisRBF(torch.nn.Module):

class PhysNetBasisRBF(torch.nn.Module):
"""
Expand distances in the basis used in PhysNet (see https://arxiv.org/abs/1902.08408)
width (beta_k) = (2K^⁻1 * (1 - exp(-cutoff)))^-2)
center (mu_k) = equally spaced between exp(-cutoff) and 1
"""
def __init__(self, n_rbf: int, cutoff:float, trainable:bool):
"""

def __init__(self, n_rbf: int, cutoff: float, trainable: bool):
"""
Args:
n_rbf: total number of basis functions.
Expand All @@ -212,7 +218,7 @@ def __init__(self, n_rbf: int, cutoff:float, trainable:bool):
# compute offset and width of Gaussian functions
widths = ((2 / self.n_rbf) * (1 - torch.exp(torch.Tensor([-cutoff])))) ** (-2)
r_0 = torch.exp(torch.Tensor([-cutoff])).item()
centers = torch.linspace(r_0,1,self.n_rbf)
centers = torch.linspace(r_0, 1, self.n_rbf)

if trainable:
self.widths = torch.nn.Parameter(widths)
Expand All @@ -221,6 +227,7 @@ def __init__(self, n_rbf: int, cutoff:float, trainable:bool):
self.register_buffer("widths", widths)
self.register_buffer("centers", centers)


def forward(self, inputs: torch.Tensor):
return torch.exp(-abs(self.widths) * (torch.exp(-inputs[...,None]) - self.centers) ** 2)
return torch.exp(
-abs(self.widths) * (torch.exp(-inputs[..., None]) - self.centers) ** 2
)
Loading

0 comments on commit a5268e1

Please sign in to comment.