FluxML/Zygote.jl

Increasing memory usage in each call of gradient

Open

#1509 aperta il 30 mar 2024

Vedi su GitHub
 (1 commento) (1 reazione) (0 assegnatari)Julia (220 fork)batch import
compilerhelp wantedperformance

Metriche repository

Star
 (1568 star)
Metriche merge PR
 (Nessuna PR mergiata in 30 g)

Descrizione

I was experimenting with alternatives to https://github.com/FluxML/Optimisers.jl/pull/57 when I encountered the following weird issue.

Look at the number of allocations when computing the gradient of loss1

function loss1(m)
    ls = 0f0
    for l in Functors.fleaves(m)
        if l isa AbstractArray{<:Number}
            ls += sum(l)
        end
    end
    return ls
end

function loss2(m)
    sum(sum(l) for l in Functors.fleaves(m) if l isa AbstractArray{<:Number})
end

function loss3(m)
    sum([sum(l) for l in Functors.fleaves(m) if l isa AbstractArray{<:Number}])
end


function perf()
    m = Chain(Dense(128 => 128, relu), BatchNorm(3), Dense(128 => 10))
    @btime gradient(loss1, $m)[1]
    @btime gradient(loss2, $m)[1]
    @btime gradient(loss3, $m)[1]
    println()
end

perf(); #1st call
perf(); #2nd call
perf(); #3rd call
# OUTPUT
154.795 ms (1022652 allocations: 39.16 MiB)
1.734 ms (7605 allocations: 352.62 KiB)
1.314 ms (5948 allocations: 288.08 KiB)

258.556 ms (1658450 allocations: 63.37 MiB)
1.735 ms (7605 allocations: 352.62 KiB)
1.316 ms (5948 allocations: 288.08 KiB)

336.418 ms (2154374 allocations: 82.29 MiB)
1.739 ms (7605 allocations: 352.62 KiB)
1.319 ms (5948 allocations: 288.08 KiB)

What's going on?

Guida contributor