pyro-ppl/pyro

FR: a MaskedConstraint to constrain MaskedDistribution

Open

#2801 aperta il 9 apr 2021

Vedi su GitHub
 (1 commento) (1 reazione) (0 assegnatari)Python (981 fork)batch import
discussionenhancementhelp wanted

Metriche repository

Star
 (8211 star)
Metriche merge PR
 (Merge medio 10g 19h) (1 PR mergiata in 30 g)

Descrizione

Addresses this forum post

We could implement a constraints.masked(base_constraint, mask)

class MaskedConstraint(Constraint):
    def __init__(self, base_constraint, mask):
        self.base_constraint = base_constraint
        self.mask = mask
        super().__init__()
    def check(self, value):
        return self.base_constraint.check(value) | ~self.mask

Then we could use this in MaskedDistribution via

class MaskedDistribution(...):
    ...
    @lazy_property
    def support(self):
        return constraints.masked(self.base_dist.support, self.mask)

To fully address the forum post, we would also need to handle poutine.mask, e.g. by replacing poutine.mask with MaskedDistribution under the hood, as in #2284

Guida contributor