HIPS/autograd

support numpy.take

Open

#743 opened on Nov 24, 2025

View on GitHub
 (2 comments) (0 reactions) (0 assignees)Python (6,628 stars) (909 forks)batch import
PR welcomegood first issue

Description

Currently tracing the gradient through take is not supported:

import numpy as np
import autograd as ag
import autograd.numpy as anp

rng = np.random.default_rng(42)
x = rng.uniform(size=(3, 4, 5))
idx = rng.integers(0, 4, size=(6,))

def foo(x, idx):
    # # works:
    # return x[:, idx, :].sum()
    # # doesn't work:
    return anp.take(x, idx, axis=1).sum()

gfoo = ag.grad(foo, argnum=0)
gfoo(x, rng.integers(0, 4, size=(6,))).shape
# NotImplementedError: VJP of take wrt argnums (0,) not defined

take is mostly just a convenient syntax/subset of getitem indexing, so I suppose there is nothing fundamentally blocking it.

Contributor guide