Saturday, January 18, 2025

Hierarchical partial pooling with tfprobability

Before we jump into the technicalities: This post is, of course, dedicated to McElreath who wrote one of most intriguing books on Bayesian (or should we just say – scientific?) modeling we’re aware of. If you haven’t read Statistical Rethinking, and are interested in modeling, you might definitely want to check it out. In this post, we’re not going to try to re-tell the story: Our clear focus will, instead, be a demonstration of how to do MCMC with tfprobability.

Concretely, this post has two parts. The first is a quick overview of how to use tfd_joint_sequential_distribution to construct a model, and then sample from it using Hamiltonian Monte Carlo. This part can be consulted for quick code look-up, or as a frugal template of the whole process.
The second part then walks through a multi-level model in more detail, showing how to extract, post-process and visualize sampling as well as diagnostic outputs.

Reedfrogs

The data comes with the rethinking package.

'data.frame':   48 obs. of  5 variables:
 $ density : int  10 10 10 10 10 10 10 10 10 10 ...
 $ pred    : Factor w/ 2 levels "no","pred": 1 1 1 1 1 1 1 1 2 2 ...
 $ size    : Factor w/ 2 levels "big","small": 1 1 1 1 2 2 2 2 1 1 ...
 $ surv    : int  9 10 7 10 9 9 10 9 4 9 ...
 $ propsurv: num  0.9 1 0.7 1 0.9 0.9 1 0.9 0.4 0.9 ...

The task is modeling survivor counts among tadpoles, where tadpoles are held in tanks of different sizes (equivalently, different numbers of inhabitants). Each row in the dataset describes one tank, with its initial count of inhabitants (density) and number of survivors (surv).
In the technical overview part, we build a simple unpooled model that describes every tank in isolation. Then, in the detailed walk-through, we’ll see how to construct a varying intercepts model that allows for information sharing between tanks.

Constructing models with tfd_joint_distribution_sequential

tfd_joint_distribution_sequential represents a model as a list of conditional distributions.
This is easiest to see on a real example, so we’ll jump right in, creating an unpooled model of the tadpole data.

This is the how the model specification would look in Stan:

model{
    vector[48] p;
    a ~ normal( 0 , 1.5 );
    for ( i in 1:48 ) {
        p[i] = a[tank[i]];
        p[i] = inv_logit(p[i]);
    }
    S ~ binomial( N , p );
}

And here is tfd_joint_distribution_sequential:

library(tensorflow)

# make sure you have at least version 0.7 of TensorFlow Probability 
# as of this writing, it is required of install the master branch:
# install_tensorflow(version = "nightly")
library(tfprobability)

n_tadpole_tanks <- nrow(d)
n_surviving <- d$surv
n_start <- d$density

m1 <- tfd_joint_distribution_sequential(
  list(
    # normal prior of per-tank logits
    tfd_multivariate_normal_diag(
      loc = rep(0, n_tadpole_tanks),
      scale_identity_multiplier = 1.5),
    # binomial distribution of survival counts
    function(l)
      tfd_independent(
        tfd_binomial(total_count = n_start, logits = l),
        reinterpreted_batch_ndims = 1
      )
  )
)

The model consists of two distributions: Prior means and variances for the 48 tadpole tanks are specified by tfd_multivariate_normal_diag; then tfd_binomial generates survival counts for each tank.
Note how the first distribution is unconditional, while the second depends on the first. Note too how the second has to be wrapped in tfd_independent to avoid wrong broadcasting. (This is an aspect of tfd_joint_distribution_sequential usage that deserves to be documented more systematically, which is surely going to happen. Just think that this functionality was added to TFP master only three weeks ago!)

As an aside, the model specification here ends up shorter than in Stan as tfd_binomial optionally takes logits as parameters.

As with every TFP distribution, you can do a quick functionality check by sampling from the model:

# sample a batch of 2 values 
# we get samples for every distribution in the model
s <- m1 %>% tfd_sample(2)
[[1]]
Tensor("MultivariateNormalDiag/sample/affine_linear_operator/forward/add:0",
shape=(2, 48), dtype=float32)

[[2]]
Tensor("IndependentJointDistributionSequential/sample/Beta/sample/Reshape:0",
shape=(2, 48), dtype=float32)

and computing log probabilities:

# we should get only the overall log probability of the model
m1 %>% tfd_log_prob(s)
t[[1]]
Tensor("MultivariateNormalDiag/sample/affine_linear_operator/forward/add:0",
shape=(2, 48), dtype=float32)

[[2]]
Tensor("IndependentJointDistributionSequential/sample/Beta/sample/Reshape:0",
shape=(2, 48), dtype=float32)

Now, let’s see how we can sample from this model using Hamiltonian Monte Carlo.

Running Hamiltonian Monte Carlo in TFP

We define a Hamiltonian Monte Carlo kernel with dynamic step size adaptation based on a desired acceptance probability.

# number of steps to run burnin
n_burnin <- 500

# optimization target is the likelihood of the logits given the data
logprob <- function(l)
  m1 %>% tfd_log_prob(list(l, n_surviving))

hmc <- mcmc_hamiltonian_monte_carlo(
  target_log_prob_fn = logprob,
  num_leapfrog_steps = 3,
  step_size = 0.1,
) %>%
  mcmc_simple_step_size_adaptation(
    target_accept_prob = 0.8,
    num_adaptation_steps = n_burnin
  )

We then run the sampler, passing in an initial state. If we want to run \(n\) chains, that state has to be of length \(n\), for every parameter in the model (here we have just one).

The sampling function, mcmc_sample_chain, may optionally be passed a trace_fn that tells TFP which kinds of meta information to save. Here we save acceptance ratios and step sizes.

# number of steps after burnin
n_steps <- 500
# number of chains
n_chain <- 4

# get starting values for the parameters
# their shape implicitly determines the number of chains we will run
# see current_state parameter passed to mcmc_sample_chain below
c(initial_logits, .) %<-% (m1 %>% tfd_sample(n_chain))

# tell TFP to keep track of acceptance ratio and step size
trace_fn <- function(state, pkr) {
  list(pkr$inner_results$is_accepted,
       pkr$inner_results$accepted_results$step_size)
}

res <- hmc %>% mcmc_sample_chain(
  num_results = n_steps,
  num_burnin_steps = n_burnin,
  current_state = initial_logits,
  trace_fn = trace_fn
)

When sampling is finished, we can access the samples as res$all_states:

mcmc_trace <- res$all_states
mcmc_trace
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack/TensorArrayGatherV3:0",
shape=(500, 4, 48), dtype=float32)

This is the shape of the samples for l, the 48 per-tank logits: 500 samples times 4 chains times 48 parameters.

From these samples, we can compute effective sample size and \(rhat\) (alias mcmc_potential_scale_reduction):

# Tensor("Mean:0", shape=(48,), dtype=float32)
ess <- mcmc_effective_sample_size(mcmc_trace) %>% tf$reduce_mean(axis = 0L)

# Tensor("potential_scale_reduction/potential_scale_reduction_single_state/sub_1:0", shape=(48,), dtype=float32)
rhat <- mcmc_potential_scale_reduction(mcmc_trace)

Whereas diagnostic information is available in res$trace:

# Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack_1/TensorArrayGatherV3:0",
# shape=(500, 4), dtype=bool)
is_accepted <- res$trace[[1]] 

# Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack_2/TensorArrayGatherV3:0",
# shape=(500,), dtype=float32)
step_size <- res$trace[[2]] 

After this quick outline, let’s move on to the topic promised in the title: multi-level modeling, or partial pooling. This time, we’ll also take a closer look at sampling results and diagnostic outputs.

Multi-level tadpoles

The multi-level model – or varying intercepts model, in this case: we’ll get to varying slopes in a later post – adds a hyperprior to the model. Instead of deciding on a mean and variance of the normal prior the logits are drawn from, we let the model learn means and variances for individual tanks.
These per-tank means, while being priors for the binomial logits, are assumed to be normally distributed, and are themselves regularized by a normal prior for the mean and an exponential prior for the variance.

For the Stan-savvy, here is the Stan formulation of this model.

list(
    # a_bar, the prior for the mean of the normal distribution of per-tank logits
    tfd_normal(loc = 0, scale = 1.5),
    # sigma, the prior for the variance of the normal distribution of per-tank logits
    tfd_exponential(rate = 1),
    # normal distribution of per-tank logits
    # parameters sigma and a_bar refer to the outputs of the above two distributions
    function(sigma, a_bar) 
      tfd_sample_distribution(
        tfd_normal(loc = a_bar, scale = sigma),
        sample_shape = list(n_tadpole_tanks)
      ), 
    # binomial distribution of survival counts
    # parameter l refers to the output of the normal distribution immediately above
    function(l)
      tfd_independent(
        tfd_binomial(total_count = n_start, logits = l),
        reinterpreted_batch_ndims = 1
      )
  )
)

Technically, dependencies in tfd_joint_distribution_sequential are defined via spatial proximity in the list: In the learned prior for the logits

function(sigma, a_bar) 
      tfd_sample_distribution(
        tfd_normal(loc = a_bar, scale = sigma),
        sample_shape = list(n_tadpole_tanks)
      )

sigma refers to the distribution immediately above, and a_bar to the one above that.

Analogously, in the distribution of survival counts

function(l)
      tfd_independent(
        tfd_binomial(total_count = n_start, logits = l),
        reinterpreted_batch_ndims = 1
      )

l refers to the distribution immediately preceding its own definition.

Again, let’s sample from this model to see if shapes are correct.

s <- m2 %>% tfd_sample(2)
s 

They are.

[[1]]
Tensor("Normal/sample_1/Reshape:0", shape=(2,), dtype=float32)

[[2]]
Tensor("Exponential/sample_1/Reshape:0", shape=(2,), dtype=float32)

[[3]]
Tensor("SampleJointDistributionSequential/sample_1/Normal/sample/Reshape:0",
shape=(2, 48), dtype=float32)

[[4]]
Tensor("IndependentJointDistributionSequential/sample_1/Beta/sample/Reshape:0",
shape=(2, 48), dtype=float32)

And to make sure we get one overall log_prob per batch:

Tensor("JointDistributionSequential/log_prob/add_3:0", shape=(2,), dtype=float32)

Training this model works like before, except that now the initial state comprises three parameters, a_bar, sigma and l:

c(initial_a, initial_s, initial_logits, .) %<-% (m2 %>% tfd_sample(n_chain))

Here is the sampling routine:

# the joint log probability now is based on three parameters
logprob <- function(a, s, l)
  m2 %>% tfd_log_prob(list(a, s, l, n_surviving))

hmc <- mcmc_hamiltonian_monte_carlo(
  target_log_prob_fn = logprob,
  num_leapfrog_steps = 3,
  # one step size for each parameter
  step_size = list(0.1, 0.1, 0.1),
) %>%
  mcmc_simple_step_size_adaptation(target_accept_prob = 0.8,
                                   num_adaptation_steps = n_burnin)

run_mcmc <- function(kernel) {
  kernel %>% mcmc_sample_chain(
    num_results = n_steps,
    num_burnin_steps = n_burnin,
    current_state = list(initial_a, tf$ones_like(initial_s), initial_logits),
    trace_fn = trace_fn
  )
}

res <- hmc %>% run_mcmc()
 
mcmc_trace <- res$all_states

This time, mcmc_trace is a list of three: We have

[[1]]
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack/TensorArrayGatherV3:0",
shape=(500, 4), dtype=float32)

[[2]]
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack_1/TensorArrayGatherV3:0",
shape=(500, 4), dtype=float32)

[[3]]
Tensor("mcmc_sample_chain/trace_scan/TensorArrayStack_2/TensorArrayGatherV3:0",
shape=(500, 4, 48), dtype=float32)

Now let’s create graph nodes for the results and information we’re interested in.

# as above, this is the raw result
mcmc_trace_ <- res$all_states

# we perform some reshaping operations directly in tensorflow
all_samples_ <-
  tf$concat(
    list(
      mcmc_trace_[[1]] %>% tf$expand_dims(axis = -1L),
      mcmc_trace_[[2]]  %>% tf$expand_dims(axis = -1L),
      mcmc_trace_[[3]]
    ),
    axis = -1L
  ) %>%
  tf$reshape(list(2000L, 50L))

# diagnostics, also as above
is_accepted_ <- res$trace[[1]]
step_size_ <- res$trace[[2]]

# effective sample size
# again we use tensorflow to get conveniently shaped outputs
ess_ <- mcmc_effective_sample_size(mcmc_trace) 
ess_ <- tf$concat(
  list(
    ess_[[1]] %>% tf$expand_dims(axis = -1L),
    ess_[[2]]  %>% tf$expand_dims(axis = -1L),
    ess_[[3]]
  ),
  axis = -1L
) 

# rhat, conveniently post-processed
rhat_ <- mcmc_potential_scale_reduction(mcmc_trace)
rhat_ <- tf$concat(
  list(
    rhat_[[1]] %>% tf$expand_dims(axis = -1L),
    rhat_[[2]]  %>% tf$expand_dims(axis = -1L),
    rhat_[[3]]
  ),
  axis = -1L
) 

And we’re ready to actually run the chains.

# so far, no sampling has been done!
# the actual sampling happens when we create a Session 
# and run the above-defined nodes
sess <- tf$Session()
eval <- function(...) sess$run(list(...))

c(mcmc_trace, all_samples, is_accepted, step_size, ess, rhat) %<-%
  eval(mcmc_trace_, all_samples_, is_accepted_, step_size_, ess_, rhat_)

This time, let’s actually inspect those results.

Multi-level tadpoles: Results

First, how do the chains behave?

Trace plots

Extract the samples for a_bar and sigma, as well as one of the learned priors for the logits:

Here’s a trace plot for a_bar:

prep_tibble <- function(samples) {
  as_tibble(samples, .name_repair = ~ c("chain_1", "chain_2", "chain_3", "chain_4")) %>% 
    add_column(sample = 1:500) %>%
    gather(key = "chain", value = "value", -sample)
}

plot_trace <- function(samples, param_name) {
  prep_tibble(samples) %>% 
    ggplot(aes(x = sample, y = value, color = chain)) +
    geom_line() + 
    ggtitle(param_name)
}

plot_trace(a_bar, "a_bar")

And here for sigma and a_1:

How about the posterior distributions of the parameters, first and foremost, the varying intercepts a_1a_48?

Posterior distributions

plot_posterior <- function(samples) {
  prep_tibble(samples) %>% 
    ggplot(aes(x = value, color = chain)) +
    geom_density() +
    theme_classic() +
    theme(legend.position = "none",
          axis.title = element_blank(),
          axis.text = element_blank(),
          axis.ticks = element_blank())
    
}

plot_posteriors <- function(sample_array, num_params) {
  plots <- purrr::map(1:num_params, ~ plot_posterior(sample_array[ , , .x] %>% as.matrix()))
  do.call(grid.arrange, plots)
}

plot_posteriors(mcmc_trace[[3]], dim(mcmc_trace[[3]])[3])

Now let’s see the corresponding posterior means and highest posterior density intervals.
(The below code includes the hyperpriors in summary as we’ll want to display a complete precis-like output soon.)

Posterior means and HPDIs

all_samples <- all_samples %>%
  as_tibble(.name_repair = ~ c("a_bar", "sigma", paste0("a_", 1:48))) 

means <- all_samples %>% 
  summarise_all(list (~ mean)) %>% 
  gather(key = "key", value = "mean")

sds <- all_samples %>% 
  summarise_all(list (~ sd)) %>% 
  gather(key = "key", value = "sd")

hpdis <-
  all_samples %>%
  summarise_all(list(~ list(hdi(.) %>% t() %>% as_tibble()))) %>% 
  unnest() 

hpdis_lower <- hpdis %>% select(-contains("upper")) %>%
  rename(lower0 = lower) %>%
  gather(key = "key", value = "lower") %>% 
  arrange(as.integer(str_sub(key, 6))) %>%
  mutate(key = c("a_bar", "sigma", paste0("a_", 1:48)))

hpdis_upper <- hpdis %>% select(-contains("lower")) %>%
  rename(upper0 = upper) %>%
  gather(key = "key", value = "upper") %>% 
  arrange(as.integer(str_sub(key, 6))) %>%
  mutate(key = c("a_bar", "sigma", paste0("a_", 1:48)))

summary <- means %>% 
  inner_join(sds, by = "key") %>% 
  inner_join(hpdis_lower, by = "key") %>%
  inner_join(hpdis_upper, by = "key")


summary %>% 
  filter(!key %in% c("a_bar", "sigma")) %>%
  mutate(key_fct = factor(key, levels = unique(key))) %>%
  ggplot(aes(x = key_fct, y = mean, ymin = lower, ymax = upper)) +
   geom_pointrange() + 
   coord_flip() +  
   xlab("") + ylab("post. mean and HPDI") +
   theme_minimal() 

Now for an equivalent to precis. We already computed means, standard deviations and the HPDI interval.
Let’s add n_eff, the effective number of samples, and rhat, the Gelman-Rubin statistic.

Comprehensive summary (a.k.a. “precis”)

is_accepted <- is_accepted %>% as.integer() %>% mean()
step_size <- purrr::map(step_size, mean)

ess <- apply(ess, 2, mean)

summary_with_diag <- summary %>% add_column(ess = ess, rhat = rhat)
summary_with_diag
# A tibble: 50 x 7
   key    mean    sd  lower upper   ess  rhat
   <chr> <dbl> <dbl>  <dbl> <dbl> <dbl> <dbl>
 1 a_bar  1.35 0.266  0.792  1.87 405.   1.00
 2 sigma  1.64 0.218  1.23   2.05  83.6  1.00
 3 a_1    2.14 0.887  0.451  3.92  33.5  1.04
 4 a_2    3.16 1.13   1.09   5.48  23.7  1.03
 5 a_3    1.01 0.698 -0.333  2.31  65.2  1.02
 6 a_4    3.02 1.04   1.06   5.05  31.1  1.03
 7 a_5    2.11 0.843  0.625  3.88  49.0  1.05
 8 a_6    2.06 0.904  0.496  3.87  39.8  1.03
 9 a_7    3.20 1.27   1.11   6.12  14.2  1.02
10 a_8    2.21 0.894  0.623  4.18  44.7  1.04
# ... with 40 more rows

For the varying intercepts, effective sample sizes are pretty low, indicating we might want to investigate possible reasons.

Let’s also display posterior survival probabilities, analogously to figure 13.2 in the book.

Posterior survival probabilities

sim_tanks <- rnorm(8000, a_bar, sigma)
tibble(x = sim_tanks) %>% ggplot(aes(x = x)) + geom_density() + xlab("distribution of per-tank logits")

# our usual sigmoid by another name (undo the logit)
logistic <- function(x) 1/(1 + exp(-x))
probs <- map_dbl(sim_tanks, logistic)
tibble(x = probs) %>% ggplot(aes(x = x)) + geom_density() + xlab("probability of survival")

Finally, we want to make sure we see the shrinkage behavior displayed in figure 13.1 in the book.

Shrinkage

summary %>% 
  filter(!key %in% c("a_bar", "sigma")) %>%
  select(key, mean) %>%
  mutate(est_survival = logistic(mean)) %>%
  add_column(act_survival = d$propsurv) %>%
  select(-mean) %>%
  gather(key = "type", value = "value", -key) %>%
  ggplot(aes(x = key, y = value, color = type)) +
  geom_point() +
  geom_hline(yintercept = mean(d$propsurv), size = 0.5, color = "cyan" ) +
  xlab("") +
  ylab("") +
  theme_minimal() +
  theme(axis.text.x = element_blank())

We see results similar in spirit to McElreath’s: estimates are shrunken to the mean (the cyan-colored line). Also, shrinkage seems to be more active in smaller tanks, which are the lower-numbered ones on the left of the plot.

Outlook

In this post, we saw how to construct a varying intercepts model with tfprobability, as well as how to extract sampling results and relevant diagnostics. In an upcoming post, we’ll move on to varying slopes.
With non-negligible probability, our example will build on one of Mc Elreath’s again…
Thanks for reading!

Related Articles

Latest Articles