FluxML/Flux.jl

Allow BatchNorm training on CUDA with `track_stats=false`

Open

#1 606 ouverte le 31 mai 2021

Voir sur GitHub
 (4 commentaires) (0 réactions) (0 assignés)Julia (619 forks)batch import
good first issuehelp wanted

Métriques du dépôt

Stars
 (4 725 stars)
Métriques de merge PR
 (Merge moyen 4h 27m) (2 PRs mergées en 30 j)

Description

Gathered from https://discourse.julialang.org/t/batchnorm-only-track-stats-true-supported-on-gpu/62091.

This would most likely require changes in NNlibCUDA as well. I'm not sure how interchangeable the various cudnnBatchNormalizationForward* functions are, so putting a pin in this until someone more knowledgeable can comment.

Guide contributeur