pyro-ppl/pyro

[FR] Predictive with deterministic site in the guide

Open

#3.358 aberto em 19 de abr. de 2024

Ver no GitHub
 (6 comments) (0 reactions) (0 assignees)Python (981 forks)batch import
enhancementhelp wanted

Métricas do repositório

Stars
 (8.211 stars)
Métricas de merge de PR
 (Mesclagem média 10d 19h) (1 fundiu PR em 30d)

Description

Hi,

I'm working on a project where we would like to access the output of an NN in the guide when using Predictive. We've implemented it using a deterministic site in the guide. The program boils down to the following.

import pyro
from pyro.infer import Predictive
from pyro.distributions import Normal
import torch

def model():
    pyro.deterministic('m_deter', torch.tensor(1.))
    pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))

def guide():
    pyro.deterministic('g_deter', torch.tensor(1.))
    pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))


Predictive(
  model=model, 
  guide=guide, 
  return_sites=('model_site', 'guide_site', 'x'), 
  num_samples=1)() # Includes m_deter but not g_deter

We would like for both m_deter and g_deter to be included. It looks like Predictive currently only considers model sites for return sites. Would it be possible to expand it so we can include deterministic sites from the guide?

Guia do colaborador