FluxML/Flux.jl

Shape-propagating Chain

Open

#703 建立於 2019年3月26日

在 GitHub 查看
 (8 留言) (8 反應) (0 負責人)Julia (4,725 star) (619 fork)batch import
discussionenhancementhelp wanted

描述

It'd be nice to be able to write something like

model = @Chain(
  Input(28^2),
  Dense(32, relu),
  Dense(10),
  softmax)

It's a relatively minor convenience but it does avoid some redundancy when specifying chains, which is tedious to correct and easy to get wrong when trying different layer sizes.

Here's roughly how I imagine this working. The @Chain would expand to something like

shape = nothing
layer1, shape = fromshape(Input, shape, 10)
layer2, shape = fromshape(Dense, shape, 32, relu)
...
Chain(layer1, layer2, ...)

fromshape can then forward to an appropriate constructor or error for non-supported layers. Hopefully this strikes the right balance of simplicity/generality and we don't end up having to turn it into a full shape inference system.

貢獻者指南