help wantedneeds adjoint
描述
I'm not sure if I should be filing this here or in CUDA, but doing temporary @allowscalars lead to errors when taking a gradient.
MWE:
using Zygote, CUDA
CUDA.allowscalar(false)
f(x) = CUDA.@allowscalar x[3]
gradient(f, cu(randn(10, 5)))
and the resulting error
ERROR: Compiling Tuple{typeof(task_local_storage), var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Symbol, Bool}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/Lw5Kf/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/reverse.jl:315
[5] _generate_pullback_via_decomposition(T::Type)
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/emit.jl:101
[6] #s2989#1184
@ ~/.julia/packages/Zygote/Lw5Kf/src/compiler/interface2.jl:28 [inlined]
[7] var"#s2989#1184"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] macro expansion
@ ~/.julia/packages/GPUArrays/3sW6s/src/host/indexing.jl:74 [inlined]
[10] _pullback
@ ./REPL[2]:1 [inlined]
[11] _pullback(ctx::Zygote.Context, f::typeof(f), args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/interface2.jl:0
[12] _pullback(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/interface.jl:34
[13] pullback(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/interface.jl:40
[14] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/Lw5Kf/src/compiler/interface.jl:75
[15] top-level scope
@ REPL[11]:1
[16] top-level scope
@ ~/.julia/packages/CUDA/YpW0k/src/initialization.jl:52
This may or may not be related to https://github.com/FluxML/Zygote.jl/issues/1070