pytorch/ignite

enable automatic mixed precision for xla

Open

#1,931 opened on 2021年4月12日

GitHub で見る
 (8 comments) (1 reaction) (0 assignees)Python (4,313 stars) (602 forks)batch import
enhancementhelp wanted

説明

Feature

Automatic mixed precision for xla has landed in pytorch 1.8.1 and torch/xla nightly. We should enable it in create_supervised_* helper functions.

Suggested solution

Remove xla and amp checks in _check_arg().

  • For create_supervised_trainer, update supervised_training_step_tpu() function to accept scaler argument just like supervised_training_step_amp().
  • For create_supervised_evaluator, just removing xla and amp checks in _check_arg() should work.
  • For tests, we could remove xla checks and only run with pytorch 1.8.1.

Additional context

This feature should not be included in ignite release until the next torch and xla comes out.

コントリビューターガイド