Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug in MRVI's get_aggregated_posterior #3188

Open
rastogiruchir opened this issue Feb 13, 2025 · 3 comments · May be fixed by #3189
Open

Potential bug in MRVI's get_aggregated_posterior #3188

rastogiruchir opened this issue Feb 13, 2025 · 3 comments · May be fixed by #3189
Labels

Comments

@rastogiruchir
Copy link

The MixtureSameFamily constructed in get_aggregated_posterior computes different log_probs from those I manually compute by iterating over the mixture components. On real datasets, I see that this manual computation shrinks the range of the log_probs calculated under a given sample. Not sure if I'm misunderstanding the intention of get_aggregated_posterior, and I haven't checked in detail why this discrepancy exists.

Here's a small notebook to reproduce: https://colab.research.google.com/drive/1tT8rbLnK5zJqNnZa_Fn5sUpcwNB8Ge90#scrollTo=Zn7z1tEB_zJz.

@ori-kron-wis
Copy link
Collaborator

@PierreBoyeau @justjhong

@justjhong
Copy link
Contributor

justjhong commented Feb 13, 2025

Hey @rastogiruchir , thanks for raising this issue. I believe you are right, this is in fact a bug. I was able to figure out that if you don't indicate the event dimension of the Normal distribution, the mixture family is basically computing the mixture log probs over each dimension instead of over the full log probs.

In more detail:
Bugged version:
$$\log \Pi_{d}^{D} (\frac{\Sigma_{c}^{C} \mathcal{N}( \mu_{dc}, \sigma_{dc})}{C})$$

Correct version:
$$\log (\frac{\Sigma_{c}^{C} \Pi_{d}^{D} \mathcal{N}(\mu_{dc}, \sigma_{dc})}{N})$$

I was able to reproduce the fix you have with the Multivariate distribution using to_event on the original implementation:

  """Following the code in get_aggregated_posterior and differential_abundance"""
  ap = MixtureSameFamily(
    Categorical(probs=jnp.ones(qu_locs.shape[0]) / qu_locs.shape[0]),
    Normal(qu_locs, qu_scales).to_event(1),
  )
  return ap.log_prob(x).item()

Will make a PR for this. Thanks again Ruchir!

@canergen
Copy link
Member

Thanks @rastogiruchir. To highlight Justin's suggested solution fixes it in jax: https://colab.research.google.com/drive/1tT8rbLnK5zJqNnZa_Fn5sUpcwNB8Ge90#scrollTo=Zn7z1tEB_zJz.
In addition, it highlights that MoG in torch were already correct. To note, torch.distributions.Normal requires standard deviation while MultiVariateNormal expects covariance matrices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants