pyro-ppl/pyro

FR: a MaskedConstraint to constrain MaskedDistribution

Open

#2 801 ouverte le 9 avr. 2021

Voir sur GitHub
 (1 commentaire) (1 réaction) (0 assignés)Python (981 forks)batch import
discussionenhancementhelp wanted

Métriques du dépôt

Stars
 (8 211 stars)
Métriques de merge PR
 (Merge moyen 10j 19h) (1 PR mergée en 30 j)

Description

Addresses this forum post

We could implement a constraints.masked(base_constraint, mask)

class MaskedConstraint(Constraint):
    def __init__(self, base_constraint, mask):
        self.base_constraint = base_constraint
        self.mask = mask
        super().__init__()
    def check(self, value):
        return self.base_constraint.check(value) | ~self.mask

Then we could use this in MaskedDistribution via

class MaskedDistribution(...):
    ...
    @lazy_property
    def support(self):
        return constraints.masked(self.base_dist.support, self.mask)

To fully address the forum post, we would also need to handle poutine.mask, e.g. by replacing poutine.mask with MaskedDistribution under the hood, as in #2284

Guide contributeur