pyro-ppl/pyro

[feature request] manual mini-batching and batch dimension scaling

Open

#1437 opened on Oct 8, 2018

View on GitHub
 (6 comments) (0 reactions) (0 assignees)Python (8,211 stars) (981 forks)batch import
documentationhelp wanted

Description

In models with mixed levels of nesting (e.g. global_plate > local_plate_1 > local_plate_2 > ...), mibi-batching across different batch dimensions requires introducing proper scale factors for each batch dimension. Pyro handles these scale factors automatically if mini-batching is achieved via pyro.iarange(..., size=..., subsample_size=...) or pyro.iarange(..., size=..., subsample=...). The latter construct is flexible and allows arbitrary mibi-batching schemes, including big data situations where the full data tensor can not be loaded all at once.

Mini-batching, however, is often done manually and externally and not via pyro.iarange. In such cases, the appropriate scale factors must also be applied manually via poutine.scale. We are being consistent here: manual mini-batching? then manual scaling. However, most of the examples (DMM, VAE, ...) have little to no emphasis on this issue and neglect scaling altogether. While convergence is not a big deal while working with adaptive optimizers, neglecting the scale factors leads to wrong ELBO estimates.

  • Adding a word of caution to the examples about scale factors and/or throwing in poutine.scale when mini-batching manually to set a good precedent for the new users?

Contributor guide