pyro-ppl/pyro

PyTorch 1.10 throws new jit errors

Open

#2,965 opened on 2021年11月9日

GitHub で見る
 (1 comment) (2 reactions) (0 assignees)Python (8,211 stars) (981 forks)batch import
bughelp wanted

説明

This issue tracks new errors in the PyTorch 1.10 jit, e.g.

pytest tests/infer/test_jit.py::test_dirichlet_bernoulli -k Jit -vx --runxfail
__ test_dirichlet_bernoulli[JitTraceEnum_ELBO-False] __

Elbo = <class 'pyro.infer.traceenum_elbo.JitTraceEnum_ELBO'>, vectorized = False

    @pytest.mark.parametrize("vectorized", [False, True])
    @pytest.mark.parametrize(
        "Elbo",
        [
            TraceEnum_ELBO,
            JitTraceEnum_ELBO,
        ],
    )
    def test_dirichlet_bernoulli(Elbo, vectorized):
        pyro.clear_param_store()
        data = torch.tensor([1.0] * 6 + [0.0] * 4)

        def model1(data):
            concentration0 = constant([10.0, 10.0])
            f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
            for i in pyro.plate("plate", len(data)):
                pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

        def model2(data):
            concentration0 = constant([10.0, 10.0])
            f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
            pyro.sample(
                "obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data
            )

        model = model2 if vectorized else model1

        def guide(data):
            concentration_q = pyro.param(
                "concentration_q", constant([15.0, 15.0]), constraint=constraints.positive
            )
            pyro.sample("latent_fairness", dist.Dirichlet(concentration_q))

        elbo = Elbo(
            num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True
        )
        optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)})
        svi = SVI(model, guide, optim, elbo)
        for step in range(40):
>           svi.step(data)

tests/infer/test_jit.py:462:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/svi.py:145: in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:564: in loss_and_grads
    differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:561: in differentiable_loss
    return self._differentiable_loss(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <pyro.ops.jit.CompiledFunction object at 0x7ff5225e9400>
args = (tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]),)
kwargs = {'_guide_id': 140690819139104, '_model_id': 140690812334288}
key = (1, (('_guide_id', 140690819139104), ('_model_id', 140690812334288)))
unconstrained_params = [tensor([2.7072, 2.7090], requires_grad=True)]
params_and_args = [tensor([2.7072, 2.7090], requires_grad=True), tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.])]
param_capture = <pyro.poutine.trace_messenger.TraceMessenger object at 0x7ff5225c4898>

    def __call__(self, *args, **kwargs):
        key = _hashable_args_kwargs(args, kwargs)

        # if first time
        if key not in self.compiled:
            # param capture
            with poutine.block():
                with poutine.trace(param_only=True) as first_param_capture:
                    self.fn(*args, **kwargs)

            self._param_names = list(set(first_param_capture.trace.nodes.keys()))
            unconstrained_params = tuple(
                pyro.param(name).unconstrained() for name in self._param_names
            )
            params_and_args = unconstrained_params + args
            weakself = weakref.ref(self)

            def compiled(*params_and_args):
                self = weakself()
                unconstrained_params = params_and_args[: len(self._param_names)]
                args = params_and_args[len(self._param_names) :]
                constrained_params = {}
                for name, unconstrained_param in zip(
                    self._param_names, unconstrained_params
                ):
                    constrained_param = pyro.param(
                        name
                    )  # assume param has been initialized
                    assert constrained_param.unconstrained() is unconstrained_param
                    constrained_params[name] = constrained_param
                return poutine.replay(self.fn, params=constrained_params)(
                    *args, **kwargs
                )

            if self.ignore_warnings:
                compiled = ignore_jit_warnings()(compiled)
            with pyro.validation_enabled(False):
                time_compilation = self.jit_options.pop("time_compilation", False)
                with optional(timed(), time_compilation) as t:
                    self.compiled[key] = torch.jit.trace(
                        compiled, params_and_args, **self.jit_options
                    )
                if time_compilation:
                    self.compile_time = t.elapsed
        else:
            unconstrained_params = [
                # FIXME this does unnecessary transform work
                pyro.param(name).unconstrained() for name in self._param_names
            ]
            params_and_args = unconstrained_params + list(args)

        with poutine.block(hide=self._param_names):
            with poutine.trace(param_only=True) as param_capture:
>               ret = self.compiled[key](*params_and_args)
E               RuntimeError: The following operation failed in the TorchScript interpreter.
E               Traceback of TorchScript (most recent call last):
E               RuntimeError: Unsupported value kind: Tensor

pyro/ops/jit.py:121: RuntimeError

It looks like some inserted constant tensor is failing insertableTensor by requiring grad. I've spent a couple hours debugging but haven't been able to isolate the error.

コントリビューターガイド