# Variational Inference for MRP with Reliable Posterior Distributions

Part 2- The errors of our ways

MRP
Variational Inference
Author

Andy Timm

Published

November 20, 2022

This is the second post in my series on using Variational Inference to speed up relatively complex Bayesian models like Multilevel Regression and Poststratification without the approximation being of disastrously poor quality.

In the last post, I laid out why such reformulating the Bayesian inference problem as optimization might be desirable, but previewed why this might be quite hard to find high quality approximations amenable to optimization. I then introduced our running example (predicting national/sub-national opinion on an abortion question from the CCES using MRP), and gave an initial introduction to a version of Variational Inference where we maximize the Evidence Lower Bound (ELBO) as an objective, and do so using a mean-field Gaussian approximation. We saw that with 60k examples, this took about 8 hours to fit with MCMC, but 144 seconds (!) with VI.

In this post, we’ll explore the shortcomings of this initial approximation, and take a first pass at trying to better with a more complex (full rank) variational approximation. The goal is to get a better feel for what failing models could look like, at least in this relatively simple case.

The rough plan for the series is as follows:

1. Introducing the Problem- Why is VI useful, why VI can produce spherical cows
2. (This post) How far does iteration on classic VI algorithms like mean-field and full-rank get us?
3. Some theory on why posterior approximation with VI can be so poor
4. Seeing if some more sophisticated techniques like normalizing flows help

# The disclaimer

One sort of obvious objections to how I’ve set up this series is “Why not talk about theory on why VI approximations can be poor before trying stuff?”. While in practice I did read a lot of the papers for the next post before writing this one, I think there’s a lot of value is looking at failed solutions to a problem to build up intuition about what our failure mode looks like, and what it might require to get it right.

# Toplines

meanfield_60k <- readRDS("fit_60k_meanfield.rds")

# Meanfield
epred_mat_mf <- posterior_epred(meanfield_60k, newdata = poststrat_df_60k, draws = 1000)
mrp_estimates_vector_mf <- epred_mat_mf %*% poststrat_df_60k$n / sum(poststrat_df_60k$n)
mrp_estimate_mf <- c(mean = mean(mrp_estimates_vector_mf),
sd = sd(mrp_estimates_vector_mf))

# MCMC
epred_mat_mcmc <- posterior_epred(mcmc_60k, newdata = poststrat_df_60k, draws = 1000)
mrp_estimates_vector_mcmc <- epred_mat_mcmc %*% poststrat_df_60k$n / sum(poststrat_df_60k$n)
mrp_estimate_mcmc <- c(mean = mean(mrp_estimates_vector_mcmc),
sd = sd(mrp_estimates_vector_mcmc))

cat("Meanfield MRP estimate mean, sd: ", round(mrp_estimate_mf, 3))
cat("MCMC MRP estimate mean, sd: ", round(mrp_estimate_mcmc, 3))
Mean SD
MCMC 43.9% .2%
mean-field VI 43.7% .2%

Starting with basics, the toplines are pretty much identical, which is a good start. The minor difference here could easily reverse on a different seed- from a few quick re-runs these often end up having matching means to 3 decimals.

# State Level Estimates

What happens if we produce state level estimates, similar to the plot last post comparing MRP to a simple weighted estimate? Note that I’ll steer away from the MRP Case Study example here in a few ways. I’ll use tidybayes for working with the draws (more elegant than their loop based approach), and I’ll use more draws (helps with simulation error in smaller states).

mcmc_state_level <- poststrat_df_60k %>% add_epred_draws(mcmc_60k, ndraws = 1000)
mfvi_state_level <- poststrat_df_60k %>% add_epred_draws(meanfield_60k, ndraws = 1000)

mcmc_state_level %>% glimpse()
Rows: 12,000,000
Columns: 13
Groups: state, eth, male, age, educ, n, repvote, region, .row [12,000]
$state <chr> "AL", "AL", "AL", "AL", "AL", "AL", "AL", "AL", "AL", "AL",…$ eth        <chr> "White", "White", "White", "White", "White", "White", "Whit…
$male <dbl> -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,…$ age        <chr> "18-29", "18-29", "18-29", "18-29", "18-29", "18-29", "18-2…
$educ <chr> "No HS", "No HS", "No HS", "No HS", "No HS", "No HS", "No H…$ n          <dbl> 23948, 23948, 23948, 23948, 23948, 23948, 23948, 23948, 239…
$repvote <dbl> 0.6437414, 0.6437414, 0.6437414, 0.6437414, 0.6437414, 0.64…$ region     <chr> "South", "South", "South", "South", "South", "South", "Sout…
$.row <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…$ .chain     <int> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA,…
$.iteration <int> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA,…$ .draw      <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, …
$.epred <dbl> 0.5771322, 0.5189677, 0.5483006, 0.5421404, 0.5417602, 0.55… If you haven’t worked with tidybayes before, the glimpse above should help give some intuition about the new shape of the data- we’ve take the 12,000 row poststrat_df_60k, and added a row per observation per draw, with the prediction (.epred) and related metadata. This gives 12,000 x 1,000 = 12 million rows. This really isn’t the most space efficient storage, but it allows for very elegant dplyr style manipulation of results and quick exploration. Let’s now plot and compare the 50 and 95% credible intervals by state between the two models. mcmc_state_summary <- mcmc_state_level %>% # multiply each draw by it's cell's proportion of state N # this is the P in MRP group_by(state,.draw) %>% mutate(postrat_draw = sum(.epred*(n/sum(n)))) %>% group_by(state) %>% median_qi(postrat_draw, .width = c(.5,.95)) %>% mutate(model = "MCMC") mfvi_state_summary <- mfvi_state_level %>% group_by(state,.draw) %>% mutate(postrat_draw = sum(.epred*(n/sum(n)))) %>% group_by(state) %>% median_qi(postrat_draw, .width = c(.5,.95)) %>% mutate(model = "MF-VI") combined_summary <- bind_rows(mcmc_state_summary,mfvi_state_summary) combined_summary %>% mutate(ordered_state = fct_reorder(combined_summary$state,
combined_summary$postrat_draw)) %>% ggplot(aes(y = ordered_state, x = postrat_draw, xmin = .lower, xmax = .upper, color = model)) + geom_pointinterval(position = position_dodge(1)) + xlim(.25,.75) + theme(legend.position="top") + xlab("Should employers be allowed to deny their employees abortion care?") + ylab("State") … That looks concerning. What might you get wrong if you used the VI approximation for inference here? If you only cared about the median estimate primarily, you might be ok with this effort. If you care about uncertainty though, here’s a non-exhaustive list of concerns here: 1. Probably unimodal, smooth posterior distributions from MCMC have gone off-course to the point where the Median/50/95% presentation no longer seems up to expressing the posterior shape (more on this in a second). 2. The MF-VI posteriors are often narrower in 50% or 95% CI- we’d on average underestimate various types of uncertainty here. 3. Worse1, the MF-VI posterior’s CIs aren’t consistently narrower, either in the sense they are always narrower, or that they tend to consistently distort the same way. Sometimes both the 50% and 95% are just a small amount narrower than MCMC- the Michigan posterior attempt looks passable. Sometimes things are worse, with 50% MFVI CIs almost as wide as the MCMC 95% interval- Wyoming shows such a distortion. Sometimes the probability mass between 50% and 95% is confined to such a minuscule range it looks like I forgot to plot it. That last point is particularly important because it suggests there’s no easy rule of thumb for mechanically correcting these intervals, or deciding which could be plausible approximations without the MCMC plot alongside to guide that process. We can’t use VI to save a ton of time, infer the intervals consistently need to x% be wider, and call it a day- we need to reckon more precisely with why they’re distorted. Let’s return now to the point about how the shape has gone wrong. Below is a dot plot (Kay et al., 2016)- each point here represents about 1% of the probability mass. I enjoy this approach to posterior visualization when things are getting weird, as this clarifies a lot about the full shape of the posterior distribution, making fewer smoothing assumptions like a density or eye plot might. mcmc_state_points <- mcmc_state_level %>% # multiply each draw by it's cell's proportion of state N # this is the P in MRP group_by(state,.draw) %>% summarize(postrat_draw = sum(.epred*(n/sum(n)))) %>% mutate(model = "MCMC") mfvi_state_points <- mfvi_state_level %>% group_by(state,.draw) %>% summarize(postrat_draw = sum(.epred*(n/sum(n)))) %>% mutate(model = "MF-VI") combined_points <- mcmc_state_points %>% bind_rows(mfvi_state_points) %>% ungroup() combined_points %>% mutate(ordered_state = fct_reorder(combined_points$state,
combined_points_w_lower_tol$postrat_draw)) %>% ggplot(aes(y = ordered_state, x = postrat_draw, color = model)) + stat_dots(quantiles = 100) + facet_wrap(~model) + theme(legend.position="none") + xlab("Should employers be allowed to deny their employees abortion care?") + ylab("State") … That certainly looks different, but I don’t really think I’d say it looks meaningfully better4. Looking at the printed out ELBO, it’s pretty clear that there was no traction after the first ~1000 samples. A variational family this simple isn’t going to get much better, no matter how much time you give it. ## Full-Rank Approximation So if extend training time, but improvements don’t result, maybe the next option is ask whether we need something more sophisticated than a mean-field approximation. Instead of q(z) = \prod_{j=1}^{m} q_j(z_j) let’s now try the full-rank approximation. Gather than each z_j getting it’s own independent Gaussian, this uses a single multivariate normal distribution- so we can now (roughly) learn correlation structure, fancy. q(z) = \mathcal{N}(z|\mu,\Sigma) tic() fit_60k_fullrank <- stan_glmer(abortion ~ (1 | state) + (1 | eth) + (1 | educ) + male + (1 | male:eth) + (1 | educ:age) + (1 | educ:eth) + repvote + factor(region), family = binomial(link = "logit"), data = cces_all_df, prior = normal(0, 1, autoscale = TRUE), prior_covariance = decov(scale = 0.50), adapt_delta = 0.99, tol_rel_obj = 1e-8, # Printing the ELBO every 1k draws refresh = 1000, algorithm = "fullrank", QR = TRUE, seed = 605) Chain 1: ------------------------------------------------------------ Chain 1: EXPERIMENTAL ALGORITHM: Chain 1: This procedure has not been thoroughly tested and may be unstable Chain 1: or buggy. The interface is subject to change. Chain 1: ------------------------------------------------------------ Chain 1: Chain 1: Chain 1: Chain 1: Gradient evaluation took 0.025 seconds Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 250 seconds. Chain 1: Adjust your expectations accordingly! Chain 1: Chain 1: Chain 1: Begin eta adaptation. Chain 1: Iteration: 1 / 250 [ 0%] (Adaptation) Chain 1: Iteration: 50 / 250 [ 20%] (Adaptation) Chain 1: Iteration: 100 / 250 [ 40%] (Adaptation) Chain 1: Iteration: 150 / 250 [ 60%] (Adaptation) Chain 1: Iteration: 200 / 250 [ 80%] (Adaptation) Chain 1: Iteration: 250 / 250 [100%] (Adaptation) Chain 1: Success! Found best value [eta = 0.1]. Chain 1: Chain 1: Begin stochastic gradient ascent. Chain 1: iter ELBO delta_ELBO_mean delta_ELBO_med notes Chain 1: 100 -248586.032 1.000 1.000 Chain 1: 200 -180460.369 0.689 1.000 Chain 1: 300 -121675.221 0.620 0.483 Chain 1: 400 -87431.017 0.563 0.483 Chain 1: 500 -120999.829 0.506 0.392 Chain 1: 600 -96768.296 0.463 0.392 Chain 1: 700 -93851.607 0.402 0.378 Chain 1: 800 -92494.273 0.353 0.378 Chain 1: 900 -74378.556 0.341 0.277 Chain 1: 1000 -77681.560 0.311 0.277 Chain 1: 1100 -77465.866 0.211 0.250 Chain 1: 1200 -68692.287 0.186 0.244 Chain 1: 1300 -75140.633 0.147 0.128 Chain 1: 1400 -49430.772 0.160 0.128 Chain 1: 1500 -59011.994 0.148 0.128 Chain 1: 1600 -57033.572 0.127 0.086 Chain 1: 1700 -56133.855 0.125 0.086 Chain 1: 1800 -46605.149 0.144 0.128 Chain 1: 1900 -47895.964 0.122 0.086 Chain 1: 2000 -44745.890 0.125 0.086 Chain 1: 2100 -43472.467 0.128 0.086 Chain 1: 2200 -43454.384 0.115 0.070 Chain 1: 2300 -41781.249 0.110 0.040 Chain 1: 2400 -42045.221 0.059 0.035 Chain 1: 2500 -41381.652 0.044 0.029 Chain 1: 2600 -40754.440 0.043 0.027 Chain 1: 2700 -41108.136 0.042 0.027 Chain 1: 2800 -40450.439 0.023 0.016 Chain 1: 2900 -40423.015 0.020 0.016 Chain 1: 3000 -40375.121 0.013 0.015 Chain 1: 3100 -40227.022 0.011 0.009 Chain 1: 3200 -40302.411 0.011 0.009 Chain 1: 3300 -40352.339 0.007 0.006 Chain 1: 3400 -40174.196 0.007 0.004 Chain 1: 3500 -40089.973 0.006 0.004 Chain 1: 3600 -40143.009 0.004 0.002 Chain 1: 3700 -40123.486 0.003 0.002 Chain 1: 3800 -40044.004 0.002 0.002 Chain 1: 3900 -39955.515 0.002 0.002 Chain 1: 4000 -40003.851 0.002 0.002 Chain 1: 4100 -39948.544 0.002 0.002 Chain 1: 4200 -40028.027 0.002 0.002 Chain 1: 4300 -39907.006 0.002 0.002 Chain 1: 4400 -39868.266 0.002 0.002 Chain 1: 4500 -39938.386 0.002 0.002 Chain 1: 4600 -39837.339 0.002 0.002 Chain 1: 4700 -39852.349 0.002 0.002 Chain 1: 4800 -39823.670 0.002 0.002 Chain 1: 4900 -39809.797 0.001 0.001 Chain 1: 5000 -39807.261 0.001 0.001 Chain 1: 5100 -39806.402 0.001 0.001 Chain 1: 5200 -39818.805 0.001 0.001 Chain 1: 5300 -39797.428 0.001 0.001 Chain 1: 5400 -39790.469 0.001 0.000 Chain 1: 5500 -39785.797 0.001 0.000 Chain 1: 5600 -39779.121 0.000 0.000 Chain 1: 5700 -39780.314 0.000 0.000 Chain 1: 5800 -39771.363 0.000 0.000 Chain 1: 5900 -39770.673 0.000 0.000 Chain 1: 6000 -39764.096 0.000 0.000 Chain 1: 6100 -39764.173 0.000 0.000 Chain 1: 6200 -39765.651 0.000 0.000 Chain 1: 6300 -39756.809 0.000 0.000 Chain 1: 6400 -39753.724 0.000 0.000 Chain 1: 6500 -39754.753 0.000 0.000 Chain 1: 6600 -39750.392 0.000 0.000 Chain 1: 6700 -39753.067 0.000 0.000 Chain 1: 6800 -39750.341 0.000 0.000 Chain 1: 6900 -39745.696 0.000 0.000 Chain 1: 7000 -39743.521 0.000 0.000 Chain 1: 7100 -39739.157 0.000 0.000 Chain 1: 7200 -39736.689 0.000 0.000 Chain 1: 7300 -39743.472 0.000 0.000 Chain 1: 7400 -39738.431 0.000 0.000 Chain 1: 7500 -39740.789 0.000 0.000 Chain 1: 7600 -39735.842 0.000 0.000 Chain 1: 7700 -39733.493 0.000 0.000 Chain 1: 7800 -39735.015 0.000 0.000 Chain 1: 7900 -39736.429 0.000 0.000 Chain 1: 8000 -39733.548 0.000 0.000 Chain 1: 8100 -39732.722 0.000 0.000 Chain 1: 8200 -39734.720 0.000 0.000 Chain 1: 8300 -39732.932 0.000 0.000 Chain 1: 8400 -39727.658 0.000 0.000 Chain 1: 8500 -39734.522 0.000 0.000 Chain 1: 8600 -39728.602 0.000 0.000 Chain 1: 8700 -39724.690 0.000 0.000 Chain 1: 8800 -39725.374 0.000 0.000 Chain 1: 8900 -39731.450 0.000 0.000 Chain 1: 9000 -39725.866 0.000 0.000 Chain 1: 9100 -39728.639 0.000 0.000 Chain 1: 9200 -39730.156 0.000 0.000 Chain 1: 9300 -39729.036 0.000 0.000 Chain 1: 9400 -39725.536 0.000 0.000 Chain 1: 9500 -39727.031 0.000 0.000 Chain 1: 9600 -39725.389 0.000 0.000 Chain 1: 9700 -39727.947 0.000 0.000 Chain 1: 9800 -39723.932 0.000 0.000 Chain 1: 9900 -39723.173 0.000 0.000 Chain 1: 10000 -39723.944 0.000 0.000 Chain 1: Informational Message: The maximum number of iterations is reached! The algorithm may not have converged. Chain 1: This variational approximation is not guaranteed to be meaningful. Chain 1: Chain 1: Drawing a sample of size 1000 from the approximate posterior... Chain 1: COMPLETED. Warning: Pareto k diagnostic value is 2.95. Resampling is disabled. Decreasing tol_rel_obj may help if variational algorithm has terminated prematurely. Otherwise consider using sampling instead. toc() 350.16 sec elapsed full_rank_draws <- poststrat_df_60k %>% add_epred_draws(fit_60k_fullrank, ndraws = 1000) frvi_points <- full_rank_draws %>% group_by(state,.draw) %>% summarize(postrat_draw = sum(.epred*(n/sum(n)))) %>% mutate(model = "FR-VI") combined_points_w_frvi <- combined_points_w_lower_tol %>% bind_rows(frvi_points) %>% ungroup() combined_points_w_frvi %>% mutate(ordered_state = fct_reorder(combined_points_w_frvi$state,
combined_points_w_frvi\$postrat_draw)) %>%
ggplot(aes(y = ordered_state,
x = postrat_draw,
color = model)) +
stat_dots(quantiles = 100) +
facet_wrap(~model) +
theme(legend.position="none") +
xlab("Should employers be allowed to deny their employees abortion care?") +
ylab("State")

The first thing to note here is that unlike the mean-field approximation, fitting this model required some tinkering to get it to fit. I ended up needing to set QR = TRUE (ie, use a QR decomposition) to get this to fit at all (unless I set the initialization to 0, at which point the posterior collapsed to nearly a single point).

Unfortunately, this version has a similar spiky posterior distribution. In terms of uncertainty, it’s clearly worse than the mean-field implementation. The ELBO starts from higher, spends some time actually improving, but also quickly reaches a plateau. It doesn’t seem like this is a way out either.

# Where to from here? (Why is it like this?)

We’ve seen that simple variational families like the mean-field and full-rank can approximately mirror the central tendencies of MCMC, but things fall apart as we attempt to consider uncertainty, either through simple credible intervals, or especially once we start to visualize the unrealistic, lumpy VI posterior distributions in their entirety.

This isn’t something we can solve with more training time: each of these algorithms had reached the lowest ELBO they could well before we produced final draws. If I had to guess, I think we need a fundamentally more expressive class of variational family to make progress.

While trying to fit models without digging too much into the theory of why VI approximations can be poor has been fun, it’s time to bring in some theory. In the next post, I’ll explore the literature on why the uncertainty behavior of VI can be so dubious. In the following one, I’ll illustrate some better diagnostics as well.

The code for this post can be found here. Thanks for reading.

## Footnotes

1. Really, the worst type of wrong, completely unpredictable wrong. If you spend time staring to try to infer a causal pattern of which states we can’t estimate well, you’re likely just going to end up confused.↩︎

2. Some of these MFVI distributions are bad enough that you might reasonably wonder if some of the badness is just plotting weirdness. That was my intuition at first. Of course though, this is sufficient granularity to make the MCMC results look reasonable. But even if you zoom in on 1 or two states and add way more points, the improbably sharp spikes remain.↩︎

3. Phrase due to Richard McElreath. The magic of good visualizations like Kay et al.’s is that makes it trivial to let pattern recognition go to work, and be able to go “oh, that looks wrong”.↩︎

4. Also, apologies for showing every 100 iterations; the rstanarm parameter to set this, refresh doesn’t appear to work properly with non-MCMC models, so I can either not show the ELBO or blow up the post with this.↩︎

## Citation

BibTeX citation:
@online{timm2022,
author = {Andy Timm},
title = {Variational {Inference} for {MRP} with {Reliable} {Posterior}
{Distributions}},
date = {2022-11-20},
url = {https://andytimm.github.io/Variational_MRP_pt2.html},
langid = {en}
}