-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save PyMC Deterministic
s to idata
#21
Comments
Thanks, @bwengals! There is no concept of I'm not sure of a pleasant way of storing deterministics, either! If you have a function from a parameter to your deterministics, it is probably better off if the user applies that themselves. I think these lines do the deterministics in PyMC. I'd be happy to include a snippet in the documentation -- either as a standalone page, or with the bayeux and pymc demo. LMK what you think! |
Ah of course,
so that makes perfect sense that there's no mechanism to store deterministics. Maybe a snippet in the pymc demo is the easiest? Whatever you think makes sense! Thanks or pointing me to that spot in the code, I'll try it out on my end too. |
It would be great to have this example/feature in the documentation 🙏 |
A possibility (which is not very elegant) is to sample the deterministic via import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
x = np.linspace(0, 1, 100)
y = 1 + 2 * x + np.random.normal(0, 0.1, 100)
with pm.Model(coords={"n": range(x.size)}) as model:
alpha = pm.Normal("alpha", 0, 1)
beta = pm.Normal("beta", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
mu = pm.Deterministic("mu", alpha + beta * x, dims=("n",))
pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims=("n",))
bx_model = bx.Model.from_pymc(model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
with model:
idata.extend(pm.sample_posterior_predictive(idata, var_names=["mu", "y_obs"]))
fig, ax = plt.subplots()
ax.plot(x, y, "o", label="data")
az.plot_hdi(
x,
idata.posterior_predictive["y_obs"],
fill_kwargs={"alpha": 0.5, "label": "y_obs"},
ax=ax,
)
az.plot_hdi(
x,
idata.posterior_predictive["mu"],
fill_kwargs={"alpha": 0.8, "label": "mu"},
ax=ax,
)
ax.legend()
ax.set(title="linear model", xlabel="x", ylabel="y") Note that idata.posterior_predictive provides the coordinates: I have not managed to do it via the functions in https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py . |
That's interesting -- I'd like to continue with the stateless API, but I could imagine a workflow like bx_model = bx.Model.from_pymc(model)
pytree = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0), return_pytree=True)
idata = bx.postprocess_pymc(pytree, model) where what do you think? |
Yes! That would be fantastic! I tried this path but realized I needed |
Hi, I'm trying to implement a quasi-stateful RNG to be used in ELBO https://jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html. Any ideas how to accomplish this in Bayeux? |
Hey, have been playing around with this a bit from PyMC, so glad this exists now! Unfortunately I'm not getting
Deterministic
s recorded inidata.posterior
. Happy to attempt PR if you point me where to start?The text was updated successfully, but these errors were encountered: