Delayed sampling
Automatic marginalization and conditioning are implemented using a heuristic called delayed sampling1. This has a small limitation: it supports marginalization through chains of random variables, but not through trees of random variables. Trees are reduced to chains by simulation where necessary.
Consider:
x ~ Gamma(2.0, 1.0);
y ~ Poisson(x);
z ~ Poisson(x);
This is a tree (the most basic), so delayed sampling cannot maintain a marginal distribution for all three random variables (jointly). Instead—this is the heuristic—it reduces the tree to a chain by simulating y
. This proceeds as follows:
-
x ~ Gamma(1.0, 2.0);
x
is trivially marginalized out..-. | x | '-' -
y ~ Poisson(x);
x
is marginalized out,y
is trivially marginalized out. There is a single chain of marginalized variables..-. | x | +-' / v .-+ | y | '-' -
z ~ Poisson(x);
Addingz
to the graph would create a tree of marginalized variables. To avoid this, the delayed sampling heuristic simulates a value fory
(now depicted as a square rather than circle below). This will break the first branch of the tree that would otherwise be formed. Nowx
is still marginalized out, but conditioned ony
..-. | x | +-' / v .---. | y | '---' -
Finally, the new node is added;
x
remains marginalized out,y
is simulated,z
is trivially marginalized out. There is only a single chain of marginalized variables..-. | x | +-+ / \ v v .---. +-. | y | | z | '---' '-'
Delayed sampling over Gaussian variables yields some common use-cases without explicit coding. Consider the following linear-Gaussian state-space model:
x[1] ~ Gaussian(0.0, 4.0);
y[1] ~ Gaussian(b*x[1], 1.0);
for t in 2..4 {
x[t] ~ Gaussian(a*x[t - 1], 4.0);
y[t] ~ Gaussian(b*x[t], 1.0);
}
x
and y
are latent (typically, for such a model, the x
are latent while the y
are observed). We see the previous steps repeated on each iteration of the loop:
-
At the start of the
t
th iteration, bothx[t-1]
andy[t-1]
are marginalized out.****************************************** * .----. .----. * ╌╌╌╌▶|x[t-2]+---->|x[t-1]| * '-+--' '-+--' * | | * | | * v v * .--+---. .-+--. * |y[t-2]| |y[t-1]| * '------' '----' ****************************************** -
x[t] ~ Gaussian(a*x[t - 1], 4.0);
Addingx[t]
to the graph would create a tree of marginalized variables, soy[t-1]
is simulated to reduce the tree to a chain.****************************************** * .----. .----. * ╌╌╌╌▶|x[t-2]+---->|x[t-1]| * '-+--' '-+--' * | | * | | * v v * .--+---. .--+---. * |y[t-2]| |y[t-1]| * '------' '------' ****************************************** -
Now the new node for the latent state is added;
x[t-1]
andx[t]
remain marginalized out.****************************************** * .----. .----. .----. * ╌╌╌╌▶|x[t-2]+---->|x[t-1]+---->| x[t] | * '-+--' '-+--' '----' * | | * | | * v v * .--+---. .--+---. * |y[t-2]| |y[t-1]| * '------' '------' ****************************************** -
Finally the new node for the observation is added;
x[t-1]
,x[t]
andy[t]
remain marginalized out, ready for the next iteration of the loop.****************************************** * .----. .----. .----. * ╌╌╌╌▶|x[t-2]+---->|x[t-1]+---->| x[t] | * '-+--' '-+--' '-+--' * | | | * | | | * v v v * .--+---. .--+---. .-+--. * |y[t-2]| |y[t-1]| | y[t] | * '------' '------' '----' ******************************************
In fact, the operations automatically performed for this example are precisely those of the Kalman filter, without having to code them by hand.
-
L.M. Murray, D. Lundén, J. Kudlicka, D. Broman and T.B. Schön (2018). Delayed Sampling and Automatic Rao–Blackwellization of Probabilistic Programs. In Proceedings of the 21st International Conference on Artificial Intelligence and Statistics (AISTATS) 2018, Lanzarote, Spain. ↩