pyro-ppl/pyro

FR: a MaskedConstraint to constrain MaskedDistribution

Open

#2,801 opened on Apr 9, 2021

View on GitHub
 (1 comment) (1 reaction) (0 assignees)Python (8,211 stars) (981 forks)batch import
discussionenhancementhelp wanted

Description

Addresses this forum post

We could implement a constraints.masked(base_constraint, mask)

class MaskedConstraint(Constraint):
    def __init__(self, base_constraint, mask):
        self.base_constraint = base_constraint
        self.mask = mask
        super().__init__()
    def check(self, value):
        return self.base_constraint.check(value) | ~self.mask

Then we could use this in MaskedDistribution via

class MaskedDistribution(...):
    ...
    @lazy_property
    def support(self):
        return constraints.masked(self.base_dist.support, self.mask)

To fully address the forum post, we would also need to handle poutine.mask, e.g. by replacing poutine.mask with MaskedDistribution under the hood, as in #2284

Contributor guide

FR: a MaskedConstraint to constrain MaskedDistribution · pyro-ppl/pyro#2801 | Good First Issue