pyro-ppl/pyro

[FR] Support Automatic Mixed Precision training

Open

#3 316 ouverte le 31 janv. 2024

Voir sur GitHub
 (7 commentaires) (0 réactions) (0 assignés)Python (981 forks)batch import
enhancementhelp wanted

Métriques du dépôt

Stars
 (8 211 stars)
Métriques de merge PR
 (Merge moyen 10j 19h) (1 PR mergée en 30 j)

Description

Issue Description

Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into float16 or bfloat16 but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of the GradScaler class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.html

It would be nice to have support for using this class within pyro optimizers to allow for amp support.

Guide contributeur