- Learn how to
- Perform multi-class classification
- Use ensemble methods to prevent overfitting
- Handle imbalanced classes
- Perform multi-class classification
- Readings: + ISLR ch 4.3.5, 8.2
- Readings: + ISLR ch 4.3.5, 8.2
Some methods naturally handle multiple classes; others can be adapted from binary classification
wine = read_csv("data/wine.csv") %>% mutate( cultivar = factor(cultivar)) glimpse(wine) ## Observations: 178 ## Variables: 14 ## $ cultivar <fct> A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A... ## $ alcohol <dbl> 14.23, 13.20, 13.16, 14.37, 13.24, 14.20, 14.3... ## $ malic_acid <dbl> 1.71, 1.78, 2.36, 1.95, 2.59, 1.76, 1.87, 2.15... ## $ ash <dbl> 2.43, 2.14, 2.67, 2.50, 2.87, 2.45, 2.45, 2.61... ## $ alcalinity <dbl> 15.6, 11.2, 18.6, 16.8, 21.0, 15.2, 14.6, 17.6... ## $ magnesium <dbl> 127, 100, 101, 113, 118, 112, 96, 121, 97, 98,... ## $ phenols <dbl> 2.80, 2.65, 2.80, 3.85, 2.80, 3.27, 2.50, 2.60... ## $ flavanoids <dbl> 3.06, 2.76, 3.24, 3.49, 2.69, 3.39, 2.52, 2.51... ## $ nonflavanoid <dbl> 0.28, 0.26, 0.30, 0.24, 0.39, 0.34, 0.30, 0.31... ## $ proanthocyanins <dbl> 2.29, 1.28, 2.81, 2.18, 1.82, 1.97, 1.98, 1.25... ## $ color <dbl> 5.64, 4.38, 5.68, 7.80, 4.32, 6.75, 5.25, 5.05... ## $ hue <dbl> 1.04, 1.05, 1.03, 0.86, 1.04, 1.05, 1.02, 1.06... ## $ OD280 <dbl> 3.92, 3.40, 3.17, 3.45, 2.93, 2.85, 3.58, 3.58... ## $ proline <dbl> 1065, 1050, 1185, 1480, 735, 1450, 1290, 1295,...
wine %>% rpart(cultivar ~ ., data = .) %>% rpart.plot()
library(nnet) (mltn_out = multinom(cultivar ~ ., data = wine, trace = FALSE)) ## Call: ## multinom(formula = cultivar ~ ., data = wine, trace = FALSE) ## ## Coefficients: ## (Intercept) alcohol malic_acid ash alcalinity magnesium ## B 848.4648 -50.4198 -16.82837 -376.29272 30.77966 -0.3187808 ## C -149.9645 61.7616 -14.00633 -93.30316 28.10216 -4.2882308 ## phenols flavanoids nonflavanoid proanthocyanins color hue ## B 115.4110 -93.03559 358.9488 100.5827 -24.09276 442.3062 ## C 267.9405 -256.98401 -389.8206 151.4916 37.77146 -220.3293 ## OD280 proline ## B -26.29891 -0.5581035 ## C -166.22775 -0.4108165 ## ## Residual Deviance: 4.22418e-06 ## AIC: 56
tibble(mltn_out$fitted.values, predict(mltn_out)) %>% sample_n(5) ## # A tibble: 5 x 2 ## `mltn_out$fitted.values`[,"A"] [,"B"] [,"C"] `predict(mltn_out)` ## <dbl> <dbl> <dbl> <fct> ## 1 1.41e-151 1.00e+ 0 1.17e-236 B ## 2 4.17e-277 1.00e+ 0 0. B ## 3 1.22e-155 1.00e+ 0 5.06e-178 B ## 4 1.33e-253 3.05e-219 1.00e+ 0 C ## 5 4.22e-127 1.00e+ 0 1.28e-137 B
Ensemble idea: combine multiple models to improve performance
Look at bagging (bootstrap aggregating) for classification trees
library(randomForest) (rf_out = randomForest( cultivar ~ ., data = wine, ntree= 500 )) ## ## Call: ## randomForest(formula = cultivar ~ ., data = wine, ntree = 500) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 3 ## ## OOB estimate of error rate: 2.25% ## Confusion matrix: ## A B C class.error ## A 58 1 0 0.01694915 ## B 1 68 2 0.04225352 ## C 0 0 48 0.00000000