Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save PyMC Deterministics to idata #21

Open
bwengals opened this issue Jan 19, 2024 · 7 comments
Open

Save PyMC Deterministics to idata #21

bwengals opened this issue Jan 19, 2024 · 7 comments

Comments

@bwengals
Copy link

Hey, have been playing around with this a bit from PyMC, so glad this exists now! Unfortunately I'm not getting Deterministics recorded in idata.posterior. Happy to attempt PR if you point me where to start?

@ColCarroll
Copy link
Collaborator

Thanks, @bwengals!

There is no concept of Deterministic here, which is sort of funny and unsatisfying -- bayeux just takes initial parameters and a differentiable log density.

I'm not sure of a pleasant way of storing deterministics, either! If you have a function from a parameter to your deterministics, it is probably better off if the user applies that themselves.

I think these lines do the deterministics in PyMC. I'd be happy to include a snippet in the documentation -- either as a standalone page, or with the bayeux and pymc demo.

LMK what you think!

@bwengals
Copy link
Author

Ah of course,

bayeux just takes initial parameters and a differentiable log density.

so that makes perfect sense that there's no mechanism to store deterministics. Maybe a snippet in the pymc demo is the easiest? Whatever you think makes sense! Thanks or pointing me to that spot in the code, I'll try it out on my end too.

@juanitorduz
Copy link

It would be great to have this example/feature in the documentation 🙏

@juanitorduz
Copy link

juanitorduz commented Apr 2, 2024

A possibility (which is not very elegant) is to sample the deterministic via pm.sample_posterior_predictive as described in Out of model predictions with PyMC

import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm

x = np.linspace(0, 1, 100)
y = 1 + 2 * x + np.random.normal(0, 0.1, 100)

with pm.Model(coords={"n": range(x.size)}) as model:
    alpha = pm.Normal("alpha", 0, 1)
    beta = pm.Normal("beta", 0, 1)
    sigma = pm.HalfNormal("sigma", 1)
    mu = pm.Deterministic("mu", alpha + beta * x, dims=("n",))
    pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims=("n",))

bx_model = bx.Model.from_pymc(model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))

with model:
    idata.extend(pm.sample_posterior_predictive(idata, var_names=["mu", "y_obs"]))


fig, ax = plt.subplots()
ax.plot(x, y, "o", label="data")
az.plot_hdi(
    x,
    idata.posterior_predictive["y_obs"],
    fill_kwargs={"alpha": 0.5, "label": "y_obs"},
    ax=ax,
)
az.plot_hdi(
    x,
    idata.posterior_predictive["mu"],
    fill_kwargs={"alpha": 0.8, "label": "mu"},
    ax=ax,
)
ax.legend()
ax.set(title="linear model", xlabel="x", ylabel="y")

image

Note that

idata.posterior_predictive

provides the coordinates:

image


I have not managed to do it via the functions in https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py .

@ColCarroll
Copy link
Collaborator

That's interesting -- I'd like to continue with the stateless API, but I could imagine a workflow like

bx_model = bx.Model.from_pymc(model)
pytree = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0), return_pytree=True)

idata = bx.postprocess_pymc(pytree, model)

where bx.postprocess_pymc starts from: https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py#L581

what do you think?

@juanitorduz
Copy link

Yes! That would be fantastic! I tried this path but realized I needed raw_mcmc_samples, so I thought it would be better to be handled by bayeux internally 👍

@sherna90
Copy link

bayeux just takes initial parameters and a differentiable log density.

Hi, I'm trying to implement a quasi-stateful RNG to be used in ELBO https://jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html. Any ideas how to accomplish this in Bayeux?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants