pyro-ppl/pyro

FR Streaming MCMC interface for big models

Open

#2,843 opened on May 14, 2021

View on GitHub
 (9 comments) (0 reactions) (2 assignees)Python (8,211 stars) (981 forks)batch import
enhancementhelp wanted

Description

This issue proposes a streaming architecture for MCMC on models with large memory footprint.

The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.

@fehiepsi suggested creating a new MCMC class (say StreamingMCMC) with similar interface to MCMC and still independent of kernel (using either HMC or NUTS) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC.

Along with the new StreamingMCMC class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.

Tasks (to be split into multiple PRs)

@mtsokol

  • #2857 Create a StreamingMCMC class with interface identical to MCMC (except disallowing parallel chains).
  • #2857 Generalize unit tests of MCMC to parametrize over both MCMC and StreamingMCMC
  • Add some tests ensuring StreamingMCMC and MCMC perform identical computations, up to numerical precision
  • Create a tutorial using StreamingMCMC on a big model

@fritzo

  • #2856 Create streaming helpers for mean, variance, etc.
  • Add r_hat to pyro.ops.streaming
  • Add n_eff = ess to pyro.ops.streaming

Contributor guide