FluxML/Flux.jl

Shape-propagating Chain

Open

#703 opened on 2019年3月26日

GitHub で見る
 (8 comments) (8 reactions) (0 assignees)Julia (4,725 stars) (619 forks)batch import
discussionenhancementhelp wanted

説明

It'd be nice to be able to write something like

model = @Chain(
  Input(28^2),
  Dense(32, relu),
  Dense(10),
  softmax)

It's a relatively minor convenience but it does avoid some redundancy when specifying chains, which is tedious to correct and easy to get wrong when trying different layer sizes.

Here's roughly how I imagine this working. The @Chain would expand to something like

shape = nothing
layer1, shape = fromshape(Input, shape, 10)
layer2, shape = fromshape(Dense, shape, 32, relu)
...
Chain(layer1, layer2, ...)

fromshape can then forward to an appropriate constructor or error for non-supported layers. Hopefully this strikes the right balance of simplicity/generality and we don't end up having to turn it into a full shape inference system.

コントリビューターガイド