FluxML/Flux.jl

Shape-propagating Chain

Open

#703 ouverte le 26 mars 2019

Voir sur GitHub
 (8 commentaires) (8 réactions) (0 assignés)Julia (619 forks)batch import
discussionenhancementhelp wanted

Métriques du dépôt

Stars
 (4 725 stars)
Métriques de merge PR
 (Merge moyen 4h 27m) (2 PRs mergées en 30 j)

Description

It'd be nice to be able to write something like

model = @Chain(
  Input(28^2),
  Dense(32, relu),
  Dense(10),
  softmax)

It's a relatively minor convenience but it does avoid some redundancy when specifying chains, which is tedious to correct and easy to get wrong when trying different layer sizes.

Here's roughly how I imagine this working. The @Chain would expand to something like

shape = nothing
layer1, shape = fromshape(Input, shape, 10)
layer2, shape = fromshape(Dense, shape, 32, relu)
...
Chain(layer1, layer2, ...)

fromshape can then forward to an appropriate constructor or error for non-supported layers. Hopefully this strikes the right balance of simplicity/generality and we don't end up having to turn it into a full shape inference system.

Guide contributeur