discussionenhancementhelp wanted
Description
This issue proposes adding an implementation of nonparametric Hamiltonian Monte Carlo, a modified version of HMC that provably correctly handles models with stochastic control flow, unlike vanilla HMC.
@fzaiser (one of the paper's authors) has open-sourced a PyTorch reference implementation, and turning this into a Pyro inference algorithm would be a great project for contributors with some preexisting interest in and knowledge of the line of research (on extensions of gradient-based approximate inference algorithms to discontinuous models) embodied in the paper.
Some questions for discussion:
- What extra information about the model is necessary to support the algorithms used in the paper's experiments?
- Can this information be extracted automatically, or would it require additional user-supplied annotations?
- Would it make more sense to implement this functionality in a new MCMC kernel or to modify our existing HMC kernels?
- Would NumPyro be a better substrate for this algorithm since Jax does not compile away control flow?