描述
This issue proposes a streaming architecture for MCMC on models with large memory footprint.
The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.
@fehiepsi suggested creating a new MCMC class (say StreamingMCMC) with similar interface to MCMC and still independent of kernel (using either HMC or NUTS) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC.
Along with the new StreamingMCMC class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.
Tasks (to be split into multiple PRs)
@mtsokol
- #2857 Create a
StreamingMCMCclass with interface identical to MCMC (except disallowing parallel chains). - #2857 Generalize unit tests of
MCMCto parametrize over bothMCMCandStreamingMCMC - Add some tests ensuring
StreamingMCMCandMCMCperform identical computations, up to numerical precision - Create a tutorial using
StreamingMCMCon a big model
@fritzo
- #2856 Create streaming helpers for mean, variance, etc.
- Add
r_hatto pyro.ops.streaming - Add
n_eff = essto pyro.ops.streaming