FluxML/Zygote.jl

Zygote.hessian_reverse for BatchNorm broken on julia v1.11

Open

#1531 aperta il 11 ott 2024

Vedi su GitHub
 (1 commento) (0 reazioni) (0 assegnatari)Julia (220 fork)batch import
bugcompilerhelp wanted

Metriche repository

Star
 (1568 star)
Metriche merge PR
 (Nessuna PR mergiata in 30 g)

Descrizione

With Zygote v0.6.71 on julia v1.11 I get the following error:

m = Chain(BatchNorm(3), sum)
x = Float32[1 2; 3 4; 5 6]
Zygote.hessian_reverse(m, x)
ERROR: Compiling Tuple{Zygote.Pullback{Tuple{typeof(Base.sym_in), Symbol, Tuple{Symbol, Symbol}}, Any}, Nothing}: UndefRefError: access to undefined reference
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::Nothing)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [3] Pullback
    @ ./namedtuple.jl:399 [inlined]
  [4] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::Nothing)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:225 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::FillArrays.Fill{Float32, 2, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:354 [inlined]
  [8] _pullback(ctx::Zygote.Context{…}, f::Zygote.Pullback{…}, args::FillArrays.Fill{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/dev/Flux/src/layers/basic.jl:53 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/dev/Flux/src/layers/basic.jl:51 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [13] FluxML/Flux.jl#78
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [15] gradient
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::typeof(gradient), ::Chain{Tuple{BatchNorm{…}, typeof(sum)}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [17] FluxML/Flux.jl#132
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:75 [inlined]
 [18] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#132#133"{Chain{Tuple{…}}}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [19] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:946
 [20] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [22] call_composed
    @ ./operators.jl:1054 [inlined]
 [23] call_composed
    @ ./operators.jl:1053 [inlined]
 [24] #_#113
    @ ./operators.jl:1050 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::Base.var"##_#113", ::@Kwargs{}, ::ComposedFunction{…}, ::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [26] _apply
    @ ./boot.jl:946 [inlined]
 [27] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [29] ComposedFunction
    @ ./operators.jl:1050 [inlined]
 [30] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [31] pullback(f::Function, cx::Zygote.Context{false}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [32] pullback
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88 [inlined]
 [33] withjacobian(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:141
 [34] jacobian
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:128 [inlined]
 [35] hessian_reverse(f::Chain{Tuple{BatchNorm{…}, typeof(sum)}}, x::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:75
 [36] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.

It is specific to BatchNorm, since this works instead

m = Chain(Dense(3=>3), sum)
x = Float32[1 2; 3 4; 5 6]
Zygote.hessian_reverse(m, x)

Found in FluxML/Flux.jl#2492

Guida contributor