FluxML/Flux.jl

Weights shape not validated against kernel, channels

Open

#2506 aperta il 25 ott 2024

Vedi su GitHub
 (5 commenti) (1 reazione) (0 assegnatari)Julia (619 fork)batch import
good first issuehelp wanted

Metriche repository

Star
 (4725 star)
Metriche merge PR
 (Merge medio 4h 27m) (2 PR mergiate in 30 g)

Descrizione

weights = Flux.kaiming_normal()(3, 3, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3,), 3 => 1, pad=1)  # 10 parameters

weights = Flux.kaiming_normal()(3, 3, 1, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3, 3), 1 => 1, pad=1)  # 10 parameters

I wanted to strictly specify the weight init for testing, but encountered this odd result. I think there should be validation to ensure that the weight shape matches the kernel size and input channels, and error if there is a mismatch.

Guida contributor