bughelp wanted
描述
Hi,
I'm trying to calculate a "masked minimum" of an array, i. e. take the minimum among all values fulfilling a certain condition. On the CPU, this works:
using Zygote
function masked_minimum(x::AbstractArray{T, N}, mask::AbstractArray{T, N}) where {T, N}
x_masked = x[mask .== one(T)]
return minimum(x_masked)
end
A = sprand(200, 100, 0.8)
x = rand(100)
mask = round.(rand(200))
masked_minimum(A * x, mask);
Zygote.gradient(x -> masked_minimum(A * x, mask), x)[1];
However, on the GPU it does not work:
function masked_minimum(x::AbstractArray{T, N}, mask::AbstractArray{T, N}) where {T, N}
# Copy to the CPU to allow for fast indexing.
mask = Array(mask)
x_masked = Array(x)[mask .== one(T)]
return minimum(x_masked)
end
A = cu(sprand(200, 100, 0.8))
x = cu(rand(100))
mask = cu(round.(rand(200)))
masked_minimum(A * x, mask);
Zygote.gradient(x -> masked_minimum(A * x, mask), x)[1]
Error message:
GPU compilation of kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
.1 is of type Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
.x is of type Vector{Float32} which is not isbits.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/validation.jl:88
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:417 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/jgSVI/src/TimerOutput.jl:252 [inlined]
[4] macro expansion
@ ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:416 [inlined]
[5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
@ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/utils.jl:64
[6] cufunction_compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
@ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:354
[7] #224
@ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:347 [inlined]
[8] JuliaContext(f::CUDA.var"#224#225"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#17", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}})
@ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:76
[9] cufunction_compile(job::GPUCompiler.CompilerJob)
@ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:346
[10] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/cache.jl:90
[11] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:299
[12] cufunction
@ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:293 [inlined]
[13] macro expansion
@ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:102 [inlined]
[14] #launch_heuristic#248
@ ~/.julia/packages/CUDA/DfvRa/src/gpuarrays.jl:17 [inlined]
[15] _copyto!
@ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:63 [inlined]
[16] copyto!
@ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:46 [inlined]
[17] copy
@ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:37 [inlined]
[18] materialize
@ ./broadcast.jl:860 [inlined]
[19] broadcast(::typeof(*), ::Vector{Float32}, ::Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Base.Broadcast ./broadcast.jl:798
[20] *
@ <redacted>/julia-1.7.3/share/julia/stdlib/v1.7/LinearAlgebra/src/adjtrans.jl:297 [inlined]
[21] #1404
@ ~/.julia/packages/ChainRules/EyLkg/src/rulesets/Base/arraymath.jl:36 [inlined]
[22] unthunk
@ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
[23] wrap_chainrules_output
@ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:104 [inlined]
[24] map
@ ./tuple.jl:223 [inlined]
[25] wrap_chainrules_output
@ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:105 [inlined]
[26] ZBack
@ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:205 [inlined]
[27] Pullback
@ ./In[281]:13 [inlined]
[28] (::typeof(∂(#351)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
[29] (::Zygote.var"#60#61"{typeof(∂(#351))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:41
[30] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:76
[31] top-level scope
@ In[281]:13
[32] eval
@ ./boot.jl:373 [inlined]
[33] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1196
Is this a bug or am I doing something wrong?