FluxML/Zygote.jl

Deepcopy of Dictionaries in combination with mutating leads to errors.

Open

#1,233 opened on Jun 1, 2022

View on GitHub
 (3 comments) (0 reactions) (0 assignees)Julia (1,568 stars) (220 forks)batch import
bugdictionaryhelp wanted

Description

Hey, I discovered a strange behavior which I am not sure it is intended in that way.

Consider the following example code:

function mutating1(d, a)
	upd = d[1] * a
	d[1] = upd
	return d
end

function mutating2(d,a)
	nd = deepcopy(d)
		
	upd = d[1] * a
	nd[1] = upd
	return nd
end

function mutating3(d,a)
	nd = deepcopy(d)
		
	upd = nd[1] * a
	d[1] = upd
	return d
end

function nrmsq(d)
	res = norm(d[1]*d[2])
	return res
end

d = 3
arry = map(_ -> randn(d,d), 1:2)
dic = Dict{Any, Any}()
for (jj, tm) in enumerate(arry)
    dic[jj] = tm
end

dicco = deepcopy(dic)
f1 = x -> nrmsq(mutating1(x, 5))
f2 = x -> nrmsq(mutating2(x, 5))
f3 = x -> nrmsq(mutating3(x, 5))

# all three are equal, make sure f1 which modifies the input is called last
# same holds for f3, so make sure to call it on an copy.
f2(dic) |> display
f3(dicco) |> display
f1(dic) |> display

# works
g1 = gradient(f1, dic)[1]

# do not work
g2 = gradient(f2, dic)[1]
g3 =gradient(f3, dic)[1]

The versions utilizing f2 and f3 are leading to strange AssertionError: a === b from accum while the version with f1 works flawless.

Actually I have no idea how to deal with this and how to get an workaround.

Thanks for help in advance!

Contributor guide