pyro-ppl/pyro

FR: a MaskedConstraint to constrain MaskedDistribution

Open

#2.801 aberto em 9 de abr. de 2021

Ver no GitHub
 (1 comment) (1 reaction) (0 assignees)Python (981 forks)batch import
discussionenhancementhelp wanted

Métricas do repositório

Stars
 (8.211 stars)
Métricas de merge de PR
 (Mesclagem média 10d 19h) (1 fundiu PR em 30d)

Description

Addresses this forum post

We could implement a constraints.masked(base_constraint, mask)

class MaskedConstraint(Constraint):
    def __init__(self, base_constraint, mask):
        self.base_constraint = base_constraint
        self.mask = mask
        super().__init__()
    def check(self, value):
        return self.base_constraint.check(value) | ~self.mask

Then we could use this in MaskedDistribution via

class MaskedDistribution(...):
    ...
    @lazy_property
    def support(self):
        return constraints.masked(self.base_dist.support, self.mask)

To fully address the forum post, we would also need to handle poutine.mask, e.g. by replacing poutine.mask with MaskedDistribution under the hood, as in #2284

Guia do colaborador