pymc-devs/pymc

Use same API for defining internal and external Nuts kwargs

Open

#6757 aperta il 7 giu 2023

Vedi su GitHub
 (2 commenti) (0 reazioni) (0 assegnatari)Python (1902 fork)batch import
bughelp wantedjax

Metriche repository

Star
 (7926 star)
Metriche merge PR
 (Merge medio 11g 6h) (12 PR mergiate in 30 g)

Descrizione

Description

User on discourse reported:

How can I set the maximum tree depth for the NUTS method from the numpyro library? The way described in the test file test_mcmc_external.py doesn’t work:

import pymc as pm
import numpy as np

with pm.Model():
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "numpyro",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.

I don't know if nuts should be converted to nuts_kwargs, but even if a user were to pass nuts_kwargs to sample, those wouldn't make it to the sample_numpyro_nuts function because we drop arbitrary kwargs passed here:

https://github.com/pymc-devs/pymc/blob/261862d778910a09c5b61edcc66958519a86815e/pymc/sampling/mcmc.py#L252

Guida contributor