Here is a little example of fitting mixture model with R and Stan.
Given a set of likelihood functions \(\pi_1(\boldsymbol y \mid \alpha_1), \ldots, \pi_k(\boldsymbol y \mid \alpha_K)\) for variables \(\boldsymbol y = (y_1, \ldots, y_N)\), parameterized by \(\boldsymbol \alpha = (\alpha_1, \ldots, \alpha_K)\) and mixture weights \(\boldsymbol \theta = (\theta_1, \ldots, \theta_K)\) such that \(0 \le \theta_{k} \le 1\) and \(\sum\theta_{k} = 1\), the mixture likelihood is
\[ \pi(\boldsymbol y \mid{\boldsymbol \alpha}, {\boldsymbol \theta}) = \sum_{k=1}^{K}\sum_{n=1}^{N}\theta_k \pi_k(y_n \mid \alpha_k). \]
Let’s consider a simple example where the likelihood is given by a mixture of two Poisson-lognormal (PoiLog),
\[ \pi(y_1, ..., y_N \mid \mu_1, \mu_2, \sigma_1,\sigma_2, \theta) = \sum_{n=1}^{N}\theta PoiLog(y_n \mid \mu_1, \sigma_1) + \sum_{n=1}^{N}(1-\theta) PoiLog(y_n \mid \mu_2, \sigma_2). \]
We define two different components of Poisson-lognormal distributions, \(PoiLog(\mu_1 = log(50) \simeq 3.91, \sigma_1 = 0.8)\) and \(PoiLog(\mu_2 = log(400) \simeq 5.99, \sigma_2 = 0.2)\) with the mixture weight \(\theta = 0.8\)
library(cowplot)
library(ggthemes)
theme_set(theme_solarized(light=FALSE))
xx <- seq(1, 1000, by = 1)
y1 <- dnorm(log(xx), mu[1], sigma[1])
y2 <- dnorm(log(xx), mu[2], sigma[2])
y_mix <- theta * y1 + (1 - theta) * y2
line_dat <- data_frame(xx,
comp1 = y1,
comp2 = y2,
comp_mix = y_mix) %>%
gather(Model, val, 2:4) %>%
mutate(model2 = ifelse(Model == "comp_mix",
"Mixed components",
"Each component"))
ggplot(line_dat, aes(x = xx %>% log, y = val %>% exp, col = Model)) +
geom_line() +
geom_point() +
facet_wrap(~model2, scale = "free") +
xlab("log k") +
ylab("Density")
then we simulate data based on the above probability density.
set.seed(123)
N <- 100
z <- rbinom(N, 1, 1 - theta) + 1
log_lambda <- rnorm(N, mu[z], sigma[z])
Y <- rpois(N, exp(log_lambda))
#lambda <- rlnorm(N, mu[z] - 0.5 * sigma[z]^2, sigma[z])
#Y <- rpois(N, lambda)
hist_dat <- data_frame(Y)
p1 <- ggplot(hist_dat, aes(Y)) +
geom_histogram(bins = 20, fill = "#268bd2", col = "black")
p2 <- ggplot(hist_dat, aes(log(Y))) +
geom_histogram(bins = 20, fill = "#268bd2", col = "black")
plot_grid(p1, p2)
First, we fit the direct centered parameterization.
\[ \mu_1,\mu_2 \sim \mathcal{N}(0, 10) \]
\[ \sigma_1, \sigma_2 \sim \text{Half-}Cauchy(0, 5), \]
\[ log\lambda_{1n} \sim \mathcal{N}(\mu_1, \sigma_1) \]
\[ log\lambda_{2n} \sim \mathcal{N}(\mu_2, \sigma_2) \]
\[ \theta \sim Beta(2, 2) \]
## data{
## int<lower=0> N;
## int y[N];
## }
##
## parameters {
## ordered[2] mu;
## real<lower=0> sigma[2];
## real<lower=0, upper=1> theta;
## vector[N] log_lambda[2];
## }
## model {
## sigma ~ cauchy(0, 2.5);
## mu ~ normal(0, 10);
## theta ~ beta(2, 2);
## for (i in 1:2)
## log_lambda[i,] ~ normal(mu[i], sigma[i]);
## for (n in 1:N)
## target += log_mix(theta,
## poisson_log_lpmf(y[n] | log_lambda[1, n]),
## poisson_log_lpmf(y[n] | log_lambda[2, n]));
## }
The model looks straightforward but this model will suffer inefficiency in parameter sampling. Because the values of \(\mu\) and \(\sigma\) are highly correlated in the posterior (Betancourt and Girolami 2013). When the model is expressed this way, Stan has trouble sampling from the neck of the funnel geometry (Figure 28.1 in Stan Version 2.17.0 Manual), where \(\sigma\) is small and thus \(\mu\) is constrained to be near 0.
Loading required package: StanHeaders
rstan (Version 2.17.3, GitRev: 2e1f913d3ca3)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
Attaching package: 'rstan'
The following object is masked from 'package:tidyr':
extract
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
list_dat <- list(N = N, y = Y)
fit_cp <- stan(file = "poilog_mix_cp.stan",
data = list_dat,
iter = 2000,
warmup = 1000,
thin = 1,
chains = 4,
refresh = 500,
control = list(adapt_delta = 0.9, max_treedepth = 20))
Warning: There were 4 divergent transitions after warmup. Increasing adapt_delta above 0.9 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
Warning: There were 2 chains where the estimated Bayesian Fraction of Missing Information was low. See
http://mc-stan.org/misc/warnings.html#bfmi-low
Warning: Examine the pairs() plot to diagnose sampling problems
The model did not converge well. Additionally, there were divergent transitions, which indicates that the simulated Hamiltonian diverges from the true Hamiltonian. Even a few numbers of divergent transitions indicate serious bias on the estimates.
print(fit_cp,
pars = c("mu", "sigma", "theta",
"log_lambda[1,1]",
"log_lambda[1,2]",
"log_lambda[1,100]",
"log_lambda[2,1]",
"log_lambda[2,2]",
"log_lambda[2,100]",
"lp__"))
## Inference for Stan model: poilog_mix_cp.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75%
## mu[1] 3.92 0.12 0.20 3.66 3.78 3.86 4.04
## mu[2] 6.73 0.68 2.81 4.39 5.94 5.99 6.06
## sigma[1] 0.86 0.11 0.17 0.64 0.73 0.80 0.98
## sigma[2] 1.28 1.08 3.02 0.17 0.22 0.26 0.37
## theta 0.85 0.05 0.08 0.72 0.79 0.83 0.91
## log_lambda[1,1] 4.19 0.00 0.13 3.94 4.11 4.19 4.27
## log_lambda[1,2] 3.98 0.00 0.13 3.71 3.89 3.98 4.07
## log_lambda[1,100] 3.08 0.00 0.21 2.67 2.94 3.09 3.22
## log_lambda[2,1] 6.75 0.71 4.60 2.63 5.77 6.00 6.24
## log_lambda[2,2] 6.72 0.68 4.21 2.67 5.77 6.00 6.24
## log_lambda[2,100] 6.75 0.71 4.23 2.81 5.78 6.00 6.24
## lp__ -391.00 63.61 113.81 -715.46 -380.54 -345.35 -327.57
## 97.5% n_eff Rhat
## mu[1] 4.36 2 2.13
## mu[2] 17.09 17 1.22
## sigma[1] 1.22 2 2.23
## sigma[2] 11.05 8 1.96
## theta 0.99 2 2.28
## log_lambda[1,1] 4.41 2297 1.00
## log_lambda[1,2] 4.24 4000 1.00
## log_lambda[1,100] 3.46 4000 1.00
## log_lambda[2,1] 19.39 42 1.08
## log_lambda[2,2] 18.64 38 1.08
## log_lambda[2,100] 19.79 35 1.09
## lp__ -300.65 3 3.63
##
## Samples were drawn using NUTS(diag_e) at Thu Sep 27 11:22:27 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
Instead of sampling \(log\lambda_{1n}\) and \(log\lambda_{2n}\) directly, the model can be converted to the following more efficient form,
\[ \mu_1,\mu_2 \sim \mathcal{N}(0, 10) \]
\[ \sigma_1, \sigma_2 \sim \text{Half-}Cauchy(0, 5), \]
\[ log\tilde\lambda_{1n},log\tilde\lambda_{2n} \sim \mathcal{N}(0,1) \]
\[ log\lambda_{1n} = \mu_1 + log\tilde\lambda_{1n} \times \sigma_1 \]
\[ log\lambda_{2n} = \mu_2 + log\tilde\lambda_{2n} \times \sigma_2 \]
\[ \theta \sim Beta(2, 2) \]
## data{
## int<lower=0> N;
## int y[N];
## }
##
## parameters {
## ordered[2] mu;
## real<lower=0> sigma[2];
## real<lower=0, upper=1> theta;
## vector[N] log_lambda_tilde[2];
## }
## transformed parameters {
## vector[N] log_lambda[2];
## for (i in 1:2)
## log_lambda[i,] = mu[i] + log_lambda_tilde[i,] * sigma[i];
## }
##
## model {
## sigma ~ cauchy(0, 2.5);
## mu ~ normal(0, 10);
## theta ~ beta(2, 2);
## for (i in 1:2)
## log_lambda_tilde[i, ] ~ normal(0, 1);
##
## for (n in 1:N)
## target += log_mix(theta,
## poisson_log_lpmf(y[n] | log_lambda[1, n]),
## poisson_log_lpmf(y[n] | log_lambda[2, n]));
## }
This non-centered parameterization exchanges a direct dependence between (\(\mu\), \(\sigma\)) and \(log\lambda\) for a dependence between (\(\mu\), \(\sigma\)) and \(y_i\), which would reduce the dependence between \(\lambda\), \(\mu\), and \(\sigma\).
fit_ncp <- stan(file = "poilog_mix_ncp.stan",
data = list_dat,
iter = 2000,
warmup = 1000,
thin = 1,
chains = 4,
refresh = 500,
control = list(adapt_delta = 0.9, max_treedepth = 20))
Now we don’t have divergent transitions. The resulting estimates well recovered the original values (\(\boldsymbol \mu\), \(\boldsymbol \sigma\) and \(\theta\)) 🍺. Note that the above model will not work very well when the two components are relatively close (see detailed explanation).
print(fit_ncp,
pars = c("mu", "sigma", "theta",
"log_lambda[1,1]",
"log_lambda[1,2]",
"log_lambda[1,100]",
"log_lambda[2,1]",
"log_lambda[2,2]",
"log_lambda[2,100]",
"lp__"))
## Inference for Stan model: poilog_mix_ncp.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75%
## mu[1] 3.79 0.01 0.09 3.62 3.73 3.79 3.85
## mu[2] 5.97 0.01 0.07 5.83 5.93 5.97 6.02
## sigma[1] 0.75 0.01 0.09 0.61 0.69 0.74 0.80
## sigma[2] 0.26 0.01 0.07 0.16 0.21 0.25 0.30
## theta 0.80 0.00 0.04 0.71 0.77 0.80 0.83
## log_lambda[1,1] 4.18 0.00 0.12 3.94 4.10 4.19 4.27
## log_lambda[1,2] 3.97 0.00 0.14 3.70 3.88 3.98 4.07
## log_lambda[1,100] 3.08 0.00 0.20 2.67 2.95 3.09 3.23
## log_lambda[2,1] 5.97 0.00 0.29 5.37 5.80 5.97 6.15
## log_lambda[2,2] 5.97 0.00 0.29 5.40 5.79 5.98 6.16
## log_lambda[2,100] 5.98 0.00 0.28 5.41 5.81 5.98 6.15
## lp__ -505.78 0.65 13.07 -532.06 -514.33 -505.31 -496.95
## 97.5% n_eff Rhat
## mu[1] 3.96 112 1.02
## mu[2] 6.11 135 1.05
## sigma[1] 0.96 86 1.04
## sigma[2] 0.43 83 1.02
## theta 0.88 4000 1.01
## log_lambda[1,1] 4.42 4000 1.00
## log_lambda[1,2] 4.23 4000 1.00
## log_lambda[1,100] 3.47 4000 1.00
## log_lambda[2,1] 6.54 4000 1.00
## log_lambda[2,2] 6.53 4000 1.00
## log_lambda[2,100] 6.53 4000 1.00
## lp__ -481.13 403 1.01
##
## Samples were drawn using NUTS(diag_e) at Thu Sep 27 11:23:14 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
## CXXFLAGS=-O3 -mtune=native -march=native -Wno-unused-variable -Wno-unused-function
## CC=clang
## CXX=clang++ -arch x86_64 -ftemplate-depth-256
## Session info -------------------------------------------------------------
## setting value
## version R version 3.5.1 (2018-07-02)
## system x86_64, darwin17.7.0
## ui unknown
## language (EN)
## collate en_US.UTF-8
## tz America/Detroit
## date 2018-09-27
## Packages -----------------------------------------------------------------
## package * version date source
## assertthat 0.2.0 2017-04-11 CRAN (R 3.5.0)
## BH 1.66.0-1 2018-02-13 CRAN (R 3.5.0)
## cli 1.0.0 2017-11-05 CRAN (R 3.5.0)
## colorspace 1.3-2 2016-12-14 CRAN (R 3.5.0)
## crayon 1.3.4 2017-09-16 CRAN (R 3.5.0)
## digest 0.6.15 2018-01-28 CRAN (R 3.5.0)
## ggplot2 * 3.0.0 2018-07-03 cran (@3.0.0)
## glue 1.3.0 2018-09-21 Github (tidyverse/glue@4e74901)
## graphics * 3.5.1 2018-07-26 local
## grDevices * 3.5.1 2018-07-26 local
## grid 3.5.1 2018-07-26 local
## gridExtra 2.3 2017-09-09 CRAN (R 3.5.0)
## gtable 0.2.0 2016-02-26 CRAN (R 3.5.0)
## inline 0.3.15 2018-05-18 CRAN (R 3.5.0)
## labeling 0.3 2014-08-23 CRAN (R 3.5.0)
## lattice 0.20-35 2017-03-25 CRAN (R 3.5.1)
## lazyeval 0.2.1 2017-10-29 CRAN (R 3.5.0)
## magrittr 1.5 2014-11-22 CRAN (R 3.5.0)
## MASS 7.3-50 2018-04-30 CRAN (R 3.5.1)
## Matrix 1.2-14 2018-04-13 CRAN (R 3.5.1)
## methods * 3.5.1 2018-07-26 local
## mgcv 1.8-24 2018-06-23 CRAN (R 3.5.1)
## munsell 0.5.0 2018-06-12 CRAN (R 3.5.0)
## nlme 3.1-137 2018-04-07 CRAN (R 3.5.1)
## pillar 1.2.3 2018-05-25 CRAN (R 3.5.0)
## plyr 1.8.4 2016-06-08 CRAN (R 3.5.0)
## R6 2.2.2 2017-06-17 CRAN (R 3.5.0)
## RColorBrewer 1.1-2 2014-12-07 CRAN (R 3.5.1)
## Rcpp 0.12.18 2018-07-23 cran (@0.12.18)
## RcppEigen 0.3.3.4.0 2018-02-07 CRAN (R 3.5.0)
## reshape2 1.4.3 2017-12-11 CRAN (R 3.5.0)
## rlang 0.2.2 2018-08-16 cran (@0.2.2)
## rstan * 2.17.3 2018-01-20 CRAN (R 3.5.0)
## scales 1.0.0 2018-08-09 cran (@1.0.0)
## StanHeaders * 2.17.2 2018-01-20 CRAN (R 3.5.0)
## stats * 3.5.1 2018-07-26 local
## stats4 3.5.1 2018-07-26 local
## stringi 1.2.3 2018-06-12 CRAN (R 3.5.0)
## stringr * 1.3.1 2018-05-10 CRAN (R 3.5.0)
## tibble * 1.4.2 2018-01-22 CRAN (R 3.5.0)
## tools 3.5.1 2018-07-26 local
## utf8 1.1.4 2018-05-24 CRAN (R 3.5.0)
## utils * 3.5.1 2018-07-26 local
## viridisLite 0.3.0 2018-02-01 CRAN (R 3.5.0)
## withr 2.1.2 2018-03-15 CRAN (R 3.5.0)
Betancourt, M. J., and Mark Girolami. 2013. “Hamiltonian Monte Carlo for Hierarchical Models,” December. http://arxiv.org/abs/1312.0906.