\[ %% % Add your macros here; they'll be included in pdf and html output. %% \newcommand{\R}{\mathbb{R}} % reals \newcommand{\E}{\mathbb{E}} % expectation \renewcommand{\P}{\mathbb{P}} % probability \DeclareMathOperator{\logit}{logit} \DeclareMathOperator{\logistic}{logistic} \DeclareMathOperator{\SE}{SE} \DeclareMathOperator{\sd}{sd} \DeclareMathOperator{\var}{var} \DeclareMathOperator{\cov}{cov} \DeclareMathOperator{\cor}{cor} \DeclareMathOperator{\Normal}{Normal} \DeclareMathOperator{\LogNormal}{logNormal} \DeclareMathOperator{\Poisson}{Poisson} \DeclareMathOperator{\Beta}{Beta} \DeclareMathOperator{\Binom}{Binomial} \DeclareMathOperator{\Gam}{Gamma} \DeclareMathOperator{\Exp}{Exponential} \DeclareMathOperator{\Cauchy}{Cauchy} \DeclareMathOperator{\Unif}{Unif} \DeclareMathOperator{\Dirichlet}{Dirichlet} \DeclareMathOperator{\Wishart}{Wishart} \DeclareMathOperator{\StudentsT}{StudentsT} \DeclareMathOperator{\Weibull}{Weibull} \newcommand{\given}{\;\vert\;} \]

Overfitting, crossvalidation, and sparsification

Peter Ralph

26 January 2021 – Advanced Biological Statistics

Prediction

Out-of-sample prediction

To test predictive ability (and diagnose overfitting!):

  1. Split the data into test and training pieces.
  2. Fit the model using the training data.
  3. See how well it predicts the test data.

If you do this a lot of times, it’s called crossvalidation.

is it Christmas? No. (99.73% accurate)

https://xkcd.com/2236/

Overfitting: when you have too much information

Example data

from Efron, Hastie, Johnstone, & Tibshirani
library(lars)
data(diabetes)
class(diabetes$x2) <- "matrix"
diabetes                 package:lars                  R Documentation

Blood and other measurements in diabetics

Description:

     The ‘diabetes’ data frame has 442 rows and 3 columns. These are
     the data used in the Efron et al "Least Angle Regression" paper.

Format:

     This data frame contains the following columns:

     x a matrix with 10 columns

     y a numeric vector

     x2 a matrix with 64 columns

The dataset has

  • 442 diabetes patients
  • 10 main variables: age, gender, body mass index, average blood pressure (map), and six blood serum measurements (tc, ldl, hdl, tch, ltg, glu)
  • 45 interactions, e.g. age:ldl
  • 9 quadratic effects, e.g. age^2
  • measure of disease progression taken one year later: y

plot of chunk show_cors

cor(cbind(diabetes$x, y=diabetes$y))
##        age    sex    bmi   map    tc   ldl    hdl   tch   ltg   glu      y
## age  1.000  0.174  0.185  0.34 0.260  0.22 -0.075  0.20  0.27  0.30  0.188
## sex  0.174  1.000  0.088  0.24 0.035  0.14 -0.379  0.33  0.15  0.21  0.043
## bmi  0.185  0.088  1.000  0.40 0.250  0.26 -0.367  0.41  0.45  0.39  0.586
## map  0.335  0.241  0.395  1.00 0.242  0.19 -0.179  0.26  0.39  0.39  0.441
## tc   0.260  0.035  0.250  0.24 1.000  0.90  0.052  0.54  0.52  0.33  0.212
## ldl  0.219  0.143  0.261  0.19 0.897  1.00 -0.196  0.66  0.32  0.29  0.174
## hdl -0.075 -0.379 -0.367 -0.18 0.052 -0.20  1.000 -0.74 -0.40 -0.27 -0.395
## tch  0.204  0.332  0.414  0.26 0.542  0.66 -0.738  1.00  0.62  0.42  0.430
## ltg  0.271  0.150  0.446  0.39 0.516  0.32 -0.399  0.62  1.00  0.46  0.566
## glu  0.302  0.208  0.389  0.39 0.326  0.29 -0.274  0.42  0.46  1.00  0.382
## y    0.188  0.043  0.586  0.44 0.212  0.17 -0.395  0.43  0.57  0.38  1.000

Crossvalidation plan

  1. Put aside 20% of the data for testing.

  2. Refit the model.

  3. Predict the test data; compute \[\begin{aligned} S = \sqrt{\frac{1}{M} \sum_{k=1}^M (\hat y_i - y_i)^2} \end{aligned}\]

  1. Repeat for the other four 20%s.

  2. Compare.

Crossvalidation

First let’s split the data into testing and training just once:

test_observations <- (rbinom(nrow(diabetes), size=1, prob=0.2) == 1)
test_d <- cbind(data.frame(y=diabetes$y[test_observations]),
                diabetes$x2[test_observations,])
training_d <- cbind(data.frame(y=diabetes$y[!test_observations]),
                diabetes$x2[!test_observations,])

Ordinary least squares

ols <- lm(y ~ ., data=training_d)
summary(ols)
## 
## Call:
## lm(formula = y ~ ., data = training_d)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -149.52  -31.10   -1.21   29.00  147.31 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept)    153.39       2.89   53.09  < 2e-16 ***
## age             68.06      74.61    0.91    0.362    
## sex           -327.29      78.36   -4.18 0.000040 ***
## bmi            401.24      98.21    4.09 0.000058 ***
## map            423.92      84.82    5.00 0.000001 ***
## tc          -13659.91   78291.28   -0.17    0.862    
## ldl          11966.61   68813.44    0.17    0.862    
## hdl           4816.50   29265.52    0.16    0.869    
## tch            -80.94     324.56   -0.25    0.803    
## ltg           5192.60   25741.01    0.20    0.840    
## glu             71.35      83.17    0.86    0.392    
## `age^2`         79.90      78.24    1.02    0.308    
## `bmi^2`         73.83      91.38    0.81    0.420    
## `map^2`         37.13      81.21    0.46    0.648    
## `tc^2`        6702.06    8693.93    0.77    0.441    
## `ldl^2`       4819.05    6659.04    0.72    0.470    
## `hdl^2`        577.56    1920.55    0.30    0.764    
## `tch^2`        495.64     726.06    0.68    0.495    
## `ltg^2`       2086.73    2194.66    0.95    0.343    
## `glu^2`        140.25     109.37    1.28    0.201    
## `age:sex`       94.23      82.12    1.15    0.252    
## `age:bmi`       24.56      89.76    0.27    0.785    
## `age:map`       -7.08      87.87   -0.08    0.936    
## `age:tc`       158.54     690.02    0.23    0.818    
## `age:ldl`     -364.79     554.99   -0.66    0.512    
## `age:hdl`       81.67     318.90    0.26    0.798    
## `age:tch`      148.36     233.76    0.63    0.526    
## `age:ltg`       88.02     256.37    0.34    0.732    
## `age:glu`       -7.31      91.64   -0.08    0.937    
## `sex:bmi`       57.41      88.76    0.65    0.518    
## `sex:map`       80.49      84.62    0.95    0.342    
## `sex:tc`       360.61     750.44    0.48    0.631    
## `sex:ldl`     -199.35     594.98   -0.34    0.738    
## `sex:hdl`     -253.86     353.48   -0.72    0.473    
## `sex:tch`     -310.52     245.72   -1.26    0.207    
## `sex:ltg`      -48.38     276.95   -0.17    0.861    
## `sex:glu`      115.51      82.17    1.41    0.161    
## `bmi:map`      110.59      98.39    1.12    0.262    
## `bmi:tc`      -477.61     746.48   -0.64    0.523    
## `bmi:ldl`      376.47     614.69    0.61    0.541    
## `bmi:hdl`      164.72     380.44    0.43    0.665    
## `bmi:tch`       73.15     304.09    0.24    0.810    
## `bmi:ltg`      126.66     293.18    0.43    0.666    
## `bmi:glu`        5.56     107.09    0.05    0.959    
## `map:tc`      1529.79     815.35    1.88    0.062 .  
## `map:ldl`    -1266.57     683.19   -1.85    0.065 .  
## `map:hdl`     -552.83     370.39   -1.49    0.137    
## `map:tch`       37.08     224.92    0.16    0.869    
## `map:ltg`     -670.41     334.45   -2.00    0.046 *  
## `map:glu`      -49.38     106.26   -0.46    0.643    
## `tc:ldl`    -10711.18   14616.62   -0.73    0.464    
## `tc:hdl`     -2539.93    4646.75   -0.55    0.585    
## `tc:tch`     -1288.03    2112.94   -0.61    0.543    
## `tc:ltg`     -2548.55   16926.60   -0.15    0.880    
## `tc:glu`      -238.97     739.35   -0.32    0.747    
## `ldl:hdl`     1844.84    3868.80    0.48    0.634    
## `ldl:tch`      471.69    1725.54    0.27    0.785    
## `ldl:ltg`     2009.36   14123.54    0.14    0.887    
## `ldl:glu`      139.56     641.01    0.22    0.828    
## `hdl:tch`      390.83    1195.72    0.33    0.744    
## `hdl:ltg`      766.89    5898.00    0.13    0.897    
## `hdl:glu`      277.75     350.76    0.79    0.429    
## `tch:ltg`      116.53     713.76    0.16    0.870    
## `tch:glu`      224.55     268.74    0.84    0.404    
## `ltg:glu`      149.20     329.31    0.45    0.651    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 52 on 277 degrees of freedom
## Multiple R-squared:  0.619,  Adjusted R-squared:  0.531 
## F-statistic: 7.04 on 64 and 277 DF,  p-value: <2e-16
ols_pred <- predict(ols, newdata=test_d)
ols_mse <- sqrt(mean((ols_pred - test_d$y)^2))
c(train=sqrt(mean(resid(ols)^2)),
  test=ols_mse)
## train  test 
##    47    61

With ordinary linear regression, we got a root-mean-square-prediction-error of 61.17 (on the test data), compared to a root-mean-square-error of 47.07 for the training data.

This suggests there’s some overfitting going on.

plot(training_d$y, predict(ols), xlab="true values", ylab="OLS predicted", main="training data", pch=20, asp=1)
abline(0,1)
plot(test_d$y, ols_pred, xlab="true values", ylab="OLS predicted", main="test data", pch=20, asp=1)
abline(0,1)

plot of chunk plot_ols

A sparsifying prior

We have a lot of predictors: 64 of them. A good guess is that only a few are really useful. So, we can put a sparsifying prior on the coefficients, i.e., \(\beta\)s in \[\begin{aligned} y = \beta_0 + \beta_1 x_1 + \cdots \beta_n x_n + \epsilon \end{aligned}\]

Overfitting: the effect of spurious variables

Who says we don’t do experiments?

  1. Simulate data with y ~ a + b x[1] + c x[2], and fit a linear model.
  2. Measure in-sample and out-of-sample prediction error.
  3. Add spurious variables, and report the above as a function of number of variables.

Basic data: \(y = a + b_1 x_1 + b_2 x_2 + \epsilon\).

N <- 500
df <- data.frame(x1 = rnorm(N),
                 x2 = runif(N))
params <- list(intercept = 2.0,
               x1 = 7.0,
               x2 = -8.0,
               sigma = 1)
pred_y <- params$intercept + params$x1 * df$x1 + params$x2 * df$x2 
df$y <- rnorm(N, mean=pred_y, sd=params$sigma)
pairs(df)

plot of chunk in_class1

Crossvalidation error function

kfold <- function (K, df) {
    Kfold <- sample(rep(1:K, nrow(df)/K))
    results <- data.frame(test_error=rep(NA, K), train_error=rep(NA, K))
    for (k in 1:K) {
        the_lm <- lm(y ~ ., data=df, subset=(Kfold != k))
        results$train_error[k] <- sqrt(mean(resid(the_lm)^2))
        test_y <- df$y[Kfold == k]
        results$test_error[k] <- sqrt(mean(
                       (test_y - predict(the_lm, newdata=subset(df, Kfold==k)))^2 ))
    }
    return(results)
}

Add spurious variables

max_M <- 300  # max number of spurious variables
noise_df <- matrix(rnorm(nrow(df) * (max_M-2)), nrow=nrow(df))
colnames(noise_df) <- paste0('z', 1:ncol(noise_df))
new_df <- cbind(df, noise_df)
all_results <- data.frame(m=floor(seq(from=2, to=max_M-1, length.out=40)),
                          test_error=NA, train_error=NA)
for (j in 1:nrow(all_results)) {
    m <- all_results$m[j]
    all_results[j,2:3] <- colMeans(kfold(K=10, new_df[,1:(m+1)]))
}

Results

plot(all_results$m, all_results$test_error, type='l', lwd=2,
     xlab='number of variables', ylab='root mean square error', ylim=range(all_results[,2:3], 0))
lines(all_results$m, all_results$train_error, col=2, lwd=2)
legend("topleft", lty=1, col=1:2, lwd=2, legend=paste(c("test", "train"), "error"))

plot of chunk in_class5

Interlude

Estimation of infiltration rate from soil properties using regression model for cultivated land
EIR = 14,195.35 - 141.75 (sand%) - 142.10 (silt%) - 142.56 (clay%)

Use the data to try to reproduce their model:

BIR = 14,195.35 - 141.75 (sand%) - 142.10 (silt%) - 142.56 (clay%)

They’re not wrong! What’s up with those coefficients, though?

Sparseness and scale mixtures

Encouraging sparseness

Suppose we do regression with a large number of predictor variables.

The resulting coefficients are sparse if most are zero.

The idea is to “encourage” all the coefficients to be zero, unless they really want to be nonzero, in which case we let them be whatever they want.

This tends to discourage overfitting.

The idea is to “encourage” all the coefficients to be zero, unless they really want to be nonzero, in which case we let them be whatever they want.

To do this, we want a prior which is very peak-ey at zero but flat away from zero (“spike-and-slab”).

Compare the Normal

\[\begin{aligned} X \sim \Normal(0,1) \end{aligned}\]

to the “exponential scale mixture of Normals”,

\[\begin{aligned} X &\sim \Normal(0,\sigma) \\ \sigma &\sim \Exp(1) . \end{aligned}\]

plot of chunk scale_mixturesplot of chunk scale_mixtures

Why use a scale mixture?

  1. Lets the data choose the appropriate scale of variation.

  2. Weakly encourages \(\sigma\) to be small: so, as much variation as possible is explained by signal instead of noise.

  3. Gets you a prior that is more peaked at zero and flatter otherwise.

A strongly sparsifying prior

The “horseshoe”:

\[\begin{aligned} \beta_j &\sim \Normal(0, \lambda_j) \\ \lambda_j &\sim \Cauchy(0, \tau) \\ \tau &\sim \Unif(0, 1) \end{aligned}\]

Application

The plan

Fit a linear model to the diabetes dataset with:

  1. no prior
  2. horseshoe priors

OLS in brms

xy <- cbind(data.frame(y=diabetes$y), diabetes$x2)
names(xy) <- gsub("[:^]", "_", names(xy))

# ols
brms_ols <- brm(y ~ ., data=xy, family=gaussian(link='identity'))
## Compiling Stan program...
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.14, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess

horseshoe in brms

brms_hs <- brm(y ~ ., data=xy, family=gaussian(link='identity'), 
               prior=c(set_prior(horseshoe(), class="b")))
## Compiling Stan program...
## Start sampling

Crossvalidation error function

brms_kfold <- function (K, models) {
    stopifnot(!is.null(names(models)))
    Kfold <- sample(rep(1:K, nrow(xy)/K))
    results <- data.frame(rep=1:K)
    for (j in seq_along(models)) {
        train <- test <- rep(NA, K)
        for (k in 1:K) {
            new_fit <- update(models[[j]], newdata=subset(xy, Kfold != k))
            train[k] <- sqrt(mean(resid(new_fit)[,"Estimate"]^2))
            test_y <- xy$y[Kfold == k]
            test[k] <- sqrt(mean(
                   (test_y - predict(new_fit, newdata=subset(xy, Kfold==k))[,"Estimate"])^2 ))
        }
        results[[paste0(names(models)[j], "_train")]] <- train
        results[[paste0(names(models)[j], "_test")]] <- test
    }
    return(results)
}
brms_xvals <- brms_kfold(5, list('ols'=brms_ols, 'horseshoe'=brms_hs))
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.43, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.39, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.45, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.36, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.2, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Start sampling
## Warning: There were 1 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: Examine the pairs() plot to diagnose sampling problems
## Start sampling
## Start sampling
## Start sampling
## Warning: There were 1 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.

## Warning: Examine the pairs() plot to diagnose sampling problems

Crossvalidation results

plot of chunk show_brms_xvals

Coefficients:

plot of chunk brmresults

plot of chunk brmresults2

Model fit

What’s an appropriate noise distribution?

plot of chunk show_y

Aside: quantile-quantile plots

The idea is to plot the quantiles of each distribution against each other.

If these are datasets, this means just plotting their sorted values against each other.

x <- rnorm(1e4)
y <- rbeta(1e4, 2, 2)
plot(sort(x), sort(y)); qqplot(x, y, main="qqplot"); qqnorm(y, main="qnorm")

plot of chunk qq

Look at the residuals!

\[\begin{aligned} y_i &= \sum_j \beta_j x_{ij} + \epsilon_i \\ \epsilon_i &\sim \Normal(0, \sigma^2) . \end{aligned}\]

ols_resids <- resid(lm(y ~ ., data=xy))
qqnorm(ols_resids)
qqline(ols_resids)

plot of chunk the_resids

Posterior preditive checks:

pp_check(brms_hs)
## Using 10 posterior samples for ppc type 'dens_overlay' by default.

plot of chunk the_pp

More general crossvalidation

brms::kfold

The kfold function will automatically do \(k\)-fold crossvalidation! For instance:

(khs <- brms::kfold(brms_hs, K=5))
## Fitting model 1 out of 5
## Fitting model 2 out of 5
## Fitting model 3 out of 5
## Fitting model 4 out of 5
## Fitting model 5 out of 5
## Start sampling
## Start sampling
## Start sampling
## Start sampling
## Start sampling
## 
## Based on 5-fold cross-validation
## 
##            Estimate   SE
## elpd_kfold  -2405.0 14.0
## p_kfold        33.2  3.6
## kfoldic      4810.0 28.0

elpd = “expected log posterior density”

## 
## Based on 5-fold cross-validation
## 
##            Estimate   SE
## elpd_kfold  -2405.0 14.0
## p_kfold        33.2  3.6
## kfoldic      4810.0 28.0
// reveal.js plugins