-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor: nlds_lib, lds_lib, and scripts -> JSL (#663)
* refactor ekf_vs_ukf_demo add jsl dependency * refactor eekf_logistic_regression_demo import from jsl * refactor: ekf_vs_ukf_demo * refactor: ekf_vs_ukf_mlp_demo.py Import demo from jsl library * refactor: ekf_continuous_demo.py Import demo from jsl library * refactor: linreg_kf_demo.py Import demo from jsl library * refactor: pendulum_1d_demo.py Import demo from jsl library * refactor: ekf_mlp_anim_demo.py Import demo from jsl library * refactor: bootstrap_filter_demo.py Import demo from jsl library * refactor: kf_parallel_demo.py Import demo from jsl library * refactor: kf_spiral_demo.py Import demo from jsl library * refactor: kf_tracking_demo.py Import demo from jsl library * refactor: kf_continuous_circle_demo.py Import demo from jsl library * remove: lds_lib, nlds_lib, lds_cts_time_lib
- Loading branch information
Showing
15 changed files
with
82 additions
and
2,211 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,14 @@ | ||
# Demo of the bootstrap filter under a | ||
# nonlinear discrete system | ||
|
||
# Author: Gerardo Durán-Martín (@gerdm) | ||
|
||
import superimport | ||
# !pip install git+git://github.com/probml/jsl | ||
|
||
import jax | ||
import nlds_lib as ds | ||
import jax.numpy as jnp | ||
from jsl.demos import bootstrap_filter_demo as demo | ||
import matplotlib.pyplot as plt | ||
from jax import random | ||
import pyprobml_utils as pml | ||
|
||
|
||
def plot_samples(sample_state, sample_obs, ax=None): | ||
fig, ax = plt.subplots() | ||
ax.plot(*sample_state.T, label="state space") | ||
ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+") | ||
ax.scatter(*sample_state[0], c="black", zorder=3) | ||
ax.legend() | ||
ax.set_title("Noisy observations from hidden trajectory") | ||
plt.axis("equal") | ||
|
||
|
||
def plot_inference(sample_obs, mean_hist): | ||
fig, ax = plt.subplots() | ||
ax.scatter(*sample_obs.T, marker="+", color="tab:green", s=60) | ||
ax.plot(*mean_hist.T, c="tab:orange", label="filtered") | ||
ax.scatter(*mean_hist[0], c="black", zorder=3) | ||
plt.legend() | ||
plt.axis("equal") | ||
|
||
|
||
if __name__ == "__main__": | ||
key = random.PRNGKey(314) | ||
plt.rcParams["axes.spines.right"] = False | ||
plt.rcParams["axes.spines.top"] = False | ||
|
||
def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])]) | ||
def fx(x): return x | ||
|
||
dt = 0.4 | ||
nsteps = 100 | ||
# Initial state vector | ||
x0 = jnp.array([1.5, 0.0]) | ||
# State noise | ||
Qt = jnp.eye(2) * 0.001 | ||
# Observed noise | ||
Rt = jnp.eye(2) * 0.05 | ||
|
||
key = random.PRNGKey(314) | ||
model = ds.NLDS(lambda x: fz(x, dt), fx, Qt, Rt) | ||
sample_state, sample_obs = model.sample(key, x0, nsteps) | ||
|
||
n_particles = 3_000 | ||
fz_vec = jax.vmap(fz, in_axes=(0, None)) | ||
particle_filter = ds.BootstrapFiltering(lambda x: fz_vec(x, dt), fx, Qt, Rt) | ||
pf_mean = particle_filter.filter(key, x0, sample_obs, n_particles) | ||
|
||
|
||
plot_inference(sample_obs, pf_mean) | ||
pml.savefig("nlds2d_bootstrap.pdf") | ||
|
||
plot_samples(sample_state, sample_obs) | ||
pml.savefig("nlds2d_data.pdf") | ||
|
||
plt.show() | ||
figures = demo.main() | ||
for name, figure in figures.items(): | ||
filename = f"./../figures/{name}.pdf" | ||
figure.savefig(filename) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,125 +1,15 @@ | ||
# Online learning of a logistic | ||
# regression model using the Exponential-family | ||
# Extended Kalman Filter (EEKF) algorithm | ||
|
||
# Author: Gerardo Durán-Martín (@gerdm) | ||
|
||
import superimport | ||
# !pip install git+git://github.com/probml/jsl | ||
|
||
import jax | ||
import nlds_lib as ds | ||
import jax.numpy as jnp | ||
from jsl.demos import eekf_logistic_regression_demo | ||
import matplotlib.pyplot as plt | ||
import pyprobml_utils as pml | ||
from jax import random | ||
from jax.scipy.optimize import minimize | ||
from sklearn.datasets import make_biclusters | ||
|
||
def sigmoid(x): return jnp.exp(x) / (1 + jnp.exp(x)) | ||
def log_sigmoid(z): return z - jnp.log(1 + jnp.exp(z)) | ||
def fz(x): return x | ||
def fx(w, x): return sigmoid(w[None, :] @ x) | ||
def Rt(w, x): return sigmoid(w @ x) * (1 - sigmoid(w @ x)) | ||
|
||
## Data generating process | ||
n_datapoints = 50 | ||
m = 2 | ||
X, rows, cols = make_biclusters((n_datapoints, m), 2, | ||
noise=0.6, random_state=314, | ||
minval=-4, maxval=4) | ||
# whether datapoints belong to class 1 | ||
y = rows[0] * 1.0 | ||
|
||
Phi = jnp.c_[jnp.ones(n_datapoints)[:, None], X] | ||
N, M = Phi.shape | ||
|
||
colors = ["black" if el else "white" for el in y] | ||
|
||
# Predictive domain | ||
xmin, ymin = X.min(axis=0) - 0.1 | ||
xmax, ymax = X.max(axis=0) + 0.1 | ||
step = 0.1 | ||
Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step] | ||
_, nx, ny = Xspace.shape | ||
Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace]) | ||
|
||
### EEKF Approximation | ||
mu_t = jnp.zeros(M) | ||
Pt = jnp.eye(M) * 0.0 | ||
P0 = jnp.eye(M) * 2.0 | ||
|
||
model = ds.ExtendedKalmanFilter(fz, fx, Pt, Rt) | ||
w_eekf_hist, P_eekf_hist = model.filter(mu_t, y, Phi, P0) | ||
|
||
w_eekf = w_eekf_hist[-1] | ||
P_eekf = P_eekf_hist[-1] | ||
|
||
### Laplace approximation | ||
key = random.PRNGKey(314) | ||
init_noise = 0.6 | ||
w0 = random.multivariate_normal(key, jnp.zeros(M), jnp.eye(M) * init_noise) | ||
alpha = 1.0 | ||
def E(w): | ||
an = Phi @ w | ||
log_an = log_sigmoid(an) | ||
log_likelihood_term = y * log_an + (1 - y) * jnp.log1p(-sigmoid(an)) | ||
prior_term = alpha * w @ w / 2 | ||
|
||
return prior_term - log_likelihood_term.sum() | ||
|
||
res = minimize(lambda x: E(x) / len(y), w0, method="BFGS") | ||
w_laplace = res.x | ||
SN = jax.hessian(E)(w_laplace) | ||
|
||
### Ploting surface predictive distribution | ||
key = random.PRNGKey(31415) | ||
nsamples = 5000 | ||
|
||
# EEKF surface predictive distribution | ||
eekf_samples = random.multivariate_normal(key, w_eekf, P_eekf, (nsamples,)) | ||
Z_eekf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, eekf_samples)) | ||
Z_eekf = Z_eekf.mean(axis=0) | ||
|
||
fig, ax = plt.subplots() | ||
ax.contourf(*Xspace, Z_eekf, cmap="RdBu_r", alpha=0.7, levels=20) | ||
ax.scatter(*X.T, c=colors, edgecolors="black", s=80) | ||
ax.set_title("(EEKF) Predictive distribution") | ||
pml.savefig("logistic-regression-surface-eekf.pdf") | ||
|
||
# Laplace surface predictive distribution | ||
laplace_samples = random.multivariate_normal(key, w_laplace, SN, (nsamples,)) | ||
Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples)) | ||
Z_laplace = Z_laplace.mean(axis=0) | ||
|
||
fig, ax = plt.subplots() | ||
ax.contourf(*Xspace, Z_laplace, cmap="RdBu_r", alpha=0.7, levels=20) | ||
ax.scatter(*X.T, c=colors, edgecolors="black", s=80) | ||
ax.set_title("(Laplace) Predictive distribution") | ||
pml.savefig("logistic-regression-surface-laplace.pdf") | ||
|
||
### Plot EEKF and Laplace training history | ||
P_eekf_hist_diag = jnp.diagonal(P_eekf_hist, axis1=1, axis2=2) | ||
P_laplace_diag = jnp.sqrt(jnp.diagonal(SN)) | ||
lcolors = ["black", "tab:blue", "tab:red"] | ||
elements = w_eekf_hist.T, P_eekf_hist_diag.T, w_laplace, P_laplace_diag, lcolors | ||
timesteps = jnp.arange(n_datapoints) + 1 | ||
|
||
for k, (wk, Pk, wk_laplace, Pk_laplace, c) in enumerate(zip(*elements)): | ||
fig, ax = plt.subplots() | ||
ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (EEKF)") | ||
ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3) | ||
|
||
ax.set_xlim(1, n_datapoints) | ||
ax.legend(framealpha=0.7, loc="upper right") | ||
ax.set_xlabel("number samples") | ||
ax.set_ylabel("weights") | ||
plt.tight_layout() | ||
pml.savefig(f"eekf-laplace-hist-w{k}.pdf") | ||
|
||
print("EEKF weights") | ||
print(w_eekf, end="\n"*2) | ||
|
||
print("Laplace weights") | ||
print(w_laplace, end="\n"*2) | ||
figures = eekf_logistic_regression_demo.main() | ||
for name, figure in figures.items(): | ||
filename = f"./../figures/{name}.pdf" | ||
figure.savefig(filename) | ||
plt.show() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,14 @@ | ||
# Example showcasing the learning process of the EKF algorithm. | ||
# This demo is based on the ekf_mlp_anim_demo.py demo. | ||
# Author: Gerardo Durán-Martín (@gerdm) | ||
import superimport | ||
|
||
import jax | ||
import nlds_lib as ds | ||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import ekf_vs_ukf_mlp_demo as demo | ||
import matplotlib.animation as animation | ||
from functools import partial | ||
from jax.random import PRNGKey, split, normal, multivariate_normal | ||
# !pip install git+git://github.com/probml/jsl | ||
|
||
import jax.numpy as jnp | ||
from jsl.demos import ekf_mlp_anim_demo | ||
|
||
plt.rcParams["axes.spines.right"] = False | ||
plt.rcParams["axes.spines.top"] = False | ||
def f(x): return x -10 * jnp.cos(x) * jnp.sin(x) + x ** 3 | ||
filepath = "./../figures/ekf_mlp_demo.mp4" | ||
def fx(x): return x -10 * jnp.cos(x) * jnp.sin(x) + x ** 3 | ||
def fz(W): return W | ||
|
||
# *** MLP configuration *** | ||
n_hidden = 6 | ||
n_in, n_out = 1, 1 | ||
n_params = (n_in + 1) * n_hidden + (n_hidden + 1) * n_out | ||
fwd_mlp = partial(demo.mlp, n_hidden=n_hidden) | ||
# vectorised for multiple observations | ||
fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0]) | ||
# vectorised for multiple weights | ||
fwd_mlp_weights = jax.vmap(fwd_mlp, in_axes=[1, None]) | ||
# vectorised for multiple observations and weights | ||
fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None]) | ||
|
||
# *** Generating training and test data *** | ||
n_obs = 200 | ||
key = PRNGKey(314) | ||
key_sample_obs, key_weights = split(key, 2) | ||
xmin, xmax = -3, 3 | ||
sigma_y = 3.0 | ||
x, y = demo.sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y) | ||
xtest = jnp.linspace(x.min(), x.max(), n_obs) | ||
|
||
# *** MLP Training with EKF *** | ||
W0 = normal(key_weights, (n_params,)) * 1 # initial random guess | ||
Q = jnp.eye(n_params) * 1e-4; # parameters do not change | ||
R = jnp.eye(1) * sigma_y**2; # observation noise is fixed | ||
Vinit = jnp.eye(n_params) * 100 # vague prior | ||
|
||
ekf = ds.ExtendedKalmanFilter(fz, fwd_mlp, Q, R) | ||
ekf_mu_hist, ekf_Sigma_hist = ekf.filter(W0, y[:, None], x[:, None], Vinit) | ||
|
||
xtest = jnp.linspace(x.min(), x.max(), 200) | ||
nframes = n_obs | ||
fig, ax = plt.subplots() | ||
|
||
def func(i): | ||
plt.cla() | ||
W, SW = ekf_mu_hist[i], ekf_Sigma_hist[i] | ||
W_samples = multivariate_normal(key, W, SW, (100,)) | ||
sample_yhat = fwd_mlp_obs_weights(W_samples, xtest[:, None]) | ||
for sample in sample_yhat: | ||
ax.plot(xtest, sample, c="tab:gray", alpha=0.07) | ||
ax.plot(xtest, sample_yhat.mean(axis=0)) | ||
ax.scatter(x[:i], y[:i], s=14, c="none", edgecolor="black", label="observations") | ||
ax.scatter(x[i], y[i], s=30, c="tab:red") | ||
ax.set_title(f"EKF+MLP ({i+1:03}/{n_obs})") | ||
ax.set_xlim(x.min(), x.max()) | ||
ax.set_ylim(y.min(), y.max()) | ||
|
||
return ax | ||
|
||
ani = animation.FuncAnimation(fig, func, frames=n_obs) | ||
ani.save("../figures/samples_hist_ekf.mp4", dpi=200, bitrate=-1, fps=10) | ||
ekf_mlp_anim_demo.main(fx, fz, filepath) |
Oops, something went wrong.