Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: nlds_lib, lds_lib, and scripts -> JSL #663

Merged
merged 14 commits into from
Dec 30, 2021
68 changes: 7 additions & 61 deletions scripts/bootstrap_filter_demo.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,14 @@
# Demo of the bootstrap filter under a
# nonlinear discrete system

# Author: Gerardo Durán-Martín (@gerdm)

import superimport
# !pip install git+git://github.com/probml/jsl

import jax
import nlds_lib as ds
import jax.numpy as jnp
from jsl.demos import bootstrap_filter_demo as demo
import matplotlib.pyplot as plt
from jax import random
import pyprobml_utils as pml


def plot_samples(sample_state, sample_obs, ax=None):
fig, ax = plt.subplots()
ax.plot(*sample_state.T, label="state space")
ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+")
ax.scatter(*sample_state[0], c="black", zorder=3)
ax.legend()
ax.set_title("Noisy observations from hidden trajectory")
plt.axis("equal")


def plot_inference(sample_obs, mean_hist):
fig, ax = plt.subplots()
ax.scatter(*sample_obs.T, marker="+", color="tab:green", s=60)
ax.plot(*mean_hist.T, c="tab:orange", label="filtered")
ax.scatter(*mean_hist[0], c="black", zorder=3)
plt.legend()
plt.axis("equal")


if __name__ == "__main__":
key = random.PRNGKey(314)
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False

def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])
def fx(x): return x

dt = 0.4
nsteps = 100
# Initial state vector
x0 = jnp.array([1.5, 0.0])
# State noise
Qt = jnp.eye(2) * 0.001
# Observed noise
Rt = jnp.eye(2) * 0.05

key = random.PRNGKey(314)
model = ds.NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
sample_state, sample_obs = model.sample(key, x0, nsteps)

n_particles = 3_000
fz_vec = jax.vmap(fz, in_axes=(0, None))
particle_filter = ds.BootstrapFiltering(lambda x: fz_vec(x, dt), fx, Qt, Rt)
pf_mean = particle_filter.filter(key, x0, sample_obs, n_particles)


plot_inference(sample_obs, pf_mean)
pml.savefig("nlds2d_bootstrap.pdf")

plot_samples(sample_state, sample_obs)
pml.savefig("nlds2d_data.pdf")

plt.show()
figures = demo.main()
for name, figure in figures.items():
filename = f"./../figures/{name}.pdf"
figure.savefig(filename)
plt.show()
122 changes: 6 additions & 116 deletions scripts/eekf_logistic_regression_demo.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,15 @@
# Online learning of a logistic
# regression model using the Exponential-family
# Extended Kalman Filter (EEKF) algorithm

# Author: Gerardo Durán-Martín (@gerdm)

import superimport
# !pip install git+git://github.com/probml/jsl

import jax
import nlds_lib as ds
import jax.numpy as jnp
from jsl.demos import eekf_logistic_regression_demo
import matplotlib.pyplot as plt
import pyprobml_utils as pml
from jax import random
from jax.scipy.optimize import minimize
from sklearn.datasets import make_biclusters

def sigmoid(x): return jnp.exp(x) / (1 + jnp.exp(x))
def log_sigmoid(z): return z - jnp.log(1 + jnp.exp(z))
def fz(x): return x
def fx(w, x): return sigmoid(w[None, :] @ x)
def Rt(w, x): return sigmoid(w @ x) * (1 - sigmoid(w @ x))

## Data generating process
n_datapoints = 50
m = 2
X, rows, cols = make_biclusters((n_datapoints, m), 2,
noise=0.6, random_state=314,
minval=-4, maxval=4)
# whether datapoints belong to class 1
y = rows[0] * 1.0

Phi = jnp.c_[jnp.ones(n_datapoints)[:, None], X]
N, M = Phi.shape

colors = ["black" if el else "white" for el in y]

# Predictive domain
xmin, ymin = X.min(axis=0) - 0.1
xmax, ymax = X.max(axis=0) + 0.1
step = 0.1
Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = Xspace.shape
Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])

### EEKF Approximation
mu_t = jnp.zeros(M)
Pt = jnp.eye(M) * 0.0
P0 = jnp.eye(M) * 2.0

model = ds.ExtendedKalmanFilter(fz, fx, Pt, Rt)
w_eekf_hist, P_eekf_hist = model.filter(mu_t, y, Phi, P0)

w_eekf = w_eekf_hist[-1]
P_eekf = P_eekf_hist[-1]

### Laplace approximation
key = random.PRNGKey(314)
init_noise = 0.6
w0 = random.multivariate_normal(key, jnp.zeros(M), jnp.eye(M) * init_noise)
alpha = 1.0
def E(w):
an = Phi @ w
log_an = log_sigmoid(an)
log_likelihood_term = y * log_an + (1 - y) * jnp.log1p(-sigmoid(an))
prior_term = alpha * w @ w / 2

return prior_term - log_likelihood_term.sum()

res = minimize(lambda x: E(x) / len(y), w0, method="BFGS")
w_laplace = res.x
SN = jax.hessian(E)(w_laplace)

### Ploting surface predictive distribution
key = random.PRNGKey(31415)
nsamples = 5000

# EEKF surface predictive distribution
eekf_samples = random.multivariate_normal(key, w_eekf, P_eekf, (nsamples,))
Z_eekf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, eekf_samples))
Z_eekf = Z_eekf.mean(axis=0)

fig, ax = plt.subplots()
ax.contourf(*Xspace, Z_eekf, cmap="RdBu_r", alpha=0.7, levels=20)
ax.scatter(*X.T, c=colors, edgecolors="black", s=80)
ax.set_title("(EEKF) Predictive distribution")
pml.savefig("logistic-regression-surface-eekf.pdf")

# Laplace surface predictive distribution
laplace_samples = random.multivariate_normal(key, w_laplace, SN, (nsamples,))
Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples))
Z_laplace = Z_laplace.mean(axis=0)

fig, ax = plt.subplots()
ax.contourf(*Xspace, Z_laplace, cmap="RdBu_r", alpha=0.7, levels=20)
ax.scatter(*X.T, c=colors, edgecolors="black", s=80)
ax.set_title("(Laplace) Predictive distribution")
pml.savefig("logistic-regression-surface-laplace.pdf")

### Plot EEKF and Laplace training history
P_eekf_hist_diag = jnp.diagonal(P_eekf_hist, axis1=1, axis2=2)
P_laplace_diag = jnp.sqrt(jnp.diagonal(SN))
lcolors = ["black", "tab:blue", "tab:red"]
elements = w_eekf_hist.T, P_eekf_hist_diag.T, w_laplace, P_laplace_diag, lcolors
timesteps = jnp.arange(n_datapoints) + 1

for k, (wk, Pk, wk_laplace, Pk_laplace, c) in enumerate(zip(*elements)):
fig, ax = plt.subplots()
ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (EEKF)")
ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)

ax.set_xlim(1, n_datapoints)
ax.legend(framealpha=0.7, loc="upper right")
ax.set_xlabel("number samples")
ax.set_ylabel("weights")
plt.tight_layout()
pml.savefig(f"eekf-laplace-hist-w{k}.pdf")

print("EEKF weights")
print(w_eekf, end="\n"*2)

print("Laplace weights")
print(w_laplace, end="\n"*2)
figures = eekf_logistic_regression_demo.main()
for name, figure in figures.items():
filename = f"./../figures/{name}.pdf"
figure.savefig(filename)
plt.show()

66 changes: 7 additions & 59 deletions scripts/ekf_continuous_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,13 @@
# * Nonlinear Dynamics and Chaos - Steven Strogatz
# Author: Gerardo Durán-Martín (@gerdm)

import superimport
# !pip install git+git://github.com/probml/jsl

import nlds_lib as ds
import matplotlib.pyplot as plt
import pyprobml_utils as pml
import jax.numpy as jnp
import numpy as np
from jax import random

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False

def fz(x):
x, y = x
return jnp.asarray([y, x - x ** 3])

def fx(x):
x, y = x
return jnp.asarray([x, y])

dt = 0.01
T = 7.5
nsamples = 70
x0 = jnp.array([0.5, -0.75])

# State noise
Qt = jnp.eye(2) * 0.001
# Observed noise
Rt = jnp.eye(2) * 0.01

key = random.PRNGKey(314)
ekf = ds.ContinuousExtendedKalmanFilter(fz, fx, Qt, Rt)
sample_state, sample_obs, jump = ekf.sample(key, x0, T, nsamples)
mu_hist, V_hist = ekf.estimate(sample_state, sample_obs, jump, dt)

vmin, vmax, step = -1.5, 1.5 + 0.5, 0.5
X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1]
X_dot = jnp.apply_along_axis(fz, 0, X)

fig, ax = plt.subplots()
ax.plot(*sample_state.T, label="state space")
ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations")
field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
ax.legend()
plt.axis("equal")
ax.set_title("State Space")
pml.savefig("ekf-state-space.pdf")

fig, ax = plt.subplots()
ax.plot(*sample_state.T, c="tab:orange", label="EKF estimation")
ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations")
ax.scatter(*mu_hist[0], c="black", zorder=3)
for mut, Vt in zip(mu_hist[::4], V_hist[::4]):
pml.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3)
plt.legend()
field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
ax.legend()
plt.axis("equal")
ax.set_title("Approximate Space")
pml.savefig("ekf-estimated-space.pdf")
from jsl.demos import ekf_continuous_demo
import matplotlib.pyplot as plt

figures = ekf_continuous_demo.main()
for name, figure in figures.items():
filename = f"./../figures/{name}.pdf"
figure.savefig(filename)
plt.show()
70 changes: 6 additions & 64 deletions scripts/ekf_mlp_anim_demo.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,14 @@
# Example showcasing the learning process of the EKF algorithm.
# This demo is based on the ekf_mlp_anim_demo.py demo.
# Author: Gerardo Durán-Martín (@gerdm)
import superimport

import jax
import nlds_lib as ds
import jax.numpy as jnp
import matplotlib.pyplot as plt
import ekf_vs_ukf_mlp_demo as demo
import matplotlib.animation as animation
from functools import partial
from jax.random import PRNGKey, split, normal, multivariate_normal
# !pip install git+git://github.com/probml/jsl

import jax.numpy as jnp
from jsl.demos import ekf_mlp_anim_demo

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
def f(x): return x -10 * jnp.cos(x) * jnp.sin(x) + x ** 3
filepath = "./../figures/ekf_mlp_demo.mp4"
def fx(x): return x -10 * jnp.cos(x) * jnp.sin(x) + x ** 3
def fz(W): return W

# *** MLP configuration ***
n_hidden = 6
n_in, n_out = 1, 1
n_params = (n_in + 1) * n_hidden + (n_hidden + 1) * n_out
fwd_mlp = partial(demo.mlp, n_hidden=n_hidden)
# vectorised for multiple observations
fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0])
# vectorised for multiple weights
fwd_mlp_weights = jax.vmap(fwd_mlp, in_axes=[1, None])
# vectorised for multiple observations and weights
fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None])

# *** Generating training and test data ***
n_obs = 200
key = PRNGKey(314)
key_sample_obs, key_weights = split(key, 2)
xmin, xmax = -3, 3
sigma_y = 3.0
x, y = demo.sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y)
xtest = jnp.linspace(x.min(), x.max(), n_obs)

# *** MLP Training with EKF ***
W0 = normal(key_weights, (n_params,)) * 1 # initial random guess
Q = jnp.eye(n_params) * 1e-4; # parameters do not change
R = jnp.eye(1) * sigma_y**2; # observation noise is fixed
Vinit = jnp.eye(n_params) * 100 # vague prior

ekf = ds.ExtendedKalmanFilter(fz, fwd_mlp, Q, R)
ekf_mu_hist, ekf_Sigma_hist = ekf.filter(W0, y[:, None], x[:, None], Vinit)

xtest = jnp.linspace(x.min(), x.max(), 200)
nframes = n_obs
fig, ax = plt.subplots()

def func(i):
plt.cla()
W, SW = ekf_mu_hist[i], ekf_Sigma_hist[i]
W_samples = multivariate_normal(key, W, SW, (100,))
sample_yhat = fwd_mlp_obs_weights(W_samples, xtest[:, None])
for sample in sample_yhat:
ax.plot(xtest, sample, c="tab:gray", alpha=0.07)
ax.plot(xtest, sample_yhat.mean(axis=0))
ax.scatter(x[:i], y[:i], s=14, c="none", edgecolor="black", label="observations")
ax.scatter(x[i], y[i], s=30, c="tab:red")
ax.set_title(f"EKF+MLP ({i+1:03}/{n_obs})")
ax.set_xlim(x.min(), x.max())
ax.set_ylim(y.min(), y.max())

return ax

ani = animation.FuncAnimation(fig, func, frames=n_obs)
ani.save("../figures/samples_hist_ekf.mp4", dpi=200, bitrate=-1, fps=10)
ekf_mlp_anim_demo.main(fx, fz, filepath)
Loading