Peter Ralph
26 January 2021 – Advanced Biological Statistics
To test predictive ability (and diagnose overfitting!):
If you do this a lot of times, it’s called crossvalidation.
diabetes package:lars R Documentation
Blood and other measurements in diabetics
Description:
The ‘diabetes’ data frame has 442 rows and 3 columns. These are
the data used in the Efron et al "Least Angle Regression" paper.
Format:
This data frame contains the following columns:
x a matrix with 10 columns
y a numeric vector
x2 a matrix with 64 columns
The dataset has
age:ldl
age^2
y
## age sex bmi map tc ldl hdl tch ltg glu y
## age 1.000 0.174 0.185 0.34 0.260 0.22 -0.075 0.20 0.27 0.30 0.188
## sex 0.174 1.000 0.088 0.24 0.035 0.14 -0.379 0.33 0.15 0.21 0.043
## bmi 0.185 0.088 1.000 0.40 0.250 0.26 -0.367 0.41 0.45 0.39 0.586
## map 0.335 0.241 0.395 1.00 0.242 0.19 -0.179 0.26 0.39 0.39 0.441
## tc 0.260 0.035 0.250 0.24 1.000 0.90 0.052 0.54 0.52 0.33 0.212
## ldl 0.219 0.143 0.261 0.19 0.897 1.00 -0.196 0.66 0.32 0.29 0.174
## hdl -0.075 -0.379 -0.367 -0.18 0.052 -0.20 1.000 -0.74 -0.40 -0.27 -0.395
## tch 0.204 0.332 0.414 0.26 0.542 0.66 -0.738 1.00 0.62 0.42 0.430
## ltg 0.271 0.150 0.446 0.39 0.516 0.32 -0.399 0.62 1.00 0.46 0.566
## glu 0.302 0.208 0.389 0.39 0.326 0.29 -0.274 0.42 0.46 1.00 0.382
## y 0.188 0.043 0.586 0.44 0.212 0.17 -0.395 0.43 0.57 0.38 1.000
Put aside 20% of the data for testing.
Refit the model.
Predict the test data; compute \[\begin{aligned} S = \sqrt{\frac{1}{M} \sum_{k=1}^M (\hat y_i - y_i)^2} \end{aligned}\]
Repeat for the other four 20%s.
Compare.
First let’s split the data into testing and training just once:
##
## Call:
## lm(formula = y ~ ., data = training_d)
##
## Residuals:
## Min 1Q Median 3Q Max
## -149.52 -31.10 -1.21 29.00 147.31
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 153.39 2.89 53.09 < 2e-16 ***
## age 68.06 74.61 0.91 0.362
## sex -327.29 78.36 -4.18 0.000040 ***
## bmi 401.24 98.21 4.09 0.000058 ***
## map 423.92 84.82 5.00 0.000001 ***
## tc -13659.91 78291.28 -0.17 0.862
## ldl 11966.61 68813.44 0.17 0.862
## hdl 4816.50 29265.52 0.16 0.869
## tch -80.94 324.56 -0.25 0.803
## ltg 5192.60 25741.01 0.20 0.840
## glu 71.35 83.17 0.86 0.392
## `age^2` 79.90 78.24 1.02 0.308
## `bmi^2` 73.83 91.38 0.81 0.420
## `map^2` 37.13 81.21 0.46 0.648
## `tc^2` 6702.06 8693.93 0.77 0.441
## `ldl^2` 4819.05 6659.04 0.72 0.470
## `hdl^2` 577.56 1920.55 0.30 0.764
## `tch^2` 495.64 726.06 0.68 0.495
## `ltg^2` 2086.73 2194.66 0.95 0.343
## `glu^2` 140.25 109.37 1.28 0.201
## `age:sex` 94.23 82.12 1.15 0.252
## `age:bmi` 24.56 89.76 0.27 0.785
## `age:map` -7.08 87.87 -0.08 0.936
## `age:tc` 158.54 690.02 0.23 0.818
## `age:ldl` -364.79 554.99 -0.66 0.512
## `age:hdl` 81.67 318.90 0.26 0.798
## `age:tch` 148.36 233.76 0.63 0.526
## `age:ltg` 88.02 256.37 0.34 0.732
## `age:glu` -7.31 91.64 -0.08 0.937
## `sex:bmi` 57.41 88.76 0.65 0.518
## `sex:map` 80.49 84.62 0.95 0.342
## `sex:tc` 360.61 750.44 0.48 0.631
## `sex:ldl` -199.35 594.98 -0.34 0.738
## `sex:hdl` -253.86 353.48 -0.72 0.473
## `sex:tch` -310.52 245.72 -1.26 0.207
## `sex:ltg` -48.38 276.95 -0.17 0.861
## `sex:glu` 115.51 82.17 1.41 0.161
## `bmi:map` 110.59 98.39 1.12 0.262
## `bmi:tc` -477.61 746.48 -0.64 0.523
## `bmi:ldl` 376.47 614.69 0.61 0.541
## `bmi:hdl` 164.72 380.44 0.43 0.665
## `bmi:tch` 73.15 304.09 0.24 0.810
## `bmi:ltg` 126.66 293.18 0.43 0.666
## `bmi:glu` 5.56 107.09 0.05 0.959
## `map:tc` 1529.79 815.35 1.88 0.062 .
## `map:ldl` -1266.57 683.19 -1.85 0.065 .
## `map:hdl` -552.83 370.39 -1.49 0.137
## `map:tch` 37.08 224.92 0.16 0.869
## `map:ltg` -670.41 334.45 -2.00 0.046 *
## `map:glu` -49.38 106.26 -0.46 0.643
## `tc:ldl` -10711.18 14616.62 -0.73 0.464
## `tc:hdl` -2539.93 4646.75 -0.55 0.585
## `tc:tch` -1288.03 2112.94 -0.61 0.543
## `tc:ltg` -2548.55 16926.60 -0.15 0.880
## `tc:glu` -238.97 739.35 -0.32 0.747
## `ldl:hdl` 1844.84 3868.80 0.48 0.634
## `ldl:tch` 471.69 1725.54 0.27 0.785
## `ldl:ltg` 2009.36 14123.54 0.14 0.887
## `ldl:glu` 139.56 641.01 0.22 0.828
## `hdl:tch` 390.83 1195.72 0.33 0.744
## `hdl:ltg` 766.89 5898.00 0.13 0.897
## `hdl:glu` 277.75 350.76 0.79 0.429
## `tch:ltg` 116.53 713.76 0.16 0.870
## `tch:glu` 224.55 268.74 0.84 0.404
## `ltg:glu` 149.20 329.31 0.45 0.651
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 52 on 277 degrees of freedom
## Multiple R-squared: 0.619, Adjusted R-squared: 0.531
## F-statistic: 7.04 on 64 and 277 DF, p-value: <2e-16
ols_pred <- predict(ols, newdata=test_d)
ols_mse <- sqrt(mean((ols_pred - test_d$y)^2))
c(train=sqrt(mean(resid(ols)^2)),
test=ols_mse)
## train test
## 47 61
With ordinary linear regression, we got a root-mean-square-prediction-error of 61.17 (on the test data), compared to a root-mean-square-error of 47.07 for the training data.
This suggests there’s some overfitting going on.
We have a lot of predictors: 64 of them. A good guess is that only a few are really useful. So, we can put a sparsifying prior on the coefficients, i.e., \(\beta\)s in \[\begin{aligned} y = \beta_0 + \beta_1 x_1 + \cdots \beta_n x_n + \epsilon \end{aligned}\]
y ~ a + b x[1] + c x[2]
, and fit a linear model.kfold <- function (K, df) {
Kfold <- sample(rep(1:K, nrow(df)/K))
results <- data.frame(test_error=rep(NA, K), train_error=rep(NA, K))
for (k in 1:K) {
the_lm <- lm(y ~ ., data=df, subset=(Kfold != k))
results$train_error[k] <- sqrt(mean(resid(the_lm)^2))
test_y <- df$y[Kfold == k]
results$test_error[k] <- sqrt(mean(
(test_y - predict(the_lm, newdata=subset(df, Kfold==k)))^2 ))
}
return(results)
}
max_M <- 300 # max number of spurious variables
noise_df <- matrix(rnorm(nrow(df) * (max_M-2)), nrow=nrow(df))
colnames(noise_df) <- paste0('z', 1:ncol(noise_df))
new_df <- cbind(df, noise_df)
all_results <- data.frame(m=floor(seq(from=2, to=max_M-1, length.out=40)),
test_error=NA, train_error=NA)
for (j in 1:nrow(all_results)) {
m <- all_results$m[j]
all_results[j,2:3] <- colMeans(kfold(K=10, new_df[,1:(m+1)]))
}
Use the data to try to reproduce their model:
BIR = 14,195.35 - 141.75 (sand%) - 142.10 (silt%) - 142.56 (clay%)
They’re not wrong! What’s up with those coefficients, though?
Suppose we do regression with a large number of predictor variables.
The resulting coefficients are sparse if most are zero.
The idea is to “encourage” all the coefficients to be zero, unless they really want to be nonzero, in which case we let them be whatever they want.
This tends to discourage overfitting.
The idea is to “encourage” all the coefficients to be zero, unless they really want to be nonzero, in which case we let them be whatever they want.
To do this, we want a prior which is very peak-ey at zero but flat away from zero (“spike-and-slab”).
Compare the Normal
\[\begin{aligned} X \sim \Normal(0,1) \end{aligned}\]
to the “exponential scale mixture of Normals”,
\[\begin{aligned} X &\sim \Normal(0,\sigma) \\ \sigma &\sim \Exp(1) . \end{aligned}\]
Lets the data choose the appropriate scale of variation.
Weakly encourages \(\sigma\) to be small: so, as much variation as possible is explained by signal instead of noise.
Gets you a prior that is more peaked at zero and flatter otherwise.
The “horseshoe”:
\[\begin{aligned} \beta_j &\sim \Normal(0, \lambda_j) \\ \lambda_j &\sim \Cauchy(0, \tau) \\ \tau &\sim \Unif(0, 1) \end{aligned}\]
Fit a linear model to the diabetes
dataset with:
xy <- cbind(data.frame(y=diabetes$y), diabetes$x2)
names(xy) <- gsub("[:^]", "_", names(xy))
# ols
brms_ols <- brm(y ~ ., data=xy, family=gaussian(link='identity'))
## Compiling Stan program...
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.14, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
brms_hs <- brm(y ~ ., data=xy, family=gaussian(link='identity'),
prior=c(set_prior(horseshoe(), class="b")))
## Compiling Stan program...
## Start sampling
brms_kfold <- function (K, models) {
stopifnot(!is.null(names(models)))
Kfold <- sample(rep(1:K, nrow(xy)/K))
results <- data.frame(rep=1:K)
for (j in seq_along(models)) {
train <- test <- rep(NA, K)
for (k in 1:K) {
new_fit <- update(models[[j]], newdata=subset(xy, Kfold != k))
train[k] <- sqrt(mean(resid(new_fit)[,"Estimate"]^2))
test_y <- xy$y[Kfold == k]
test[k] <- sqrt(mean(
(test_y - predict(new_fit, newdata=subset(xy, Kfold==k))[,"Estimate"])^2 ))
}
results[[paste0(names(models)[j], "_train")]] <- train
results[[paste0(names(models)[j], "_test")]] <- test
}
return(results)
}
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.43, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.39, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.45, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.36, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Warning: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: The largest R-hat is 1.2, indicating chains have not mixed.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#r-hat
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
## Start sampling
## Start sampling
## Warning: There were 1 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: Examine the pairs() plot to diagnose sampling problems
## Start sampling
## Start sampling
## Start sampling
## Warning: There were 1 divergent transitions after warmup. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: Examine the pairs() plot to diagnose sampling problems
The idea is to plot the quantiles of each distribution against each other.
If these are datasets, this means just plotting their sorted values against each other.
\[\begin{aligned} y_i &= \sum_j \beta_j x_{ij} + \epsilon_i \\ \epsilon_i &\sim \Normal(0, \sigma^2) . \end{aligned}\]
## Using 10 posterior samples for ppc type 'dens_overlay' by default.
brms::kfold
The kfold
function will automatically do \(k\)-fold crossvalidation! For instance:
## Fitting model 1 out of 5
## Fitting model 2 out of 5
## Fitting model 3 out of 5
## Fitting model 4 out of 5
## Fitting model 5 out of 5
## Start sampling
## Start sampling
## Start sampling
## Start sampling
## Start sampling
##
## Based on 5-fold cross-validation
##
## Estimate SE
## elpd_kfold -2405.0 14.0
## p_kfold 33.2 3.6
## kfoldic 4810.0 28.0
elpd
= “expected log posterior density”
##
## Based on 5-fold cross-validation
##
## Estimate SE
## elpd_kfold -2405.0 14.0
## p_kfold 33.2 3.6
## kfoldic 4810.0 28.0