FluxML/Zygote.jl

Zygote.hessian_reverse for BatchNorm broken on julia v1.11

Open

#1 531 ouverte le 11 oct. 2024

Voir sur GitHub
 (1 commentaire) (0 réactions) (0 assignés)Julia (220 forks)batch import
bugcompilerhelp wanted

Métriques du dépôt

Stars
 (1 568 stars)
Métriques de merge PR
 (Aucune PR mergée en 30 j)

Description

With Zygote v0.6.71 on julia v1.11 I get the following error:

m = Chain(BatchNorm(3), sum)
x = Float32[1 2; 3 4; 5 6]
Zygote.hessian_reverse(m, x)
ERROR: Compiling Tuple{Zygote.Pullback{Tuple{typeof(Base.sym_in), Symbol, Tuple{Symbol, Symbol}}, Any}, Nothing}: UndefRefError: access to undefined reference
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::Nothing)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [3] Pullback
    @ ./namedtuple.jl:399 [inlined]
  [4] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::Nothing)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:225 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Any}, args::FillArrays.Fill{Float32, 2, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/dev/Flux/src/layers/normalise.jl:354 [inlined]
  [8] _pullback(ctx::Zygote.Context{…}, f::Zygote.Pullback{…}, args::FillArrays.Fill{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/dev/Flux/src/layers/basic.jl:53 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/dev/Flux/src/layers/basic.jl:51 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [13] FluxML/Flux.jl#78
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [15] gradient
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::typeof(gradient), ::Chain{Tuple{BatchNorm{…}, typeof(sum)}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [17] FluxML/Flux.jl#132
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:75 [inlined]
 [18] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#132#133"{Chain{Tuple{…}}}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [19] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:946
 [20] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [22] call_composed
    @ ./operators.jl:1054 [inlined]
 [23] call_composed
    @ ./operators.jl:1053 [inlined]
 [24] #_#113
    @ ./operators.jl:1050 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::Base.var"##_#113", ::@Kwargs{}, ::ComposedFunction{…}, ::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [26] _apply
    @ ./boot.jl:946 [inlined]
 [27] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [29] ComposedFunction
    @ ./operators.jl:1050 [inlined]
 [30] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [31] pullback(f::Function, cx::Zygote.Context{false}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [32] pullback
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88 [inlined]
 [33] withjacobian(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:141
 [34] jacobian
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:128 [inlined]
 [35] hessian_reverse(f::Chain{Tuple{BatchNorm{…}, typeof(sum)}}, x::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/grad.jl:75
 [36] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.

It is specific to BatchNorm, since this works instead

m = Chain(Dense(3=>3), sum)
x = Float32[1 2; 3 4; 5 6]
Zygote.hessian_reverse(m, x)

Found in FluxML/Flux.jl#2492

Guide contributeur