pytorch/ignite

enable automatic mixed precision for xla

Open

#1,931 opened on Apr 12, 2021

View on GitHub
 (8 comments) (1 reaction) (0 assignees)Python (4,313 stars) (602 forks)batch import
enhancementhelp wanted

Description

Feature

Automatic mixed precision for xla has landed in pytorch 1.8.1 and torch/xla nightly. We should enable it in create_supervised_* helper functions.

Suggested solution

Remove xla and amp checks in _check_arg().

  • For create_supervised_trainer, update supervised_training_step_tpu() function to accept scaler argument just like supervised_training_step_amp().
  • For create_supervised_evaluator, just removing xla and amp checks in _check_arg() should work.
  • For tests, we could remove xla checks and only run with pytorch 1.8.1.

Additional context

This feature should not be included in ignite release until the next torch and xla comes out.

Contributor guide

enable automatic mixed precision for xla · pytorch/ignite#1931 | Good First Issue