pytorch/ignite

enable automatic mixed precision for xla

Open

#1,931 建立於 2021年4月12日

在 GitHub 查看
 (8 留言) (1 反應) (0 負責人)Python (4,313 star) (602 fork)batch import
enhancementhelp wanted

描述

Feature

Automatic mixed precision for xla has landed in pytorch 1.8.1 and torch/xla nightly. We should enable it in create_supervised_* helper functions.

Suggested solution

Remove xla and amp checks in _check_arg().

  • For create_supervised_trainer, update supervised_training_step_tpu() function to accept scaler argument just like supervised_training_step_amp().
  • For create_supervised_evaluator, just removing xla and amp checks in _check_arg() should work.
  • For tests, we could remove xla checks and only run with pytorch 1.8.1.

Additional context

This feature should not be included in ignite release until the next torch and xla comes out.

貢獻者指南