FluxML/Zygote.jl

Strip zygote frames from mutation error stack trace

Open

#1,501 opened on Feb 16, 2024

View on GitHub
 (1 comment) (0 reactions) (0 assignees)Julia (1,568 stars) (220 forks)batch import
compilerhelp wanted

Description

Motivation and description

When differentiating something complicated which contains mutation, it can be hard to know exactly where the mutation is. In this example, the mutation is tucked away inside the ComponentArray constructor, and in a larger example (e.g. https://github.com/DARPA-ASKEM/sciml-service/issues/141) it might be hard to figure that out.

It would be very helpful if the stack trace provided the exact location of the mutation that triggers this error, rather than interleaving that stack trace with zygote frames. Failing that, it would at least by nice to inform the user that they should look at every third frame to figure out where in their code the mutation is.

julia> using ComponentArrays, Zygote

julia> function f(x)
           ca = ComponentArray(var=x)
           ca.var
       end
f (generic function with 1 method)

julia> Zygote.jacobian(f, [1,2,3])
ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70
  [3] (::Zygote.var"#547#548"{Vector{Any}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:89
  [4] (::Zygote.var"#2643#back#549"{Zygote.var"#547#548"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] merge
    @ ./namedtuple.jl:371 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(merge), @NamedTuple{}, Base.Generator{…}}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [7] make_idx
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:170 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [9] make_carray_args
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:151 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [11] make_carray_args
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:144 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [13] ComponentArray
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:64 [inlined]
 [14] #ComponentArray#21
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:67 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ComponentVector{Int64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [16] ComponentArray
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:67 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ComponentVector{Int64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [18] f
    @ ./REPL[2]:2 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [20] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [21] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [22] call_composed
    @ ./operators.jl:1045 [inlined]
 [23] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{…}, Tuple{…}, @Kwargs{}}, Any})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [24] call_composed
    @ ./operators.jl:1044 [inlined]
 [25] #_#103
    @ ./operators.jl:1041 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [28] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [29] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [30] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [32] withjacobian(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:150
 [33] jacobian(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:128
 [34] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.

Contributor guide