Lecture Goals

  • Learn how to
    • Perform multi-class classification
    • Use ensemble methods to prevent overfitting
    • Handle imbalanced classes


- Readings: + ISLR ch 4.3.5, 8.2

Multi-Class Classification

  • Output variable (Y) takes a value from three or more categories
    • E.g. optical character recognition (OCR)
  • Some methods naturally handle multiple classes; others can be adapted from binary classification

  • Ordinal classification is different
    • E.g. rating scale: poor to excellent
    • Typically threshold numerical \(Y\) variable

Wine Data

  • Chemical analysis of 3 cultivars of wine
wine = read_csv("data/wine.csv") %>% 
  mutate( cultivar = factor(cultivar))
glimpse(wine)
## Observations: 178
## Variables: 14
## $ cultivar        <fct> A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A...
## $ alcohol         <dbl> 14.23, 13.20, 13.16, 14.37, 13.24, 14.20, 14.3...
## $ malic_acid      <dbl> 1.71, 1.78, 2.36, 1.95, 2.59, 1.76, 1.87, 2.15...
## $ ash             <dbl> 2.43, 2.14, 2.67, 2.50, 2.87, 2.45, 2.45, 2.61...
## $ alcalinity      <dbl> 15.6, 11.2, 18.6, 16.8, 21.0, 15.2, 14.6, 17.6...
## $ magnesium       <dbl> 127, 100, 101, 113, 118, 112, 96, 121, 97, 98,...
## $ phenols         <dbl> 2.80, 2.65, 2.80, 3.85, 2.80, 3.27, 2.50, 2.60...
## $ flavanoids      <dbl> 3.06, 2.76, 3.24, 3.49, 2.69, 3.39, 2.52, 2.51...
## $ nonflavanoid    <dbl> 0.28, 0.26, 0.30, 0.24, 0.39, 0.34, 0.30, 0.31...
## $ proanthocyanins <dbl> 2.29, 1.28, 2.81, 2.18, 1.82, 1.97, 1.98, 1.25...
## $ color           <dbl> 5.64, 4.38, 5.68, 7.80, 4.32, 6.75, 5.25, 5.05...
## $ hue             <dbl> 1.04, 1.05, 1.03, 0.86, 1.04, 1.05, 1.02, 1.06...
## $ OD280           <dbl> 3.92, 3.40, 3.17, 3.45, 2.93, 2.85, 3.58, 3.58...
## $ proline         <dbl> 1065, 1050, 1185, 1480, 735, 1450, 1290, 1295,...

Classification Tree

  • Same as before, with each part assigned majority class as its prediction

Wine Data

wine %>% rpart(cultivar ~ ., data = .) %>% rpart.plot()

Multiclass from Binary

  • One-vs-Rest: fit one binary classifier per class to distinguish it from other classes
    • Classify point based on outcome of all models: e.g. assign to class with highest probabilty

Multinomial Logistic Regression

  • Fix reference category, and fit (# class) -1 logistic regressions
    • E.g. fit B vs A, and C vs A models
library(nnet)
(mltn_out = multinom(cultivar ~ ., data = wine, trace = FALSE))
## Call:
## multinom(formula = cultivar ~ ., data = wine, trace = FALSE)
## 
## Coefficients:
##   (Intercept)  alcohol malic_acid        ash alcalinity  magnesium
## B    848.4648 -50.4198  -16.82837 -376.29272   30.77966 -0.3187808
## C   -149.9645  61.7616  -14.00633  -93.30316   28.10216 -4.2882308
##    phenols flavanoids nonflavanoid proanthocyanins     color       hue
## B 115.4110  -93.03559     358.9488        100.5827 -24.09276  442.3062
## C 267.9405 -256.98401    -389.8206        151.4916  37.77146 -220.3293
##        OD280    proline
## B  -26.29891 -0.5581035
## C -166.22775 -0.4108165
## 
## Residual Deviance: 4.22418e-06 
## AIC: 56

Multinomial Logistic Regression

  • At each point, model outputs class probabilities
    • Predict most likely class
tibble(mltn_out$fitted.values, predict(mltn_out)) %>% sample_n(5)
## # A tibble: 5 x 2
##   `mltn_out$fitted.values`[,"A"]    [,"B"]    [,"C"] `predict(mltn_out)`
##                            <dbl>     <dbl>     <dbl> <fct>              
## 1                      1.41e-151 1.00e+  0 1.17e-236 B                  
## 2                      4.17e-277 1.00e+  0 0.        B                  
## 3                      1.22e-155 1.00e+  0 5.06e-178 B                  
## 4                      1.33e-253 3.05e-219 1.00e+  0 C                  
## 5                      4.22e-127 1.00e+  0 1.28e-137 B

Ensemble Methods

  • Ensemble idea: combine multiple models to improve performance

  • Use ensemble methods to
    • Reduce overfitting for complex models (bagging)
    • Increase flexibility of simple models (boosting)
  • Look at bagging (bootstrap aggregating) for classification trees

Bagging

  • Fit multiple models to resampled data (bootstrap)
    • Sample both observations & variables (rows & columns)

Bagging

  • Combine their results (aggregating) through majority vote

Random Forest

  • Bagged trees called Random Forests
library(randomForest)
(rf_out = randomForest( cultivar ~ ., data = wine, ntree= 500 ))
## 
## Call:
##  randomForest(formula = cultivar ~ ., data = wine, ntree = 500) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##         OOB estimate of  error rate: 2.25%
## Confusion matrix:
##    A  B  C class.error
## A 58  1  0  0.01694915
## B  1 68  2  0.04225352
## C  0  0 48  0.00000000

Random Forest Error

  • Out-of-bag (OOB) error: error on unsampled observations
    • Out-of-sample performance measure (similar to CV)
    • Used for choosing # of trees in forest

Imbalanced Classification

  • Classification models do not perform well when classes are severely imbalanced
    • E.g. Binary classification with <5% positives; just predicting majority class gives high accuracy (95%)
  • Simple way to address problem with resampling
    • Oversample minority class, or
    • Undersample majority class
  • Fit model to balanced (resampled) data & use it on original data