HIPS/autograd

support numpy.take

Open

#743 创建于 2025年11月24日

在 GitHub 查看
 (2 评论) (0 反应) (0 负责人)Python (6,628 star) (909 fork)batch import
PR welcomegood first issue

描述

Currently tracing the gradient through take is not supported:

import numpy as np
import autograd as ag
import autograd.numpy as anp

rng = np.random.default_rng(42)
x = rng.uniform(size=(3, 4, 5))
idx = rng.integers(0, 4, size=(6,))

def foo(x, idx):
    # # works:
    # return x[:, idx, :].sum()
    # # doesn't work:
    return anp.take(x, idx, axis=1).sum()

gfoo = ag.grad(foo, argnum=0)
gfoo(x, rng.integers(0, 4, size=(6,))).shape
# NotImplementedError: VJP of take wrt argnums (0,) not defined

take is mostly just a convenient syntax/subset of getitem indexing, so I suppose there is nothing fundamentally blocking it.

贡献者指南

support numpy.take · HIPS/autograd#743 | Good First Issue