pyro-ppl/pyro

Normalization of Weights within SMCFilter class

Open

#3,382 opened on 2024年7月12日

GitHub で見る
 (2 comments) (0 reactions) (0 assignees)Python (8,211 stars) (981 forks)batch import
enhancementhelp wanted

説明

The weights reported by the get_emperical method in the SMCFilter class will usually not add up to one. I think it would be nice to have the option to directly get the normalized weights (maybe that should also be the default) to comply with the literature (see for example An Introduction to Sequential Monte Carlo , page 130). The weights are actually normalized in this line. However, the result is saved in a local variable instead of updating the state weights variable. It seems like the values of the log_weights variable are only updated to be in a range between 0 and 1 using this command. What is the reason behind not directly normalizing the weights variable?

The easiest fix for the issue would be to add an optional argument to the get_emperical function that if true normalizes the weights, e.g.:

def get_empirical(self, normalize_weights=True):
        """
        :param bool normalize_weights: If True, normalize the log weights before creating the empirical distribution.
        :returns: a marginal distribution over all state tensors.
        :rtype: a dictionary with keys which are latent variables and values
            which are :class:`~pyro.distributions.Empirical` objects.
        """
        if normalize_weights:
            # Normalize the log weights
            log_weights = self.state._log_weights - self.state._log_weights.logsumexp(-1)
        else:
            log_weights = self.state._log_weights

        return {
            key: dist.Empirical(value, log_weights)
            for key, value in self.state.items()
        }

コントリビューターガイド