FluxML/Flux.jl

Shape-propagating Chain

Open

#703 aberto em 26 de mar. de 2019

Ver no GitHub
 (8 comments) (8 reactions) (0 assignees)Julia (619 forks)batch import
discussionenhancementhelp wanted

Métricas do repositório

Stars
 (4.725 stars)
Métricas de merge de PR
 (Mesclagem média 4h 27m) (2 fundiu PRs em 30d)

Description

It'd be nice to be able to write something like

model = @Chain(
  Input(28^2),
  Dense(32, relu),
  Dense(10),
  softmax)

It's a relatively minor convenience but it does avoid some redundancy when specifying chains, which is tedious to correct and easy to get wrong when trying different layer sizes.

Here's roughly how I imagine this working. The @Chain would expand to something like

shape = nothing
layer1, shape = fromshape(Input, shape, 10)
layer2, shape = fromshape(Dense, shape, 32, relu)
...
Chain(layer1, layer2, ...)

fromshape can then forward to an appropriate constructor or error for non-supported layers. Hopefully this strikes the right balance of simplicity/generality and we don't end up having to turn it into a full shape inference system.

Guia do colaborador