pyro-ppl/pyro

FR: a MaskedConstraint to constrain MaskedDistribution

Open

#2,801 创建于 2021年4月9日

在 GitHub 查看
 (1 评论) (1 反应) (0 负责人)Python (8,211 star) (981 fork)batch import
discussionenhancementhelp wanted

描述

Addresses this forum post

We could implement a constraints.masked(base_constraint, mask)

class MaskedConstraint(Constraint):
    def __init__(self, base_constraint, mask):
        self.base_constraint = base_constraint
        self.mask = mask
        super().__init__()
    def check(self, value):
        return self.base_constraint.check(value) | ~self.mask

Then we could use this in MaskedDistribution via

class MaskedDistribution(...):
    ...
    @lazy_property
    def support(self):
        return constraints.masked(self.base_dist.support, self.mask)

To fully address the forum post, we would also need to handle poutine.mask, e.g. by replacing poutine.mask with MaskedDistribution under the hood, as in #2284

贡献者指南