Description
Pyro's HMC and NUTS implementations are feature-complete and well-tested, but they are quite slow in models like the one in our Bayesian regression tutorial that operate on small tensors for reasons that are largely beyond our control (mostly having to do with the design and implementation of torch.autograd), which is unfortunate because these are often the subjects of new users' first encounters with Pyro. Running multiple MCMC chains in parallel is one way of dealing with this problem, but Pyro's MCMC algorithms currently only support process-level parallelism with torch.multiprocessing, which is slower, memory-hungry and also error-prone on some platforms like Google Colab and Windows.
This issue proposes implementing the "ChEES-HMC" algorithm described in this paper as a new MCMC kernel where vectorization over individual MCMC chains happens via broadcasting in an additional plate context, similar to the vectorization over guide samples in our Trace*_ELBO implementations. While this algorithm is unlikely to replace NUTS in all contexts, vectorization over a large number of independent chains may be especially useful in alleviating PyTorch-related performance issues in small models.
Note that this proposal is more narrowly scoped than the general suggestion in #2539 to support broadcasting-based parallelization in our existing MCMC kernels, which as @fehiepsi said in #2539 is probably best deferred until better auto-vectorization functionality a la JAX's vmap is added to PyTorch.
This would be a great starting point for a contributor with some probabilistic ML expertise who is interested in adding a high-impact feature to Pyro while learning more about some of the internal inference APIs. If that sounds like you, please speak up! We're happy to help review draft code or discuss design issues.