\[ %% % Add your macros here; they'll be included in pdf and html output. %% \newcommand{\R}{\mathbb{R}} % reals \newcommand{\E}{\mathbb{E}} % expectation \renewcommand{\P}{\mathbb{P}} % probability \DeclareMathOperator{\logit}{logit} \DeclareMathOperator{\logistic}{logistic} \DeclareMathOperator{\SE}{SE} \DeclareMathOperator{\sd}{sd} \DeclareMathOperator{\var}{var} \DeclareMathOperator{\cov}{cov} \DeclareMathOperator{\cor}{cor} \DeclareMathOperator{\Normal}{Normal} \DeclareMathOperator{\MVN}{MVN} \DeclareMathOperator{\LogNormal}{logNormal} \DeclareMathOperator{\Poisson}{Poisson} \DeclareMathOperator{\Beta}{Beta} \DeclareMathOperator{\Binom}{Binomial} \DeclareMathOperator{\Gam}{Gamma} \DeclareMathOperator{\Exp}{Exponential} \DeclareMathOperator{\Cauchy}{Cauchy} \DeclareMathOperator{\Unif}{Unif} \DeclareMathOperator{\Dirichlet}{Dirichlet} \DeclareMathOperator{\Wishart}{Wishart} \DeclareMathOperator{\StudentsT}{StudentsT} \DeclareMathOperator{\Weibull}{Weibull} \newcommand{\given}{\;\vert\;} \]

Introduction to brms

Peter Ralph

Advanced Biological Statistics

Stan, but with formulas

The brms package lets you

fit hierarchical models using Stan

with mixed-model syntax!!!

# e.g.
brm(formula = z ~ x + y + (1 + y|f), data = xy,
    family = poisson(link='log'))
# or
brm(formula = z ~ x + y + (1 + y|f), data = xy,
    family = student(link='identity'))
brms logo

Overview of brms

Fitting models

brm(formula = z ~ x + y + (1 + y|f), data = xy,
    family = gaussian(link='identity'))
  • formula: almost just like lme4
  • data: must contain all the variables
  • family: distribution of response
  • link: connects mean to linear predictor

Parameterization

There are several classes of parameter in a brms model:

  • b : the population-level (or, fixed) effects
  • sd : the standard deviations of group-level (or, random) effects
  • family-specific parameters, like sigma for the Gaussian

Examples:

  • b_x : the slope of x : class="b", coef="x"
  • sd_f : the SD of effects for levels of f : class="sd", coef="f"

Setting priors

Pass a vector of “priors”, specified by

    set_prior(prior, class="b", ...)

where prior is some valid Stan code.

brm(formula = z ~ x + y + (1 + y|f), data = xy,
    family = gaussian(link='identity'),
    prior=c(set_prior("normal(0, 5)", class="b"),
            set_prior("cauchy(0, 1)", class="sd", coef="f")))

1. Set up the formula

xy <- data.frame(x = rnorm(100),
                 y = rexp(100),
                 f = factor(sample(letters[1:3], 100, replace=TRUE)))
xy$z <- xy$x + as.numeric(xy$f) * xy$y + rnorm(100, sd=0.1)
the_formula <- brmsformula(z ~ x + y + (1 + y | f))

2. Choose priors

Default:

get_prior(the_formula, data=xy)
##                   prior     class      coef group resp dpar nlpar bound       source
##                  (flat)         b                                            default
##                  (flat)         b         x                             (vectorized)
##                  (flat)         b         y                             (vectorized)
##                  lkj(1)       cor                                            default
##                  lkj(1)       cor               f                       (vectorized)
##  student_t(3, 1.6, 2.5) Intercept                                            default
##    student_t(3, 0, 2.5)        sd                                            default
##    student_t(3, 0, 2.5)        sd               f                       (vectorized)
##    student_t(3, 0, 2.5)        sd Intercept     f                       (vectorized)
##    student_t(3, 0, 2.5)        sd         y     f                       (vectorized)
##    student_t(3, 0, 2.5)     sigma                                            default

Choose modifications:

# for example, no good reason to do this
the_priors = c(set_prior("normal(0, 5)", class = "b"),
               set_prior("normal(0, 1)", class = "sd", coef="y", group="f"))

3. Fit the model

the_fit <- brm(the_formula, data=xy, family=gaussian(), 
               prior=the_priors)
## Compiling Stan program...
## Start sampling
## Warning: There were 121 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: There were 75 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems

4. Check converence

summary(the_fit)
## Warning: There were 121 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: z ~ x + y + (1 + y | f) 
##    Data: xy (Number of observations: 100) 
##   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
##          total post-warmup draws = 4000
## 
## Group-Level Effects: 
## ~f (Number of levels: 3) 
##                  Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)        0.08      0.19     0.00     0.52 1.00      926     1132
## sd(y)                1.16      0.42     0.54     2.16 1.00     1668     1890
## cor(Intercept,y)    -0.08      0.54    -0.95     0.91 1.00     1134     1187
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept     0.02      0.10    -0.09     0.15 1.00     2467     1498
## x             1.00      0.01     0.98     1.03 1.00     3633     2541
## y             1.90      0.66     0.55     3.20 1.00      994     1200
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     0.10      0.01     0.09     0.12 1.00     3580     2548
## 
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

Or…

launch_shinystan(the_fit)

4. Look at results

Summaries of, or samples from, the posteriors of:

  • fixef( ): “fixed” effects
  • ranef( ): “random” effects
  • fitted( ): posterior distribution of mean response (see posterior_epred)
  • predict( ): posterior distribution of actual responses (see posterior_predict)
  • hypothesis( ): get posterior distributions of functions of various parameters (e.g., difference between two classes)
  • conditional_effects( ): effect of one predictor conditioned on values of others

More tools:

  • parnames( ): list of parameter names
  • pp_check( ): compare response distribution to posterior predictive simulations
  • loo( ) leave-one-out crossvalidation for model comparison
  • bayes_R2( ): Bayesian \(r^2\)

More info:

Extracting information from the posterior

The structure of a GLM

\[\begin{aligned} Y &\sim \text{Family}(\text{mean}=\mu) \\ \mu &= \text{link}(X \beta) . \end{aligned}\]

We can ask about:

  1. the coefficients, \(\beta\),
  2. the mean, \(\mu = \E[Y]\), or
  3. the response, \(Y\).

… and, a GLMM

\[\begin{aligned} Y &\sim \text{Family}(\text{mean}=\mu) \\ \mu &= \text{link}(X \beta + Z u) . \end{aligned}\]

We can ask about:

  1. the coefficients, \(\beta\),
  2. the mean, \(\mu = \E[Y]\),
  3. the response, \(Y\).

… with or without the group-level (random) effects, \(u\).

Example: baseball

Recall the model

batting <- read.csv("data/BattingAveragePlus.csv", header=TRUE, stringsAsFactors=TRUE)
batting$scaled_height <- (batting$height - mean(batting$height))/sd(batting$height)
batting$scaled_weight <- (batting$weight - mean(batting$weight))/sd(batting$weight)
bb_fit <- brm(
      Hits  | trials(AtBats) ~ 0 + scaled_weight + scaled_height + PriPos + (1 | PriPos:Player),
      data = batting,
      family = "binomial",
      prior = c(prior(normal(0, 5), class = b),
                prior(normal(0, 5), class = sd)),
      iter = 2000, chains = 3
)
## Compiling Stan program...
## Start sampling

1. The coefficients, \(\beta\).

Example: Is height associated with batting average? Which positions tend to be better batters?

Translated: What’s the posterior distribution of the effect of height on logit batting average?

Tools: (all have pars= arguments)

  • summary( )
  • mcmc_hist( )
  • mcmc_intervals( )
  • extract( )
summary(bb_fit)
##  Family: binomial 
##   Links: mu = logit 
## Formula: Hits | trials(AtBats) ~ 0 + scaled_weight + scaled_height + PriPos + (1 | PriPos:Player) 
##    Data: batting (Number of observations: 886) 
##   Draws: 3 chains, each with iter = 2000; warmup = 1000; thin = 1;
##          total post-warmup draws = 3000
## 
## Group-Level Effects: 
## ~PriPos:Player (Number of levels: 886) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)     0.14      0.01     0.12     0.16 1.00     1173     1824
## 
## Population-Level Effects: 
##                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## scaled_weight         0.01      0.01    -0.01     0.04 1.00     1298     1966
## scaled_height        -0.01      0.01    -0.03     0.02 1.00     1326     1885
## PriPos1stBase        -1.08      0.03    -1.13    -1.03 1.00     1632     2030
## PriPos2ndBase        -1.09      0.03    -1.15    -1.04 1.00     1482     2158
## PriPos3rdBase        -1.06      0.02    -1.11    -1.01 1.00     1744     2141
## PriPosCatcher        -1.16      0.03    -1.21    -1.11 1.00     1851     2193
## PriPosCenterField    -1.05      0.03    -1.10    -1.00 1.00     1434     1943
## PriPosLeftField      -1.10      0.02    -1.14    -1.05 1.00     1619     1745
## PriPosPitcher        -1.91      0.05    -2.00    -1.82 1.00     3786     2374
## PriPosRightField     -1.05      0.03    -1.11    -0.99 1.00     1104     1437
## PriPosShortstop      -1.10      0.03    -1.16    -1.04 1.00     1653     1911
## 
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
mcmc_hist(bb_fit, pars=c("b_scaled_weight", "b_scaled_height"))

plot of chunk r plot_coefs

mcmc_intervals(bb_fit, regex_pars="b_PriPos.*")

plot of chunk r plot_betas

2. The mean, \(\mu = \E[Y]\).

Example: What is the mean batting average of each position, as a function of weight?

Tools:

  • conditional_effects( ): means in abstract conditions
  • fitted( ): expected values, summarized (optionally, for new data)
  • posterior_epred( ): draws from the posterior mean (optionally, for new data)

conditional_effects for average-height-and-weight:

conditional_effects(bb_fit, effects="PriPos")
## Setting all 'trials' variables to 1 by default if not specified otherwise.

plot of chunk r ce1

conditional_effects by weight at average height:

conditional_effects(bb_fit, effects="scaled_weight:PriPos")
## Setting all 'trials' variables to 1 by default if not specified otherwise.

plot of chunk r ce2

conditional_effects for shortstops by weight that are one SD above average height:

conditional_effects(bb_fit, effects="scaled_weight",
                    conditions=data.frame(PriPos="Shortstop", scaled_height=1))
## Setting all 'trials' variables to 1 by default if not specified otherwise.

plot of chunk r ce3

The fitted( ) values: for the original data

bb_fitted <- fitted(bb_fit)
ggplot(cbind(batting, bb_fitted)) +
       geom_segment(aes(x=Q2.5, xend=Q97.5, y=Hits, yend=Hits), col='red') +
       geom_point(aes(x=Estimate, y=Hits)) +
       geom_abline(intercept=0, slope=1) +
       ylab("actual hits") + xlab("predicted hits") + coord_fixed()

plot of chunk r fitted_bb

posterior_epred( ): posterior distribution of expected values

Mean number of hits out of 50 at-bats for a shortstop of average weight that is 1.5 SD above average height?

pp <- posterior_epred(bb_fit,
        newdata=data.frame(
            PriPos="Shortstop",
            scaled_weight=0,
            scaled_height=1.5,
            AtBats=50),
        re_formula=NA
)
head(pp)
##          [,1]
## [1,] 12.47813
## [2,] 12.26702
## [3,] 12.26739
## [4,] 12.30489
## [5,] 12.54477
## [6,] 12.13928
ggplot(data.frame(hits=pp)) + geom_histogram(aes(x=hits), bins=32)

plot of chunk r plot_pp

3. The response, \(Y\):

Suppose Miguel Cairo, and Franklin Gutierrez go up to bat 150 more times, how many hits will they get?

(mc <- subset(batting, Player %in% c("Miguel Cairo", "Franklin Gutierrez"))[, c("Player", "PriPos", "Hits", "AtBats", "scaled_height", "scaled_weight" )])
##                 Player       PriPos Hits AtBats scaled_height scaled_weight
## 329 Franklin Gutierrez Center Field   39    150     0.2568063    -0.5613294
## 641       Miguel Cairo     1st Base   28    150    -0.1757610     0.6614721

Tools:

  • predict( ): summary
  • posterior_predict( ): get samples
cbind(mc,
    predict(bb_fit, newdata=mc)
)
##                 Player       PriPos Hits AtBats scaled_height scaled_weight Estimate Est.Error Q2.5 Q97.5
## 329 Franklin Gutierrez Center Field   39    150     0.2568063    -0.5613294 38.90500  6.355636   27    52
## 641       Miguel Cairo     1st Base   28    150    -0.1757610     0.6614721 34.83633  6.001574   24    47

fitted( ) vs predict( )

predict always has greater uncertainty.

cbind(mc,
    predict=predict(bb_fit, newdata=mc),
    fitted=fitted(bb_fit, newdata=mc)
)
##                 Player       PriPos Hits AtBats scaled_height scaled_weight predict.Estimate predict.Est.Error predict.Q2.5 predict.Q97.5 fitted.Estimate fitted.Est.Error fitted.Q2.5 fitted.Q97.5
## 329 Franklin Gutierrez Center Field   39    150     0.2568063    -0.5613294         38.88267          6.361425           27            52        38.92534         3.247120    32.81933     45.41366
## 641       Miguel Cairo     1st Base   28    150    -0.1757610     0.6614721         34.82467          5.905609           24            47        34.71744         3.186855    28.87710     41.10412

Exercise

\[\begin{aligned} Z_i &\sim \Binom(N_i, \theta_i) \\ \theta_i &= \logit(\beta_{p_i} + u_i) \\ u_i &\sim \Normal(0, \sigma^2) \end{aligned}\]

where for the \(i^\text{th}\) player,

  • \(N_i\): number of at-bats
  • \(Z_i\): number of hits
  • \(p_i\): position
  • \(\beta_p\): position effect
  • \(u_i\): random player effect
##           Player   PriPos Hits AtBats scaled_height scaled_weight
## 20 Albert Pujols 1st Base  173    607     0.6893737      1.395153

How would we use the posterior to summarize:

  1. Mean batting average of 1st base players of average weight and height?
  2. Range of batting averages of 1st base players?
  3. Albert Pujol’s batting average?
  4. How many hits Albert Pujol would get out of 100 at bats?

pp_check

Question: does our model fit the data?

Possible answer: gee, I dunno, let’s simulate from it and see?

Posterior predictive simulations

  1. Fit a model.

  2. Draw a set of parameters from the posterior distribution.

  3. With these, simulate a new data set.

  4. Do this a few times, and compare the results to the original dataset.

brms lets you do this with the pp_check(brms_fit, type='xyz') method

See the docs for options.

pp_check: datasets

pp_check(bb_fit, type='hist')

plot of chunk r bb_pp1

pp_check: means by group

pp_check(bb_fit, type='stat_grouped', group='PriPos')

plot of chunk r bb_pp2

pp_check: responses

pp_check(bb_fit, type='intervals_grouped', group='PriPos')

plot of chunk r bb_pp3

pp_check: scatterplots

pp_check(bb_fit, type='scatter')

plot of chunk r bb_pp4

Example: pumpkins

Let’s first go back to the pumpkin data from Week 5:

pumpkins <- read.table("data/pumpkins.tsv", header=TRUE)
pumpkins$plot <- factor(pumpkins$plot)
pumpkins$fertilizer <- factor(pumpkins$fertilizer)
pumpkins$water <- factor(pumpkins$water)

ggplot(pumpkins) + geom_boxplot(aes(x=fertilizer:water, y=weight, fill=water))

plot of chunk r plot_pumpkins

A mixed model with lme4:

Then, we fit a mixed model with lme4:

lme_pumpkins <- lmer( weight ~ water * fertilizer + (1|plot:water:fertilizer), data=pumpkins)
summary(lme_pumpkins)
## Linear mixed model fit by REML ['lmerMod']
## Formula: weight ~ water * fertilizer + (1 | plot:water:fertilizer)
##    Data: pumpkins
## 
## REML criterion at convergence: 170.3
## 
## Scaled residuals: 
##      Min       1Q   Median       3Q      Max 
## -2.64930 -0.54801  0.03808  0.70427  1.69905 
## 
## Random effects:
##  Groups                Name        Variance Std.Dev.
##  plot:water:fertilizer (Intercept) 0.01064  0.1031  
##  Residual                          0.21505  0.4637  
## Number of obs: 120, groups:  plot:water:fertilizer, 24
## 
## Fixed effects:
##                             Estimate Std. Error t value
## (Intercept)                   3.0019     0.1158  25.920
## waterwater                    6.0103     0.1638  36.696
## fertilizerlow                -2.0347     0.1638 -12.423
## fertilizermedium             -1.0567     0.1638  -6.452
## waterwater:fertilizerlow     -2.9360     0.2316 -12.676
## waterwater:fertilizermedium  -2.7175     0.2316 -11.732
## 
## Correlation of Fixed Effects:
##                (Intr) wtrwtr frtlzrl frtlzrm wtrwtr:frtlzrl
## waterwater     -0.707                                      
## fertilizrlw    -0.707  0.500                               
## fertilzrmdm    -0.707  0.500  0.500                        
## wtrwtr:frtlzrl  0.500 -0.707 -0.707  -0.354                
## wtrwtr:frtlzrm  0.500 -0.707 -0.354  -0.707   0.500

… with brms:

Here’s the “same thing” with brms:

brms_pumpkins <- brm( weight ~ water * fertilizer + (1|plot:water:fertilizer), data=pumpkins)
## Compiling Stan program...
## Start sampling
brms_pumpkins <- brm( weight ~ water * fertilizer + (1|plot:water:fertilizer), data=pumpkins)
summary(brms_pumpkins)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: weight ~ water * fertilizer + (1 | plot:water:fertilizer) 
##    Data: pumpkins (Number of observations: 120) 
##   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
##          total post-warmup draws = 4000
## 
## Group-Level Effects: 
## ~plot:water:fertilizer (Number of levels: 24) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)     0.11      0.07     0.01     0.28 1.01      824     1490
## 
## Population-Level Effects: 
##                             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept                       3.00      0.13     2.75     3.24 1.00     1462     1780
## waterwater                      6.02      0.18     5.66     6.36 1.00     1368     2067
## fertilizerlow                  -2.03      0.17    -2.37    -1.69 1.00     1474     2420
## fertilizermedium               -1.05      0.17    -1.39    -0.71 1.00     1516     1806
## waterwater:fertilizerlow       -2.94      0.24    -3.43    -2.46 1.00     1499     2099
## waterwater:fertilizermedium    -2.73      0.25    -3.22    -2.24 1.00     1541     2401
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     0.47      0.03     0.41     0.54 1.00     3376     2784
## 
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

Quick comparison:

summary(brms_pumpkins)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: weight ~ water * fertilizer + (1 | plot:water:fertilizer) 
##    Data: pumpkins (Number of observations: 120) 
##   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
##          total post-warmup draws = 4000
## 
## Group-Level Effects: 
## ~plot:water:fertilizer (Number of levels: 24) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)     0.11      0.07     0.01     0.28 1.01      824     1490
## 
## Population-Level Effects: 
##                             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept                       3.00      0.13     2.75     3.24 1.00     1462     1780
## waterwater                      6.02      0.18     5.66     6.36 1.00     1368     2067
## fertilizerlow                  -2.03      0.17    -2.37    -1.69 1.00     1474     2420
## fertilizermedium               -1.05      0.17    -1.39    -0.71 1.00     1516     1806
## waterwater:fertilizerlow       -2.94      0.24    -3.43    -2.46 1.00     1499     2099
## waterwater:fertilizermedium    -2.73      0.25    -3.22    -2.24 1.00     1541     2401
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     0.47      0.03     0.41     0.54 1.00     3376     2784
## 
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
summary(lme_pumpkins)
## Linear mixed model fit by REML ['lmerMod']
## Formula: weight ~ water * fertilizer + (1 | plot:water:fertilizer)
##    Data: pumpkins
## 
## REML criterion at convergence: 170.3
## 
## Scaled residuals: 
##      Min       1Q   Median       3Q      Max 
## -2.64930 -0.54801  0.03808  0.70427  1.69905 
## 
## Random effects:
##  Groups                Name        Variance Std.Dev.
##  plot:water:fertilizer (Intercept) 0.01064  0.1031  
##  Residual                          0.21505  0.4637  
## Number of obs: 120, groups:  plot:water:fertilizer, 24
## 
## Fixed effects:
##                             Estimate Std. Error t value
## (Intercept)                   3.0019     0.1158  25.920
## waterwater                    6.0103     0.1638  36.696
## fertilizerlow                -2.0347     0.1638 -12.423
## fertilizermedium             -1.0567     0.1638  -6.452
## waterwater:fertilizerlow     -2.9360     0.2316 -12.676
## waterwater:fertilizermedium  -2.7175     0.2316 -11.732
## 
## Correlation of Fixed Effects:
##                (Intr) wtrwtr frtlzrl frtlzrm wtrwtr:frtlzrl
## waterwater     -0.707                                      
## fertilizrlw    -0.707  0.500                               
## fertilzrmdm    -0.707  0.500  0.500                        
## wtrwtr:frtlzrl  0.500 -0.707 -0.707  -0.354                
## wtrwtr:frtlzrm  0.500 -0.707 -0.354  -0.707   0.500

Your turn

Try out:

  1. launch_shinystan(brms_pumpkins)

  2. conditional_effects(brms_pumpkins)

  3. pp_check(brms_pumpkins, type='scatter')

// reveal.js plugins