\[ %% % Add your macros here; they'll be included in pdf and html output. %% \newcommand{\R}{\mathbb{R}} % reals \newcommand{\E}{\mathbb{E}} % expectation \renewcommand{\P}{\mathbb{P}} % probability \DeclareMathOperator{\logit}{logit} \DeclareMathOperator{\logistic}{logistic} \DeclareMathOperator{\SE}{SE} \DeclareMathOperator{\sd}{sd} \DeclareMathOperator{\var}{var} \DeclareMathOperator{\cov}{cov} \DeclareMathOperator{\cor}{cor} \DeclareMathOperator{\Normal}{Normal} \DeclareMathOperator{\MVN}{MVN} \DeclareMathOperator{\LogNormal}{logNormal} \DeclareMathOperator{\Poisson}{Poisson} \DeclareMathOperator{\Beta}{Beta} \DeclareMathOperator{\Binom}{Binomial} \DeclareMathOperator{\Gam}{Gamma} \DeclareMathOperator{\Exp}{Exponential} \DeclareMathOperator{\Cauchy}{Cauchy} \DeclareMathOperator{\Unif}{Unif} \DeclareMathOperator{\Dirichlet}{Dirichlet} \DeclareMathOperator{\Wishart}{Wishart} \DeclareMathOperator{\StudentsT}{StudentsT} \DeclareMathOperator{\Weibull}{Weibull} \newcommand{\given}{\;\vert\;} \]

Shrinkage: sparsifying priors

Peter Ralph

Advanced Biological Statistics

Overfitting: when you have too much information

Example data

from Efron, Hastie, Johnstone, & Tibshirani
library(lars)
data(diabetes)
class(diabetes$x2) <- "matrix"
diabetes                 package:lars                  R Documentation

Blood and other measurements in diabetics

Description:

     The ‘diabetes’ data frame has 442 rows and 3 columns. These are
     the data used in the Efron et al "Least Angle Regression" paper.

Format:

     This data frame contains the following columns:

     x a matrix with 10 columns

     y a numeric vector

     x2 a matrix with 64 columns

The dataset has

  • 442 diabetes patients
  • 10 main variables: age, gender, body mass index, average blood pressure (map), and six blood serum measurements (tc, ldl, hdl, tch, ltg, glu)
  • 45 interactions, e.g. age:ldl
  • 9 quadratic effects, e.g. age^2
  • measure of disease progression taken one year later: y

plot of chunk r show_cors

cor(cbind(diabetes$x, y=diabetes$y))
##        age    sex    bmi   map    tc   ldl    hdl   tch   ltg   glu      y
## age  1.000  0.174  0.185  0.34 0.260  0.22 -0.075  0.20  0.27  0.30  0.188
## sex  0.174  1.000  0.088  0.24 0.035  0.14 -0.379  0.33  0.15  0.21  0.043
## bmi  0.185  0.088  1.000  0.40 0.250  0.26 -0.367  0.41  0.45  0.39  0.586
## map  0.335  0.241  0.395  1.00 0.242  0.19 -0.179  0.26  0.39  0.39  0.441
## tc   0.260  0.035  0.250  0.24 1.000  0.90  0.052  0.54  0.52  0.33  0.212
## ldl  0.219  0.143  0.261  0.19 0.897  1.00 -0.196  0.66  0.32  0.29  0.174
## hdl -0.075 -0.379 -0.367 -0.18 0.052 -0.20  1.000 -0.74 -0.40 -0.27 -0.395
## tch  0.204  0.332  0.414  0.26 0.542  0.66 -0.738  1.00  0.62  0.42  0.430
## ltg  0.271  0.150  0.446  0.39 0.516  0.32 -0.399  0.62  1.00  0.46  0.566
## glu  0.302  0.208  0.389  0.39 0.326  0.29 -0.274  0.42  0.46  1.00  0.382
## y    0.188  0.043  0.586  0.44 0.212  0.17 -0.395  0.43  0.57  0.38  1.000

Crossvalidation plan

  1. Put aside 20% of the data for testing.

  2. Refit the model.

  3. Predict the test data; compute \[\begin{aligned} S = \sqrt{\frac{1}{M} \sum_{k=1}^M (\hat y_i - y_i)^2} \end{aligned}\]

  1. Repeat for the other four 20%s.

  2. Compare.

Crossvalidation

First let’s split the data into testing and training just once:

test_observations <- (rbinom(nrow(diabetes), size=1, prob=0.2) == 1)
test_d <- cbind(data.frame(y=diabetes$y[test_observations]),
                diabetes$x2[test_observations,])
training_d <- cbind(data.frame(y=diabetes$y[!test_observations]),
                diabetes$x2[!test_observations,])

Ordinary least squares

ols <- lm(y ~ ., data=training_d)
summary(ols)
## 
## Call:
## lm(formula = y ~ ., data = training_d)
## 
## Residuals:
##    Min     1Q Median     3Q    Max 
## -144.3  -32.5   -1.1   30.8  150.4 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept)   150.688      2.952   51.04  < 2e-16 ***
## age            84.949     76.868    1.11  0.27004    
## sex          -269.634     74.608   -3.61  0.00036 ***
## bmi           472.782     95.168    4.97  1.2e-06 ***
## map           360.868     83.425    4.33  2.1e-05 ***
## tc          -5344.136  61836.313   -0.09  0.93119    
## ldl          4723.000  54348.535    0.09  0.93081    
## hdl          1680.538  23107.326    0.07  0.94207    
## tch           -85.184    310.280   -0.27  0.78387    
## ltg          2350.565  20327.550    0.12  0.90802    
## glu            89.653     82.363    1.09  0.27730    
## `age^2`        66.327     81.718    0.81  0.41767    
## `bmi^2`       -15.807     98.076   -0.16  0.87207    
## `map^2`       -52.390     81.903   -0.64  0.52291    
## `tc^2`       4501.315   7881.979    0.57  0.56839    
## `ldl^2`      1315.348   5909.806    0.22  0.82403    
## `hdl^2`      1030.418   1782.583    0.58  0.56369    
## `tch^2`      1153.506    714.697    1.61  0.10764    
## `ltg^2`      1092.080   1792.281    0.61  0.54280    
## `glu^2`       128.336    105.649    1.21  0.22547    
## `age:sex`     148.196     90.659    1.63  0.10323    
## `age:bmi`       0.261     91.854    0.00  0.99773    
## `age:map`      20.594     92.257    0.22  0.82352    
## `age:tc`     -381.435    724.342   -0.53  0.59889    
## `age:ldl`     210.218    572.335    0.37  0.71367    
## `age:hdl`     200.905    332.316    0.60  0.54595    
## `age:tch`      61.393    261.202    0.24  0.81435    
## `age:ltg`     226.963    253.924    0.89  0.37217    
## `age:glu`     123.565     97.013    1.27  0.20381    
## `sex:bmi`     151.251     90.586    1.67  0.09608 .  
## `sex:map`      34.898     92.578    0.38  0.70649    
## `sex:tc`      710.474    742.200    0.96  0.33925    
## `sex:ldl`    -583.090    593.504   -0.98  0.32671    
## `sex:hdl`     -89.042    339.134   -0.26  0.79308    
## `sex:tch`     -61.988    232.199   -0.27  0.78969    
## `sex:ltg`    -210.102    273.395   -0.77  0.44283    
## `sex:glu`       2.214     83.695    0.03  0.97891    
## `bmi:map`     232.783    105.268    2.21  0.02781 *  
## `bmi:tc`     -449.811    783.837   -0.57  0.56652    
## `bmi:ldl`     449.714    655.611    0.69  0.49331    
## `bmi:hdl`     123.457    381.237    0.32  0.74630    
## `bmi:tch`    -132.984    266.229   -0.50  0.61781    
## `bmi:ltg`     132.106    300.483    0.44  0.66053    
## `bmi:glu`      88.775    100.372    0.88  0.37720    
## `map:tc`      164.889    829.983    0.20  0.84267    
## `map:ldl`     -35.365    692.990   -0.05  0.95934    
## `map:hdl`     -84.927    384.110   -0.22  0.82517    
## `map:tch`    -114.154    239.770   -0.48  0.63437    
## `map:ltg`       3.840    326.807    0.01  0.99063    
## `map:glu`    -244.481    107.705   -2.27  0.02396 *  
## `tc:ldl`    -4837.758  13111.399   -0.37  0.71242    
## `tc:hdl`    -2183.968   4297.196   -0.51  0.61169    
## `tc:tch`    -2109.486   1982.692   -1.06  0.28825    
## `tc:ltg`    -2127.476  13625.947   -0.16  0.87604    
## `tc:glu`      950.939    944.060    1.01  0.31465    
## `ldl:hdl`     750.673   3596.170    0.21  0.83480    
## `ldl:tch`     685.721   1687.047    0.41  0.68471    
## `ldl:ltg`    1301.031  11332.178    0.11  0.90868    
## `ldl:glu`    -997.177    828.415   -1.20  0.22970    
## `hdl:tch`    1423.428   1141.877    1.25  0.21358    
## `hdl:ltg`     579.689   4796.146    0.12  0.90388    
## `hdl:glu`    -207.188    418.504   -0.50  0.62094    
## `tch:ltg`     231.583    710.805    0.33  0.74481    
## `tch:glu`     195.247    265.217    0.74  0.46223    
## `ltg:glu`    -262.969    369.010   -0.71  0.47666    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 53 on 284 degrees of freedom
## Multiple R-squared:  0.627,  Adjusted R-squared:  0.543 
## F-statistic: 7.47 on 64 and 284 DF,  p-value: <2e-16
ols_pred <- predict(ols, newdata=test_d)
ols_mse <- sqrt(mean((ols_pred - test_d$y)^2))
c(train=sqrt(mean(resid(ols)^2)),
  test=ols_mse)
## train  test 
##    48    61

With an ordinary linear model, we got a root-mean-square-prediction-error of 61.35 (on the test data), compared to a root-mean-square-error of 47.91 for the training data.

This suggests there’s some overfitting going on.

plot(training_d$y, predict(ols), xlab="true values", ylab="OLS predicted", main="training data", pch=20, asp=1)
abline(0,1)
plot(test_d$y, ols_pred, xlab="true values", ylab="OLS predicted", main="test data", pch=20, asp=1)
abline(0,1)

plot of chunk r plot_ols

A sparsifying prior

We have a lot of predictors: 64 of them. A good guess is that only a few are really useful. So, we can put a sparsifying prior on the coefficients, i.e., \(\beta\)s in \[\begin{aligned} y = \beta_0 + \beta_1 x_1 + \cdots \beta_n x_n + \epsilon \end{aligned}\]

Sparseness and scale mixtures

Encouraging sparseness

Suppose we do a linear model with a large number of predictor variables.

The resulting coefficients are sparse if most are zero.

The idea is to “encourage” all the coefficients to be zero, unless they really want to be nonzero, in which case we let them be whatever they want.

This tends to discourage overfitting.

The idea is to “encourage” all the coefficients to be zero, unless they really want to be nonzero, in which case we let them be whatever they want.

To do this, we want a prior which is very peak-ey at zero but flat away from zero (“spike-and-slab”).

Options:

  1. Student’s \(t\) (e.g., Cauchy)
  2. Scale mixtures

Compare the Normal

\[\begin{aligned} X \sim \Normal(0,1) \end{aligned}\]

to the “exponential scale mixture of Normals”,

\[\begin{aligned} X &\sim \Normal(0,\sigma) \\ \sigma &\sim \Exp(1) . \end{aligned}\]

plot of chunk r scale_mixturesplot of chunk r scale_mixtures

Exercise:

  1. Write a function rexponormal(n, lambda) that returns n independent draws from: \[\begin{aligned} X &\sim \Normal(0,\sigma) \\ \sigma &\sim \Exp(\lambda) . \end{aligned}\]

  2. Find, by trial-and-error, a value of \(\lambda\) so that this give you random numbers that are usually between \(\pm 200\) but are sometimes as big as \(\pm 500\) or \(\pm 1000\).

Why use a scale mixture?

  1. Lets the data choose the appropriate scale of variation.

  2. Weakly encourages \(\sigma\) to be small: so, as much variation as possible is explained by signal instead of noise.

  3. Gets you a prior that is more peaked at zero and flatter otherwise.

A strongly sparsifying prior

The “horseshoe”:

\[\begin{aligned} \beta_j &\sim \Normal(0, \lambda_j) \\ \lambda_j &\sim \Cauchy(0, \tau) \\ \tau &\sim \Unif(0, 1) \end{aligned}\]

Application

The plan

Fit a linear model to the diabetes dataset with:

  1. no prior
  2. horseshoe priors

OLS in brms

xy <- cbind(data.frame(y=diabetes$y), diabetes$x2)
names(xy) <- gsub("[:^]", "_", names(xy))

# ols
brms_ols <- brm(y ~ ., data=xy, family=gaussian(link='identity'), file="cache/diabetes_ols.rds")
## Compiling Stan program...
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.33, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess

horseshoe in brms

brms_hs <- brm(y ~ ., data=xy, family=gaussian(link='identity'), 
               prior=c(set_prior(horseshoe(), class="b")), file="cache/diabetes_hs_prior.rds")
## Compiling Stan program...
## Start sampling
## Warning: There were 1 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: Examine the pairs() plot to diagnose sampling problems

Crossvalidation error function

brms_kfold <- function (K, models) {
    stopifnot(!is.null(names(models)))
    Kfold <- sample(rep(1:K, nrow(xy)/K))
    results <- data.frame(rep=1:K)
    for (j in seq_along(models)) {
        train <- test <- rep(NA, K)
        for (k in 1:K) {
            new_fit <- update(models[[j]], newdata=subset(xy, Kfold != k))
            train[k] <- sqrt(mean(resid(new_fit)[,"Estimate"]^2))
            test_y <- xy$y[Kfold == k]
            test[k] <- sqrt(mean(
                   (test_y - predict(new_fit, newdata=subset(xy, Kfold==k))[,"Estimate"])^2 ))
        }
        results[[paste0(names(models)[j], "_train")]] <- train
        results[[paste0(names(models)[j], "_test")]] <- test
    }
    return(results)
}
brms_xvals <- brms_kfold(5, list('ols'=brms_ols, 'horseshoe'=brms_hs))
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.4, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.64, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.14, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.14, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.32, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 2 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: Examine the pairs() plot to diagnose sampling problems
## Start sampling
## Start sampling
## Warning: There were 1 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.

## Warning: Examine the pairs() plot to diagnose sampling problems
## Start sampling
## Start sampling

Crossvalidation results

plot of chunk r show_brms_xvals

Coefficients:

plot of chunk r brmresults

plot of chunk r brmresults2

Coefficients, zoomed in:

plot of chunk r brmresultsB

plot of chunk r brmresults2B

Model fit

What’s an appropriate noise distribution?

plot of chunk r show_y

Aside: quantile-quantile plots

The idea is to plot the quantiles of each distribution against each other.

If these are datasets, this means just plotting their sorted values against each other.

x <- rnorm(1e4)
y <- rbeta(1e4, 2, 2)
plot(sort(x), sort(y)); qqplot(x, y, main="qqplot"); qqnorm(y, main="qnorm")

plot of chunk r qq

The noise distribution applies to the residuals!

\[\begin{aligned} y_i &= \sum_j \beta_j x_{ij} + \epsilon_i \\ \epsilon_i &\sim \Normal(0, \sigma^2) . \end{aligned}\]

ols_resids <- resid(lm(y ~ ., data=xy))
qqnorm(ols_resids)
qqline(ols_resids)

plot of chunk r the_resids

Exercise

Suppose we measure the expression of gene FLC in 100 individuals with heights between 15cm and 100cm. The mean expression level in plants of height \(h\)cm is \(2.3 h\) RPKM, with a standard deviation of 10 RPKM. Simulate data, then make histograms of:

  1. the expression levels
  2. the residual expression levels, given height.

Which looks Normally distributed?

Posterior preditive checks:

pp_check(brms_hs)
## Using 10 posterior draws for ppc type 'dens_overlay' by default.

plot of chunk r the_pp

More general crossvalidation

brms::kfold

The kfold function will automatically do \(k\)-fold crossvalidation! For instance:

(khs <- brms::kfold(brms_hs, K=5))
## Fitting model 1 out of 5
## Fitting model 2 out of 5
## Fitting model 3 out of 5
## Fitting model 4 out of 5
## Fitting model 5 out of 5
## Start sampling
## Start sampling
## Start sampling
## Start sampling
## Start sampling
## 
## Based on 5-fold cross-validation
## 
##            Estimate   SE
## elpd_kfold  -2405.0 13.7
## p_kfold        33.3  3.1
## kfoldic      4810.0 27.4

elpd = “expected log posterior density”

## 
## Based on 5-fold cross-validation
## 
##            Estimate   SE
## elpd_kfold  -2405.0 13.7
## p_kfold        33.3  3.1
## kfoldic      4810.0 27.4
// reveal.js plugins