FluxML/Zygote.jl

Zygote.hessian_reverse for BatchNorm broken on julia v1.11

Open

#1.531 aberto em 11 de out. de 2024

Ver no GitHub
 (1 comment) (0 reactions) (0 assignees)Julia (220 forks)batch import
bugcompilerhelp wanted

Métricas do repositório

Stars
 (1.568 stars)
Métricas de merge de PR
 (Nenhuma PRs mesclada em 30d)

Description

With Zygote v0.6.71 on julia v1.11 I get the following error:

m = Chain(BatchNorm(3), sum)
x = Float32[1 2; 3 4; 5 6]
Zygote.hessian_reverse(m, x)
ERROR: Compiling Tuple{Zygote.Pullback{Tuple{typeof(Base.sym_in), Symbol, Tuple{Symbol, Symbol}}, Any}, Nothing}: UndefRefError: access to undefined reference
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::Nothing)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [3] Pullback
    @ ./namedtuple.jl:399 [inlined]
  [4] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::Nothing)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:225 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::FillArrays.Fill{Float32, 2, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:354 [inlined]
  [8] _pullback(ctx::Zygote.Context{…}, f::Zygote.Pullback{…}, args::FillArrays.Fill{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/dev/Flux/src/layers/basic.jl:53 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/dev/Flux/src/layers/basic.jl:51 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [13] FluxML/Flux.jl#78
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [15] gradient
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::typeof(gradient), ::Chain{Tuple{BatchNorm{…}, typeof(sum)}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [17] FluxML/Flux.jl#132
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:75 [inlined]
 [18] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#132#133"{Chain{Tuple{…}}}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [19] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:946
 [20] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [22] call_composed
    @ ./operators.jl:1054 [inlined]
 [23] call_composed
    @ ./operators.jl:1053 [inlined]
 [24] #_#113
    @ ./operators.jl:1050 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::Base.var"##_#113", ::@Kwargs{}, ::ComposedFunction{…}, ::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [26] _apply
    @ ./boot.jl:946 [inlined]
 [27] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [29] ComposedFunction
    @ ./operators.jl:1050 [inlined]
 [30] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [31] pullback(f::Function, cx::Zygote.Context{false}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [32] pullback
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88 [inlined]
 [33] withjacobian(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:141
 [34] jacobian
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:128 [inlined]
 [35] hessian_reverse(f::Chain{Tuple{BatchNorm{…}, typeof(sum)}}, x::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:75
 [36] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.

It is specific to BatchNorm, since this works instead

m = Chain(Dense(3=>3), sum)
x = Float32[1 2; 3 4; 5 6]
Zygote.hessian_reverse(m, x)

Found in FluxML/Flux.jl#2492

Guia do colaborador