pymc-devs/pymc
Voir sur GitHubAdd rewrite for Mixture when `comp_dists` can be "fused"
Open
#6 803 ouverte le 29 juin 2023
feature requesthelp wantedlogprob
Métriques du dépôt
- Stars
- (7 926 stars)
- Métriques de merge PR
- (Merge moyen 11j 6h) (12 PRs mergées en 30 j)
Description
Description
The following distributions are equivalent:
import pymc as pm
pm.Mixture.dist(w=[0.5, 0.5], comp_dists=[pm.Normal.dist(-1), pm.Normal.dist(1)])
pm.Mixture.dist(w=[0.5, 0.5], comp_dists=pm.Normal.dist([-1, 1]))
But the second one is more efficient, because the logp is vectorized among a single batched Normal.
We could add a rewrite in the logprob_rewrites to convert the former to the second, so that users are not penalized from using the first form (which may be more intuitive for some).
Actually that sort of rewrite stack([rv1, rv2]) -> rv3 could be useful in many places in the logprob submodule