Confusion Matrix in R

Plotting the Confusion Matrix with ggplot in R
Author

Enrique Pérez Herrero

Published

March 9, 2022

1 Confusion Matrix

The confusion matrix is a useful tool for visualizing the performance of a classification algorithm. In this blog post, we provide a function to generate an image of the confusion matrix. Additionally, the R package caret includes the confusionMatrix function, which produces a detailed output.

1.1 Classification

We will conduct a Naive Bayes classification using the classical Iris dataset.

Code
# train and test data
iris$spl <- caTools::sample.split(iris, SplitRatio = 0.8)
train <- subset(iris, iris$spl == TRUE)
test <- subset(iris, iris$spl == FALSE)

iris_nb <- naiveBayes(Species ~ ., data = train)
nb_train_predict <- predict(iris_nb, test[, names(test) != "Species"])

cfm <- confusionMatrix(nb_train_predict, test$Species)
cfm
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa         10          0         0
  versicolor      0         10         1
  virginica       0          0         9

Overall Statistics
                                          
               Accuracy : 0.9667          
                 95% CI : (0.8278, 0.9992)
    No Information Rate : 0.3333          
    P-Value [Acc > NIR] : 2.963e-13       
                                          
                  Kappa : 0.95            
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: setosa Class: versicolor Class: virginica
Sensitivity                 1.0000            1.0000           0.9000
Specificity                 1.0000            0.9500           1.0000
Pos Pred Value              1.0000            0.9091           1.0000
Neg Pred Value              1.0000            1.0000           0.9524
Prevalence                  0.3333            0.3333           0.3333
Detection Rate              0.3333            0.3333           0.3000
Detection Prevalence        0.3333            0.3667           0.3000
Balanced Accuracy           1.0000            0.9750           0.9500

1.2 Plotting

To plot the obtained confusion matrix as a ggplot graphic, we will use the following function:

Code
ggplot_confusion_matrix <- function(cfm) {
  mytitle <- paste("Accuracy", percent_format() (cfm$overall[1]),
                   "Kappa", percent_format() (cfm$overall[2]))
  p <-
    ggplot(data = as.data.frame(cfm$table),
           aes(x = Reference, y = Prediction)) +
    geom_tile(aes(fill = log(Freq)), colour = "white") +
    scale_fill_gradient(low = "white", high = "steelblue") +
    geom_text(aes(x = Reference, y = Prediction, label = Freq)) +
    theme(legend.position = "none") +
    ggtitle(mytitle)
  return(p)
}
Code
ggplot_confusion_matrix(cfm)