pymc-devs/pymc

Use same API for defining internal and external Nuts kwargs

Open

#6 757 ouverte le 7 juin 2023

Voir sur GitHub
 (2 commentaires) (0 réactions) (0 assignés)Python (1 902 forks)batch import
bughelp wantedjax

Métriques du dépôt

Stars
 (7 926 stars)
Métriques de merge PR
 (Merge moyen 11j 6h) (12 PRs mergées en 30 j)

Description

Description

User on discourse reported:

How can I set the maximum tree depth for the NUTS method from the numpyro library? The way described in the test file test_mcmc_external.py doesn’t work:

import pymc as pm
import numpy as np

with pm.Model():
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "numpyro",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.

I don't know if nuts should be converted to nuts_kwargs, but even if a user were to pass nuts_kwargs to sample, those wouldn't make it to the sample_numpyro_nuts function because we drop arbitrary kwargs passed here:

https://github.com/pymc-devs/pymc/blob/261862d778910a09c5b61edcc66958519a86815e/pymc/sampling/mcmc.py#L252

Guide contributeur