Skip to content

Commit

Permalink
ADF + EEKF demo (#664)
Browse files Browse the repository at this point in the history
* refactor: adf_logistic_regression

Add jsl dependency

* feat: adf_logistic_regression_demo

Add ADF posterior marginal plot

* feat: adf_logistic_regression_demo

Add weight-estimation over time

* style: adf_logistic_regression_demo
  • Loading branch information
gerdm authored Jan 4, 2022
1 parent 0fefa01 commit 28e82aa
Showing 1 changed file with 41 additions and 128 deletions.
169 changes: 41 additions & 128 deletions scripts/adf_logistic_regression_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,30 @@

import jax
import jax.numpy as jnp
import blackjax.rwmh as mh
import matplotlib.pyplot as plt
import pyprobml_utils as pml
from sklearn.datasets import make_biclusters
from jax import random
from jax.scipy.optimize import minimize
from jax.scipy.stats import norm
from jax_cosmo.scipy import integrate
from functools import partial
from jsl.demos import eekf_logistic_regression_demo as demo

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

figures = demo.main()

def sigmoid(z): return jnp.exp(z) / (1 + jnp.exp(z))

data = figures.pop("data")
X = data["X"]
y = data["y"]
Phi = data["Phi"]
Xspace = data["Xspace"]
Phispace = data["Phispace"]
w_laplace = data["w_laplace"]

def sigmoid(z): return jnp.exp(z) / (1 + jnp.exp(z))
def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z))


def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state

keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)

return states


def E_base(w, Phi, y, alpha):
"""
Base function containing the Energy of a logistic
regression with
"""
an = Phi @ w
log_an = log_sigmoid(an)
log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - sigmoid(an))
prior_term = alpha * w @ w / 2

return prior_term - log_likelihood_term.sum()


def Zt_func(eta, y, mu, v):
log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))
log_term = log_term + norm.logpdf(eta, mu, v)
Expand Down Expand Up @@ -111,54 +91,11 @@ def adf_step(state, xs, prior_variance, lbound, ubound):

return (mu_t, tau_t), (mu_t, tau_t)


def plot_posterior_predictive(ax, X, Z, title, colors, cmap="RdBu_r"):
ax.contourf(*Xspace, Z, cmap=cmap, alpha=0.7, levels=20)
ax.scatter(*X.T, c=colors, edgecolors="gray", s=80)
ax.set_title(title)
ax.axis("off")
plt.tight_layout()


# ** Generating training data **
key = random.PRNGKey(314)
n_datapoints, ndims = 50, 2
X, rows, cols = make_biclusters((n_datapoints, ndims), 2, noise=0.6,
random_state=3141, minval=-4, maxval=4)
y = rows[0] * 1.0

alpha = 1.0
init_noise = 1.0
Phi = jnp.c_[jnp.ones(n_datapoints)[:, None], X] # Design matrix
ndata, ndims = Phi.shape


# ** MCMC Sampling with BlackJAX **
sigma_mcmc = 0.8
w0 = random.multivariate_normal(key, jnp.zeros(ndims), jnp.eye(ndims) * init_noise)
energy = partial(E_base, Phi=Phi, y=y, alpha=alpha)
initial_state = mh.new_state(w0, energy)

mcmc_kernel = mh.kernel(energy, jnp.ones(ndims) * sigma_mcmc)
mcmc_kernel = jax.jit(mcmc_kernel)

n_samples = 5_000
burnin = 300
key_init = jax.random.PRNGKey(0)
states = inference_loop(key_init, mcmc_kernel, initial_state, n_samples)

chains = states.position[burnin:, :]
nsamp, _ = chains.shape

# ** Laplace approximation **
res = minimize(lambda x: energy(x) / len(y), w0, method="BFGS")
w_map = res.x
SN = jax.hessian(energy)(w_map)

# ** ADF inference **
prior_variance = 0.0
# Lower and upper bounds of integration. Ideally, we would like to
# integrate from -inf to inf, but we run into numerical issues.
n_datapoints, ndims = Phi.shape
lbound, ubound = -20, 20
mu_t = jnp.zeros(ndims)
tau_t = jnp.ones(ndims) * 1.0
Expand All @@ -168,75 +105,51 @@ def plot_posterior_predictive(ax, X, Z, title, colors, cmap="RdBu_r"):

adf_loop = partial(adf_step, prior_variance=prior_variance, lbound=lbound, ubound=ubound)
(mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)
print("ADF weigths")
print(mu_t)


# ** Estimating posterior predictive distribution **
xmin, ymin = X.min(axis=0) - 0.1
xmax, ymax = X.max(axis=0) + 0.1
step = 0.1
Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = Xspace.shape
Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])

# MCMC posterior predictive distribution
# maps m-dimensional features on an (i,j) grid times "s" m-dimensional samples to get
# "s" samples on an (i,j) grid of predictions
Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains))
Z_mcmc = Z_mcmc.mean(axis=0)
# Laplace posterior predictive distribution
key = random.PRNGKey(314)
laplace_samples = random.multivariate_normal(key, w_map, SN, (n_samples,))
Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples))
Z_laplace = Z_laplace.mean(axis=0)
# ADF posterior predictive distribution
n_samples = 5000
key = random.PRNGKey(3141)
adf_samples = random.multivariate_normal(key, mu_t, jnp.diag(tau_t), (n_samples,))
Z_adf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, adf_samples))
Z_adf = Z_adf.mean(axis=0)


# ** Plotting predictive distribution **
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
colors = ["black" if el else "white" for el in y]

fig, ax = plt.subplots()
title = "(MCMC) Predictive distribution"
plot_posterior_predictive(ax, X, Z_mcmc, title, colors)
pml.savefig("mcmc-logreg-predictive-surface.pdf")
## Add posterior marginal for ADF-estimated weights
for i in range(ndims):
mean, std = mu_t[i], jnp.sqrt(tau_t[i])
fig = figures[f"weights_marginals_w{i}"]
ax = fig.gca()
x = jnp.linspace(mean - 4 * std, mean + 4 * std, 500)
ax.plot(x, norm.pdf(x, mean, std), label="posterior (ADF)", linestyle="dashdot")
ax.legend()

fig, ax = plt.subplots()
title = "(Laplace) Predictive distribution"
plot_posterior_predictive(ax, X, Z_adf, title, colors)
pml.savefig("laplace-logreg-predictive-surface.pdf")

fig, ax = plt.subplots()
fig_adf, ax = plt.subplots()
title = "(ADF) Predictive distribution"
plot_posterior_predictive(ax, X, Z_adf, title, colors)
pml.savefig("adf-logreg-predictive-surface.pdf")

demo.plot_posterior_predictive(ax, X, Xspace, Z_adf, title, colors)
figures["predictive_distribution_adf"] = fig_adf

# ** Plotting training history **
w_batch_all = chains.mean(axis=0)
w_batch_laplace_all = w_map
w_batch_laplace_std_all = jnp.sqrt(jnp.diag(SN))

w_batch_std_all = chains.std(axis=0)
timesteps = jnp.arange(n_datapoints)
lcolors = ["black", "tab:blue", "tab:red"]
elements = mu_t_hist.T, tau_t_hist.T, w_laplace, lcolors
timesteps = jnp.arange(n_datapoints) + 1

for k, (wk, Pk, wk_laplace, c) in enumerate(zip(*elements)):
fig_weight_k, ax = plt.subplots()
ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (adf)")
ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)

elements = zip(mu_t_hist.T, tau_t_hist.T, w_batch_all, w_batch_std_all, w_batch_laplace_all, lcolors)
for i, (w_online, w_err_online, w_batch, w_batch_err, w_batch_laplace, c) in enumerate(elements):
fig, ax = plt.subplots(figsize=(6, 3))
ax.errorbar(timesteps, w_online, jnp.sqrt(w_err_online), c=c, label=f"$w_{i}$ online")
ax.axhline(y=w_batch, c=lcolors[i], linestyle="--", label=f"$w_{i}$ batch (mcmc)")
ax.axhline(y=w_batch_laplace, c=lcolors[i], linestyle="dotted",
label=f"$w_{i}$ batch (Laplace)", linewidth=2)
ax.fill_between(timesteps, w_batch - w_batch_err, w_batch + w_batch_err, color=c, alpha=0.1)
ax.legend() #loc="lower left")
ax.set_xlim(0, n_datapoints - 0.9)
ax.set_xlim(1, n_datapoints)
ax.legend(framealpha=0.7, loc="upper right")
ax.set_xlabel("number samples")
ax.set_ylabel(f"weights ({i})")
ax.set_ylabel("weights")
plt.tight_layout()
pml.savefig(f"adf-mcmc-online-hist-w{i}.pdf")
figures[f"adf_logistic_regression_hist_w{k}"] = fig_weight_k

for name, figure in figures.items():
filename = f"./../figures/{name}.pdf"
figure.savefig(filename)

plt.show()

0 comments on commit 28e82aa

Please sign in to comment.