FluxML/Flux.jl

Weights shape not validated against kernel, channels

Open

#2,506 创建于 2024年10月25日

在 GitHub 查看
 (5 评论) (1 反应) (0 负责人)Julia (4,725 star) (619 fork)batch import
good first issuehelp wanted

描述

weights = Flux.kaiming_normal()(3, 3, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3,), 3 => 1, pad=1)  # 10 parameters

weights = Flux.kaiming_normal()(3, 3, 1, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3, 3), 1 => 1, pad=1)  # 10 parameters

I wanted to strictly specify the weight init for testing, but encountered this odd result. I think there should be validation to ensure that the weight shape matches the kernel size and input channels, and error if there is a mismatch.

贡献者指南