Skip to content

Commit

Permalink
Merge branch 'master' of github.com:GAMES-UChile/mogptk
Browse files Browse the repository at this point in the history
  • Loading branch information
tdewolff committed Dec 7, 2023
2 parents 4d2e0aa + fbe7acf commit 29747a5
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 341 deletions.
2 changes: 1 addition & 1 deletion mogptk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
.. include:: ./documentation.md
"""
from .gpr.config import *
from .gpr import *
from .util import *

from .transformer import *
Expand Down
50 changes: 20 additions & 30 deletions mogptk/gpr/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@
import copy
from . import Parameter, config

class Kernel:
class Kernel(torch.nn.Module):
"""
Base kernel.
Args:
input_dims (int): Number of input dimensions.
active_dims (list of int): Indices of active dimensions of shape (input_dims,).
name (str): Kernel name.
"""
def __init__(self, input_dims=None, active_dims=None, name=None):
if name is None:
name = self.__class__.__name__
if name.endswith('Kernel') and name != 'Kernel':
name = name[:-6]
def __init__(self, input_dims=None, active_dims=None):
super().__init__()

self.input_dims = input_dims
self.active_dims = active_dims # checks input
self.output_dims = None
self.name = name

def __call__(self, X1, X2=None):
"""
Expand All @@ -38,14 +33,13 @@ def __call__(self, X1, X2=None):

def __setattr__(self, name, val):
if name == 'train':
from .util import _find_parameters
for _, p in _find_parameters(self):
for p in self.parameters():
p.train = val
return
if hasattr(self, name) and isinstance(getattr(self, name), Parameter):
raise AttributeError("parameter is read-only, use Parameter.assign()")
if isinstance(val, Parameter) and val.name is None:
val.name = name
if isinstance(val, Parameter) and val._name is None:
val._name = 'kernel.' + name
super().__setattr__(name, val)

def _active_input(self, X1, X2=None):
Expand Down Expand Up @@ -201,10 +195,9 @@ class Kernels(Kernel):
Args:
kernels (list of Kernel): Kernels.
name (str): Kernel name.
"""
def __init__(self, *kernels, name="Kernels"):
super().__init__(name=name)
def __init__(self, *kernels):
super().__init__()
kernels = self._check_kernels(kernels)

i = 0
Expand Down Expand Up @@ -237,10 +230,9 @@ class AddKernel(Kernels):
Args:
kernels (list of Kernel): Kernels.
name (str): Kernel name.
"""
def __init__(self, *kernels, name="Add"):
super().__init__(*kernels, name=name)
def __init__(self, *kernels):
super().__init__(*kernels)

def K(self, X1, X2=None):
return torch.stack([kernel(X1, X2) for kernel in self.kernels], dim=2).sum(dim=2)
Expand All @@ -254,10 +246,9 @@ class MulKernel(Kernels):
Args:
kernels (list of Kernel): Kernels.
name (str): Kernel name.
"""
def __init__(self, *kernels, name="Mul"):
super().__init__(*kernels, name=name)
def __init__(self, *kernels):
super().__init__(*kernels)

def K(self, X1, X2=None):
return torch.stack([kernel(X1, X2) for kernel in self.kernels], dim=2).prod(dim=2)
Expand All @@ -272,13 +263,12 @@ class MixtureKernel(AddKernel):
Args:
kernel (Kernel): Single kernel.
Q (int): Number of mixtures.
name (str): Kernel name.
"""
def __init__(self, kernel, Q, name="Mixture"):
def __init__(self, kernel, Q):
if not issubclass(type(kernel), Kernel):
raise ValueError("must pass kernel")
kernels = self._check_kernels(kernel, Q)
super().__init__(*kernels, name=name)
super().__init__(*kernels)

class AutomaticRelevanceDeterminationKernel(MulKernel):
"""
Expand All @@ -287,15 +277,16 @@ class AutomaticRelevanceDeterminationKernel(MulKernel):
Args:
kernel (Kernel): Single kernel.
input_dims (int): Number of input dimensions.
name (str): Kernel name.
"""
def __init__(self, kernel, input_dims, name="ARD"):
def __init__(self, kernel, input_dims):
if not issubclass(type(kernel), Kernel):
raise ValueError("must pass kernel")
kernels = self._check_kernels(kernel, input_dims)
for i, kernel in enumerate(kernels):
kernel.set_active_dims(i)
super().__init__(*kernels, name=name)
super().__init__(*kernels)

cb = None

class MultiOutputKernel(Kernel):
"""
Expand All @@ -307,12 +298,11 @@ class MultiOutputKernel(Kernel):
output_dims (int): Number of output dimensions.
input_dims (int): Number of input dimensions.
active_dims (list of int): Indices of active dimensions of shape (input_dims,).
name (str): Kernel name.
"""
# TODO: seems to accumulate a lot of memory in the loops to call Ksub, perhaps it's keeping the computational graph while indexing?

def __init__(self, output_dims, input_dims=None, active_dims=None, name=None):
super().__init__(input_dims, active_dims, name)
def __init__(self, output_dims, input_dims=None, active_dims=None):
super().__init__(input_dims, active_dims)
self.output_dims = output_dims

def _check_input(self, X1, X2=None):
Expand Down
Loading

0 comments on commit 29747a5

Please sign in to comment.