Twittter Github Email

Here is a little example of fitting mixture model with R and Stan.

1 Mixture model

Given a set of likelihood functions \(\pi_1(\boldsymbol y \mid \alpha_1), \ldots, \pi_k(\boldsymbol y \mid \alpha_K)\) for variables \(\boldsymbol y = (y_1, \ldots, y_N)\), parameterized by \(\boldsymbol \alpha = (\alpha_1, \ldots, \alpha_K)\) and mixture weights \(\boldsymbol \theta = (\theta_1, \ldots, \theta_K)\) such that \(0 \le \theta_{k} \le 1\) and \(\sum\theta_{k} = 1\), the mixture likelihood is

\[ \pi(\boldsymbol y \mid{\boldsymbol \alpha}, {\boldsymbol \theta}) = \sum_{k=1}^{K}\sum_{n=1}^{N}\theta_k \pi_k(y_n \mid \alpha_k). \]

2 Poisson-lognormal mixture model

Let’s consider a simple example where the likelihood is given by a mixture of two Poisson-lognormal (PoiLog),

\[ \pi(y_1, ..., y_N \mid \mu_1, \mu_2, \sigma_1,\sigma_2, \theta) = \sum_{n=1}^{N}\theta PoiLog(y_n \mid \mu_1, \sigma_1) + \sum_{n=1}^{N}(1-\theta) PoiLog(y_n \mid \mu_2, \sigma_2). \]

We define two different components of Poisson-lognormal distributions, \(PoiLog(\mu_1 = log(50) \simeq 3.91, \sigma_1 = 0.8)\) and \(PoiLog(\mu_2 = log(400) \simeq 5.99, \sigma_2 = 0.2)\) with the mixture weight \(\theta = 0.8\)

then we simulate data based on the above probability density.

3 Testing

3.1 Centered parameterization

First, we fit the direct centered parameterization.

\[ \mu_1,\mu_2 \sim \mathcal{N}(0, 10) \]

\[ \sigma_1, \sigma_2 \sim \text{Half-}Cauchy(0, 5), \]

\[ log\lambda_{1n} \sim \mathcal{N}(\mu_1, \sigma_1) \]

\[ log\lambda_{2n} \sim \mathcal{N}(\mu_2, \sigma_2) \]

\[ \theta \sim Beta(2, 2) \]

## data{
##  int<lower=0> N;
##  int y[N];
## }
## 
## parameters {
##   ordered[2] mu;
##   real<lower=0> sigma[2];
##   real<lower=0, upper=1> theta;
##   vector[N] log_lambda[2];
## }
## model {
##  sigma ~ cauchy(0, 2.5);
##  mu ~ normal(0, 10);
##  theta ~ beta(2, 2);
##  for (i in 1:2)
##   log_lambda[i,] ~ normal(mu[i], sigma[i]);
##  for (n in 1:N)
##   target += log_mix(theta,
##                     poisson_log_lpmf(y[n] | log_lambda[1, n]),
##                     poisson_log_lpmf(y[n] | log_lambda[2, n]));
## }

The model looks straightforward but this model will suffer inefficiency in parameter sampling. Because the values of \(\mu\) and \(\sigma\) are highly correlated in the posterior (Betancourt and Girolami 2013). When the model is expressed this way, Stan has trouble sampling from the neck of the funnel geometry (Figure 28.1 in Stan Version 2.17.0 Manual), where \(\sigma\) is small and thus \(\mu\) is constrained to be near 0.

Loading required package: StanHeaders
rstan (Version 2.17.3, GitRev: 2e1f913d3ca3)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)

Attaching package: 'rstan'
The following object is masked from 'package:tidyr':

    extract
Warning: There were 4 divergent transitions after warmup. Increasing adapt_delta above 0.9 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
Warning: There were 2 chains where the estimated Bayesian Fraction of Missing Information was low. See
http://mc-stan.org/misc/warnings.html#bfmi-low
Warning: Examine the pairs() plot to diagnose sampling problems

The model did not converge well. Additionally, there were divergent transitions, which indicates that the simulated Hamiltonian diverges from the true Hamiltonian. Even a few numbers of divergent transitions indicate serious bias on the estimates.

## Inference for Stan model: poilog_mix_cp.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                      mean se_mean     sd    2.5%     25%     50%     75%
## mu[1]                3.92    0.12   0.20    3.66    3.78    3.86    4.04
## mu[2]                6.73    0.68   2.81    4.39    5.94    5.99    6.06
## sigma[1]             0.86    0.11   0.17    0.64    0.73    0.80    0.98
## sigma[2]             1.28    1.08   3.02    0.17    0.22    0.26    0.37
## theta                0.85    0.05   0.08    0.72    0.79    0.83    0.91
## log_lambda[1,1]      4.19    0.00   0.13    3.94    4.11    4.19    4.27
## log_lambda[1,2]      3.98    0.00   0.13    3.71    3.89    3.98    4.07
## log_lambda[1,100]    3.08    0.00   0.21    2.67    2.94    3.09    3.22
## log_lambda[2,1]      6.75    0.71   4.60    2.63    5.77    6.00    6.24
## log_lambda[2,2]      6.72    0.68   4.21    2.67    5.77    6.00    6.24
## log_lambda[2,100]    6.75    0.71   4.23    2.81    5.78    6.00    6.24
## lp__              -391.00   63.61 113.81 -715.46 -380.54 -345.35 -327.57
##                     97.5% n_eff Rhat
## mu[1]                4.36     2 2.13
## mu[2]               17.09    17 1.22
## sigma[1]             1.22     2 2.23
## sigma[2]            11.05     8 1.96
## theta                0.99     2 2.28
## log_lambda[1,1]      4.41  2297 1.00
## log_lambda[1,2]      4.24  4000 1.00
## log_lambda[1,100]    3.46  4000 1.00
## log_lambda[2,1]     19.39    42 1.08
## log_lambda[2,2]     18.64    38 1.08
## log_lambda[2,100]   19.79    35 1.09
## lp__              -300.65     3 3.63
## 
## Samples were drawn using NUTS(diag_e) at Thu Sep 27 11:22:27 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

3.2 Non-Centered parameterization

Instead of sampling \(log\lambda_{1n}\) and \(log\lambda_{2n}\) directly, the model can be converted to the following more efficient form,

\[ \mu_1,\mu_2 \sim \mathcal{N}(0, 10) \]

\[ \sigma_1, \sigma_2 \sim \text{Half-}Cauchy(0, 5), \]

\[ log\tilde\lambda_{1n},log\tilde\lambda_{2n} \sim \mathcal{N}(0,1) \]

\[ log\lambda_{1n} = \mu_1 + log\tilde\lambda_{1n} \times \sigma_1 \]

\[ log\lambda_{2n} = \mu_2 + log\tilde\lambda_{2n} \times \sigma_2 \]

\[ \theta \sim Beta(2, 2) \]

## data{
##  int<lower=0> N;
##  int y[N];
## }
## 
## parameters {
##   ordered[2] mu;
##   real<lower=0> sigma[2];
##   real<lower=0, upper=1> theta;
##   vector[N] log_lambda_tilde[2];
## }
## transformed parameters {
##   vector[N] log_lambda[2];
##   for (i in 1:2)
##     log_lambda[i,] = mu[i] + log_lambda_tilde[i,] * sigma[i];
## }
## 
## model {
##  sigma ~ cauchy(0, 2.5);
##  mu ~ normal(0, 10);
##  theta ~ beta(2, 2);
##  for (i in 1:2)
##    log_lambda_tilde[i, ] ~ normal(0, 1);
## 
##  for (n in 1:N)
##   target += log_mix(theta,
##                     poisson_log_lpmf(y[n] | log_lambda[1, n]),
##                     poisson_log_lpmf(y[n] | log_lambda[2, n]));
## }

This non-centered parameterization exchanges a direct dependence between (\(\mu\), \(\sigma\)) and \(log\lambda\) for a dependence between (\(\mu\), \(\sigma\)) and \(y_i\), which would reduce the dependence between \(\lambda\), \(\mu\), and \(\sigma\).

Now we don’t have divergent transitions. The resulting estimates well recovered the original values (\(\boldsymbol \mu\), \(\boldsymbol \sigma\) and \(\theta\)) 🍺. Note that the above model will not work very well when the two components are relatively close (see detailed explanation).

## Inference for Stan model: poilog_mix_ncp.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                      mean se_mean    sd    2.5%     25%     50%     75%
## mu[1]                3.79    0.01  0.09    3.62    3.73    3.79    3.85
## mu[2]                5.97    0.01  0.07    5.83    5.93    5.97    6.02
## sigma[1]             0.75    0.01  0.09    0.61    0.69    0.74    0.80
## sigma[2]             0.26    0.01  0.07    0.16    0.21    0.25    0.30
## theta                0.80    0.00  0.04    0.71    0.77    0.80    0.83
## log_lambda[1,1]      4.18    0.00  0.12    3.94    4.10    4.19    4.27
## log_lambda[1,2]      3.97    0.00  0.14    3.70    3.88    3.98    4.07
## log_lambda[1,100]    3.08    0.00  0.20    2.67    2.95    3.09    3.23
## log_lambda[2,1]      5.97    0.00  0.29    5.37    5.80    5.97    6.15
## log_lambda[2,2]      5.97    0.00  0.29    5.40    5.79    5.98    6.16
## log_lambda[2,100]    5.98    0.00  0.28    5.41    5.81    5.98    6.15
## lp__              -505.78    0.65 13.07 -532.06 -514.33 -505.31 -496.95
##                     97.5% n_eff Rhat
## mu[1]                3.96   112 1.02
## mu[2]                6.11   135 1.05
## sigma[1]             0.96    86 1.04
## sigma[2]             0.43    83 1.02
## theta                0.88  4000 1.01
## log_lambda[1,1]      4.42  4000 1.00
## log_lambda[1,2]      4.23  4000 1.00
## log_lambda[1,100]    3.47  4000 1.00
## log_lambda[2,1]      6.54  4000 1.00
## log_lambda[2,2]      6.53  4000 1.00
## log_lambda[2,100]    6.53  4000 1.00
## lp__              -481.13   403 1.01
## 
## Samples were drawn using NUTS(diag_e) at Thu Sep 27 11:23:14 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

4 Computing Environment

## CXXFLAGS=-O3 -mtune=native -march=native -Wno-unused-variable -Wno-unused-function
## CC=clang
## CXX=clang++ -arch x86_64 -ftemplate-depth-256
## Session info -------------------------------------------------------------
##  setting  value                       
##  version  R version 3.5.1 (2018-07-02)
##  system   x86_64, darwin17.7.0        
##  ui       unknown                     
##  language (EN)                        
##  collate  en_US.UTF-8                 
##  tz       America/Detroit             
##  date     2018-09-27
## Packages -----------------------------------------------------------------
##  package      * version   date       source                         
##  assertthat     0.2.0     2017-04-11 CRAN (R 3.5.0)                 
##  BH             1.66.0-1  2018-02-13 CRAN (R 3.5.0)                 
##  cli            1.0.0     2017-11-05 CRAN (R 3.5.0)                 
##  colorspace     1.3-2     2016-12-14 CRAN (R 3.5.0)                 
##  crayon         1.3.4     2017-09-16 CRAN (R 3.5.0)                 
##  digest         0.6.15    2018-01-28 CRAN (R 3.5.0)                 
##  ggplot2      * 3.0.0     2018-07-03 cran (@3.0.0)                  
##  glue           1.3.0     2018-09-21 Github (tidyverse/glue@4e74901)
##  graphics     * 3.5.1     2018-07-26 local                          
##  grDevices    * 3.5.1     2018-07-26 local                          
##  grid           3.5.1     2018-07-26 local                          
##  gridExtra      2.3       2017-09-09 CRAN (R 3.5.0)                 
##  gtable         0.2.0     2016-02-26 CRAN (R 3.5.0)                 
##  inline         0.3.15    2018-05-18 CRAN (R 3.5.0)                 
##  labeling       0.3       2014-08-23 CRAN (R 3.5.0)                 
##  lattice        0.20-35   2017-03-25 CRAN (R 3.5.1)                 
##  lazyeval       0.2.1     2017-10-29 CRAN (R 3.5.0)                 
##  magrittr       1.5       2014-11-22 CRAN (R 3.5.0)                 
##  MASS           7.3-50    2018-04-30 CRAN (R 3.5.1)                 
##  Matrix         1.2-14    2018-04-13 CRAN (R 3.5.1)                 
##  methods      * 3.5.1     2018-07-26 local                          
##  mgcv           1.8-24    2018-06-23 CRAN (R 3.5.1)                 
##  munsell        0.5.0     2018-06-12 CRAN (R 3.5.0)                 
##  nlme           3.1-137   2018-04-07 CRAN (R 3.5.1)                 
##  pillar         1.2.3     2018-05-25 CRAN (R 3.5.0)                 
##  plyr           1.8.4     2016-06-08 CRAN (R 3.5.0)                 
##  R6             2.2.2     2017-06-17 CRAN (R 3.5.0)                 
##  RColorBrewer   1.1-2     2014-12-07 CRAN (R 3.5.1)                 
##  Rcpp           0.12.18   2018-07-23 cran (@0.12.18)                
##  RcppEigen      0.3.3.4.0 2018-02-07 CRAN (R 3.5.0)                 
##  reshape2       1.4.3     2017-12-11 CRAN (R 3.5.0)                 
##  rlang          0.2.2     2018-08-16 cran (@0.2.2)                  
##  rstan        * 2.17.3    2018-01-20 CRAN (R 3.5.0)                 
##  scales         1.0.0     2018-08-09 cran (@1.0.0)                  
##  StanHeaders  * 2.17.2    2018-01-20 CRAN (R 3.5.0)                 
##  stats        * 3.5.1     2018-07-26 local                          
##  stats4         3.5.1     2018-07-26 local                          
##  stringi        1.2.3     2018-06-12 CRAN (R 3.5.0)                 
##  stringr      * 1.3.1     2018-05-10 CRAN (R 3.5.0)                 
##  tibble       * 1.4.2     2018-01-22 CRAN (R 3.5.0)                 
##  tools          3.5.1     2018-07-26 local                          
##  utf8           1.1.4     2018-05-24 CRAN (R 3.5.0)                 
##  utils        * 3.5.1     2018-07-26 local                          
##  viridisLite    0.3.0     2018-02-01 CRAN (R 3.5.0)                 
##  withr          2.1.2     2018-03-15 CRAN (R 3.5.0)

References

Betancourt, M. J., and Mark Girolami. 2013. “Hamiltonian Monte Carlo for Hierarchical Models,” December. http://arxiv.org/abs/1312.0906.