FluxML/Zygote.jl

Increasing memory usage in each call of gradient

Open

#1 509 ouverte le 30 mars 2024

Voir sur GitHub
 (1 commentaire) (1 réaction) (0 assignés)Julia (220 forks)batch import
compilerhelp wantedperformance

Métriques du dépôt

Stars
 (1 568 stars)
Métriques de merge PR
 (Aucune PR mergée en 30 j)

Description

I was experimenting with alternatives to https://github.com/FluxML/Optimisers.jl/pull/57 when I encountered the following weird issue.

Look at the number of allocations when computing the gradient of loss1

function loss1(m)
    ls = 0f0
    for l in Functors.fleaves(m)
        if l isa AbstractArray{<:Number}
            ls += sum(l)
        end
    end
    return ls
end

function loss2(m)
    sum(sum(l) for l in Functors.fleaves(m) if l isa AbstractArray{<:Number})
end

function loss3(m)
    sum([sum(l) for l in Functors.fleaves(m) if l isa AbstractArray{<:Number}])
end


function perf()
    m = Chain(Dense(128 => 128, relu), BatchNorm(3), Dense(128 => 10))
    @btime gradient(loss1, $m)[1]
    @btime gradient(loss2, $m)[1]
    @btime gradient(loss3, $m)[1]
    println()
end

perf(); #1st call
perf(); #2nd call
perf(); #3rd call
# OUTPUT
154.795 ms (1022652 allocations: 39.16 MiB)
1.734 ms (7605 allocations: 352.62 KiB)
1.314 ms (5948 allocations: 288.08 KiB)

258.556 ms (1658450 allocations: 63.37 MiB)
1.735 ms (7605 allocations: 352.62 KiB)
1.316 ms (5948 allocations: 288.08 KiB)

336.418 ms (2154374 allocations: 82.29 MiB)
1.739 ms (7605 allocations: 352.62 KiB)
1.319 ms (5948 allocations: 288.08 KiB)

What's going on?

Guide contributeur