pymc-devs/pymc

Use same API for defining internal and external Nuts kwargs

Open

#6.757 aberto em 7 de jun. de 2023

Ver no GitHub
 (2 comments) (0 reactions) (0 assignees)Python (1.902 forks)batch import
bughelp wantedjax

Métricas do repositório

Stars
 (7.926 stars)
Métricas de merge de PR
 (Mesclagem média 11d 6h) (12 fundiu PRs em 30d)

Description

Description

User on discourse reported:

How can I set the maximum tree depth for the NUTS method from the numpyro library? The way described in the test file test_mcmc_external.py doesn’t work:

import pymc as pm
import numpy as np

with pm.Model():
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "numpyro",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.

I don't know if nuts should be converted to nuts_kwargs, but even if a user were to pass nuts_kwargs to sample, those wouldn't make it to the sample_numpyro_nuts function because we drop arbitrary kwargs passed here:

https://github.com/pymc-devs/pymc/blob/261862d778910a09c5b61edcc66958519a86815e/pymc/sampling/mcmc.py#L252

Guia do colaborador