FluxML/Flux.jl

Shape-propagating Chain

Open

#703 aperta il 26 mar 2019

Vedi su GitHub
 (8 commenti) (8 reazioni) (0 assegnatari)Julia (619 fork)batch import
discussionenhancementhelp wanted

Metriche repository

Star
 (4725 star)
Metriche merge PR
 (Merge medio 4h 27m) (2 PR mergiate in 30 g)

Descrizione

It'd be nice to be able to write something like

model = @Chain(
  Input(28^2),
  Dense(32, relu),
  Dense(10),
  softmax)

It's a relatively minor convenience but it does avoid some redundancy when specifying chains, which is tedious to correct and easy to get wrong when trying different layer sizes.

Here's roughly how I imagine this working. The @Chain would expand to something like

shape = nothing
layer1, shape = fromshape(Input, shape, 10)
layer2, shape = fromshape(Dense, shape, 32, relu)
...
Chain(layer1, layer2, ...)

fromshape can then forward to an appropriate constructor or error for non-supported layers. Hopefully this strikes the right balance of simplicity/generality and we don't end up having to turn it into a full shape inference system.

Guida contributor