FluxML/Flux.jl

Weights shape not validated against kernel, channels

Open

#2 506 ouverte le 25 oct. 2024

Voir sur GitHub
 (5 commentaires) (1 réaction) (0 assignés)Julia (619 forks)batch import
good first issuehelp wanted

Métriques du dépôt

Stars
 (4 725 stars)
Métriques de merge PR
 (Merge moyen 4h 27m) (2 PRs mergées en 30 j)

Description

weights = Flux.kaiming_normal()(3, 3, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3,), 3 => 1, pad=1)  # 10 parameters

weights = Flux.kaiming_normal()(3, 3, 1, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3, 3), 1 => 1, pad=1)  # 10 parameters

I wanted to strictly specify the weight init for testing, but encountered this odd result. I think there should be validation to ensure that the weight shape matches the kernel size and input channels, and error if there is a mismatch.

Guide contributeur