pymc-devs/pymc

Add rewrite for Mixture when `comp_dists` can be "fused"

Open

#6803 aperta il 29 giu 2023

Vedi su GitHub
 (0 commenti) (1 reazione) (0 assegnatari)Python (1902 fork)batch import
feature requesthelp wantedlogprob

Metriche repository

Star
 (7926 star)
Metriche merge PR
 (Merge medio 11g 6h) (12 PR mergiate in 30 g)

Descrizione

Description

The following distributions are equivalent:

import pymc as pm

pm.Mixture.dist(w=[0.5, 0.5], comp_dists=[pm.Normal.dist(-1), pm.Normal.dist(1)])
pm.Mixture.dist(w=[0.5, 0.5], comp_dists=pm.Normal.dist([-1, 1]))

But the second one is more efficient, because the logp is vectorized among a single batched Normal.

We could add a rewrite in the logprob_rewrites to convert the former to the second, so that users are not penalized from using the first form (which may be more intuitive for some).

Actually that sort of rewrite stack([rv1, rv2]) -> rv3 could be useful in many places in the logprob submodule

Guida contributor