[feature request] manual mini-batching and batch dimension scaling
#1437 opened on Oct 8, 2018
Description
In models with mixed levels of nesting (e.g. global_plate > local_plate_1 > local_plate_2 > ...), mibi-batching across different batch dimensions requires introducing proper scale factors for each batch dimension. Pyro handles these scale factors automatically if mini-batching is achieved via pyro.iarange(..., size=..., subsample_size=...) or pyro.iarange(..., size=..., subsample=...). The latter construct is flexible and allows arbitrary mibi-batching schemes, including big data situations where the full data tensor can not be loaded all at once.
Mini-batching, however, is often done manually and externally and not via pyro.iarange. In such cases, the appropriate scale factors must also be applied manually via poutine.scale. We are being consistent here: manual mini-batching? then manual scaling. However, most of the examples (DMM, VAE, ...) have little to no emphasis on this issue and neglect scaling altogether. While convergence is not a big deal while working with adaptive optimizers, neglecting the scale factors leads to wrong ELBO estimates.
- Adding a word of caution to the examples about scale factors and/or throwing in
poutine.scalewhen mini-batching manually to set a good precedent for the new users?