pyro-ppl/pyro
Ver no GitHubFR: a MaskedConstraint to constrain MaskedDistribution
Open
#2.801 aberto em 9 de abr. de 2021
discussionenhancementhelp wanted
Métricas do repositório
- Stars
- (8.211 stars)
- Métricas de merge de PR
- (Mesclagem média 10d 19h) (1 fundiu PR em 30d)
Description
Addresses this forum post
We could implement a constraints.masked(base_constraint, mask)
class MaskedConstraint(Constraint):
def __init__(self, base_constraint, mask):
self.base_constraint = base_constraint
self.mask = mask
super().__init__()
def check(self, value):
return self.base_constraint.check(value) | ~self.mask
Then we could use this in MaskedDistribution via
class MaskedDistribution(...):
...
@lazy_property
def support(self):
return constraints.masked(self.base_dist.support, self.mask)
To fully address the forum post, we would also need to handle poutine.mask, e.g. by replacing poutine.mask with MaskedDistribution under the hood, as in #2284