Skip to content

Latest commit

 

History

History
10 lines (6 loc) · 993 Bytes

File metadata and controls

10 lines (6 loc) · 993 Bytes

Reparametrisable PyTorch MixtureSameFamily distribution

PyTorch implementation of the implicit reparametrisation trick for mixture distributions based on Figurnov et al., 2019, "Implicit Reparameterization Gradients" and the implementation in Tensorflow Probability.

Can be readily used for variational inference with mixture distribution variational families.

Remarks:

  • For multivariate mixtures, the class is currently implemented when the mixture component distributions fully factorise.
  • Also added a StableNormal distribution, which overrides the default cdf method with a more stable implementation from pytorch/pytorch#52973 (comment). The implementation also provides a _log_cdf method, however it is not used for the implicit reparametrisation.