Lecture Goals

  • Understand in-/out-of-sample performance
    • Measure performance in train & test sets
  • Prevent overfitting by
    • Penalizing model complexity
    • Using (cross) validation
  • Readings

Assessing Performance

  • So far, we have chosen classifiers with optimal performance on already available data
    • A.k.a. in-sample performance
  • What we really want is classifiers with optimal performance on new data
    • A.k.a. out-of-sample performance
  • E.g., memorization (NN) classifier has 100% accuracy on available data!
    • Do you trust it to prefectly classify new data?

Out-of-Sample Performance

  • Estimate out-of-sample performance by running model on unused data

  • Split avaialable data into trainging and test sets
    • Train/select model based on training set
    • Estimate performance by applying model to test set
  • No modelling decisions should be made based on test set
    • Otherwise performance estimate is biased (optimistic)

WDBC Data

  • Split WDBC data (80% train - 20% test) and measure classification tree accuracy
train = wdbc %>% sample_frac(.8) 
test = wdbc %>% setdiff( train )
rpart_out = rpart(diagnosis ~ . - id, data = train)

train %>% add_predictions(rpart_out, type = "class") %>%
  summarise( accuracy = mean( pred == diagnosis) ) %>% pull()
## [1] 0.9692308

test %>% add_predictions(rpart_out, type = "class") %>% 
  summarise( accuracy = mean( pred == diagnosis) ) %>% pull()
## [1] 0.9473684

Overfitting

  • Out-of-sample performance expected to be worse than in-sample
    • Performance decline related to model complexity
    • More complex models tend to do worse out-of-sample
  • Overfitting: model closely describes available data, but fails to generalise to new data
    • Memorising noise instead of learning underlying structure
    • Must be cautious about performance expectations

Model Complexity

  • Model complexity related to amount of tuning/choices model permits

Overfitting and Model Complexity

Model Selection

  • Statistical Learning conundrum: choose model based on available data, but want it to perform well on unseen data
    • In-sample performance favors overfitting
  • Two practical ways to prevent overfitting
    • Validation
    • Regularization

WDBC Data

  • Tree with 100% in-sample accuracy
big_tree = rpart( diagnosis ~ . - id, data = train, 
               control = rpart.control(minsplit = 1, cp = 0) )
rpart.plot(big_tree)

Validation Set

  • Split data into training, validation, and test set
    • Fit competing models on training set
    • Select model based on validation set
    • Assess performance on test set

WDBC Data

  • Select tree model (# splits) with best validation performance

Cross-Validation

  • Cross-Validation (CV) repeatedly partitions data into training & validation sets
    • Iterations called folds
  • Estimate performance by averaging error accross all folds

WDBC Data

  • CV performed by default in rpart()
big_tree$cptable
##            CP nsplit  rel error    xerror       xstd
## 1 0.814371257      0 1.00000000 1.0000000 0.06156478
## 2 0.041916168      1 0.18562874 0.2634731 0.03775071
## 3 0.017964072      3 0.10179641 0.1976048 0.03312768
## 4 0.011976048      5 0.06586826 0.1676647 0.03069522
## 5 0.008982036      6 0.05389222 0.1497006 0.02910597
## 6 0.005988024      8 0.03592814 0.1377246 0.02798231
## 7 0.003992016     12 0.01197605 0.1137725 0.02555041
## 8 0.000000000     15 0.00000000 0.1257485 0.02679985

WDBC Data

Regularization

  • Regularization is another way to estimate out-of-sample performance

  • Regularization imposes penalty on model complexity
    • Choose model that minimizes
      \[\text{(training error) + (complexity penalty)}\]
  • Form of complexity penalty can vary
    • Specify based on theoretical results or cross-validation

WDBC Data

Regularization

  • For WDBC tree model
    • Training Error: 1 - accuracy
    • Complexity Penalty: \(cp\) * (# splits)
  • Select model (optimal \(cp\)) based on
    • Minimum CV error, or
    • Simplest model within one st.dev. of min. CV error

WDBC data

as_tibble(big_tree$cptable ) %>%
  mutate( pick = (xerror <  min(xerror) + xstd ) )
## # A tibble: 8 x 6
##        CP nsplit `rel error` xerror   xstd pick 
##     <dbl>  <dbl>       <dbl>  <dbl>  <dbl> <lgl>
## 1 0.814        0      1       1     0.0616 FALSE
## 2 0.0419       1      0.186   0.263 0.0378 FALSE
## 3 0.0180       3      0.102   0.198 0.0331 FALSE
## 4 0.0120       5      0.0659  0.168 0.0307 FALSE
## 5 0.00898      6      0.0539  0.150 0.0291 FALSE
## 6 0.00599      8      0.0359  0.138 0.0280 TRUE 
## 7 0.00399     12      0.0120  0.114 0.0256 TRUE 
## 8 0           15      0       0.126 0.0268 TRUE

WDBC data

final_tree = prune(big_tree, cp = .006) 
test %>% add_predictions( final_tree, type = "class") %>% 
  mutate( pred = fct_relevel(pred, "M"),
          diagnosis = fct_relevel(diagnosis, "M") ) %>% 
  xtabs( ~ pred + diagnosis, data = .) %>% prop.table() 
##     diagnosis
## pred          M          B
##    M 0.35964912 0.01754386
##    B 0.03508772 0.58771930

WDBC data

Recipe for Statistical Learning

  • Split data into training/training+validation & test set

  • Fit models of different complexity on training set

  • Choose optimal model complexity using regularization/cross-validation

  • Estimate out-of-sample performance of final model on test set