-
Notifications
You must be signed in to change notification settings - Fork 377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential bug in MRVI's get_aggregated_posterior #3188
Comments
Hey @rastogiruchir , thanks for raising this issue. I believe you are right, this is in fact a bug. I was able to figure out that if you don't indicate the event dimension of the Normal distribution, the mixture family is basically computing the mixture log probs over each dimension instead of over the full log probs. In more detail: Correct version: I was able to reproduce the fix you have with the Multivariate distribution using to_event on the original implementation:
Will make a PR for this. Thanks again Ruchir! |
Thanks @rastogiruchir. To highlight Justin's suggested solution fixes it in jax: https://colab.research.google.com/drive/1tT8rbLnK5zJqNnZa_Fn5sUpcwNB8Ge90#scrollTo=Zn7z1tEB_zJz. |
The MixtureSameFamily constructed in get_aggregated_posterior computes different log_probs from those I manually compute by iterating over the mixture components. On real datasets, I see that this manual computation shrinks the range of the log_probs calculated under a given sample. Not sure if I'm misunderstanding the intention of get_aggregated_posterior, and I haven't checked in detail why this discrepancy exists.
Here's a small notebook to reproduce: https://colab.research.google.com/drive/1tT8rbLnK5zJqNnZa_Fn5sUpcwNB8Ge90#scrollTo=Zn7z1tEB_zJz.
The text was updated successfully, but these errors were encountered: