pyro-ppl/pyro

[FR] Support Automatic Mixed Precision training

Open

#3.316 geöffnet am 31. Jan. 2024

Auf GitHub ansehen
 (7 Kommentare) (0 Reaktionen) (0 zugewiesene Personen)Python (981 Forks)batch import
enhancementhelp wanted

Repository-Metriken

Stars
 (8.211 Stars)
PR-Merge-Metriken
 (Durchschn. Merge 10T 19h) (1 gemergte PR in 30 T)

Beschreibung

Issue Description

Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into float16 or bfloat16 but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of the GradScaler class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.html

It would be nice to have support for using this class within pyro optimizers to allow for amp support.

Contributor Guide