machine learning, r,

Decision Trees & Random Forest

Stéphan Stéphan Follow on Github Nov 15, 2020 · 42 mins read
Decision Trees & Random Forest
Share this

Libraries

library(ISLR) # Dataset
library(MASS) # Dataset
library(caTools) # Split data
library(tree) # Trees
library(rpart) # Regression & classification Trees
library(rpart.plot) # Better plot of trees infos
library(corrplot) # Graphical display of correlation matrix
library(randomForest) # Random forest model
library(gbm) # Boosted model
library(knitr) # engine for dynamic report 

setup

color_palette = c('royalblue', "rosybrown")
setwd("D:/Stéphan/OneDrive - De Vinci/Année 4/Machine Learning/TD5")
set.seed(18) 

Regression Trees

Single tree

Question 1

To demonstrate regression trees, we will use the Boston dataset that we used during the first two practical works, from the MASS package. Recall that medv is the response.

Load the Boston dataset from MASS package. Split the dataset randomly in half.

Boston_idx = sample(1:nrow(Boston), nrow(Boston) / 2)
Boston_train = Boston[Boston_idx,]
Boston_test  = Boston[-Boston_idx,]

Question 2

Fit a regression tree to the training data using the rpart() function from the rpart package. Name the tree Boston_tree.

Boston_tree = rpart(medv ~.,data=Boston_train)

Question 3

Plot the obtained tree.

plot(Boston_tree)
text(Boston_tree, pretty = 0)
title(main = "Regression Tree")

Question 4

A better plot can be obtained using the rpart.plot package. Re-plot the tree using it. You can use the rpart.plot() function which by default, when the output is continuous, each node shows: the predicted value, and the percentage of observations in the node. You can also use the prp() function.

rpart.plot(Boston_tree,shadow.col = "darkgrey",box.palette = "#3399FF")
prp(Boston_tree,branch = 0,box.palette = "#3399FF",legend.x = 'Tree')

Question 5

Print the obtained tree and print its summary. Between the things that you can see in the summary, the CP (complexity parameter) table and the importance of each variable in the model. Print the CP table using the printcp() function to see the cross validation results. Plot a comparison figure using the plotcp() function.

plot(Boston_tree)
text(Boston_tree, pretty = 0)
title(main = "Regression Tree")

summary(Boston_tree)
Call:
rpart(formula = medv ~ ., data = Boston_train)
  n= 253

          CP nsplit rel error    xerror       xstd
1 0.51227092      0 1.0000000 1.0079237 0.12063261
2 0.17530106      1 0.4877291 0.6269968 0.08009188
3 0.07439632      2 0.3124280 0.4181044 0.06301976
4 0.02929834      3 0.2380317 0.3886623 0.06765051
5 0.02529297      4 0.2087334 0.3489906 0.06171012
6 0.01419376      5 0.1834404 0.3172825 0.05624605
7 0.01000000      6 0.1692466 0.3051705 0.05617950

Variable importance
lstat    rm   age   tax   dis indus   nox  crim   rad    zn black
   40    17     9     7     7     6     5     5     2     1     1

Node number 1: 253 observations,    complexity param=0.5122709
  mean=21.99802, MSE=78.49055
  left son=2 (219 obs) right son=3 (34 obs)
  Primary splits:
      lstat   < 5.27    to the right, improve=0.5122709, (0 missing)
      rm      < 6.978   to the left,  improve=0.4501026, (0 missing)
      indus   < 7.625   to the right, improve=0.3122330, (0 missing)
      ptratio < 19.65   to the right, improve=0.2830922, (0 missing)
      nox     < 0.5125  to the right, improve=0.2710296, (0 missing)
  Surrogate splits:
      rm    < 6.978   to the left,  agree=0.917, adj=0.382, (0 split)
      age   < 17.1    to the right, agree=0.881, adj=0.118, (0 split)
      crim  < 0.03402 to the right, agree=0.877, adj=0.088, (0 split)
      zn    < 87.5    to the left,  agree=0.870, adj=0.029, (0 split)
      indus < 1.605   to the right, agree=0.870, adj=0.029, (0 split)

Node number 2: 219 observations,    complexity param=0.1753011
  mean=19.49954, MSE=33.88288
  left son=4 (92 obs) right son=5 (127 obs)
  Primary splits:
      lstat   < 14.4    to the right, improve=0.4691352, (0 missing)
      crim    < 5.84803 to the right, improve=0.3530712, (0 missing)
      ptratio < 19.9    to the right, improve=0.3072113, (0 missing)
      nox     < 0.6635  to the right, improve=0.3018543, (0 missing)
      tax     < 434.5   to the right, improve=0.2916064, (0 missing)
  Surrogate splits:
      age   < 83.55   to the right, agree=0.804, adj=0.533, (0 split)
      indus < 16.57   to the right, agree=0.795, adj=0.511, (0 split)
      tax   < 434.5   to the right, agree=0.795, adj=0.511, (0 split)
      dis   < 2.24235 to the left,  agree=0.790, adj=0.500, (0 split)
      nox   < 0.5765  to the right, agree=0.776, adj=0.467, (0 split)

Node number 3: 34 observations,    complexity param=0.07439632
  mean=38.09118, MSE=66.61845
  left son=6 (21 obs) right son=7 (13 obs)
  Primary splits:
      rm    < 7.433   to the left,  improve=0.6522527, (0 missing)
      dis   < 3.6617  to the right, improve=0.3892454, (0 missing)
      crim  < 0.25614 to the left,  improve=0.2629405, (0 missing)
      age   < 47.75   to the left,  improve=0.2073899, (0 missing)
      lstat < 4.66    to the right, improve=0.1551880, (0 missing)
  Surrogate splits:
      crim  < 0.25614 to the left,  agree=0.706, adj=0.231, (0 split)
      dis   < 3.6617  to the right, agree=0.706, adj=0.231, (0 split)
      rad   < 6       to the left,  agree=0.676, adj=0.154, (0 split)
      lstat < 3.58    to the right, agree=0.676, adj=0.154, (0 split)
      zn    < 77.5    to the left,  agree=0.647, adj=0.077, (0 split)

Node number 4: 92 observations,    complexity param=0.02929834
  mean=14.81522, MSE=19.54542
  left son=8 (44 obs) right son=9 (48 obs)
  Primary splits:
      crim  < 5.7819  to the right, improve=0.3235549, (0 missing)
      nox   < 0.603   to the right, improve=0.3025317, (0 missing)
      lstat < 19.83   to the right, improve=0.2864889, (0 missing)
      dis   < 2.0752  to the left,  improve=0.2586920, (0 missing)
      tax   < 567.5   to the right, improve=0.2188006, (0 missing)
  Surrogate splits:
      rad   < 16      to the right, agree=0.891, adj=0.773, (0 split)
      tax   < 567.5   to the right, agree=0.870, adj=0.727, (0 split)
      nox   < 0.657   to the right, agree=0.772, adj=0.523, (0 split)
      dis   < 2.26375 to the left,  agree=0.739, adj=0.455, (0 split)
      black < 320.5   to the left,  agree=0.728, adj=0.432, (0 split)

Node number 5: 127 observations,    complexity param=0.02529297
  mean=22.89291, MSE=16.85845
  left son=10 (104 obs) right son=11 (23 obs)
  Primary splits:
      rm      < 6.611   to the left,  improve=0.2345937, (0 missing)
      lstat   < 10.14   to the right, improve=0.2008000, (0 missing)
      tax     < 223.5   to the right, improve=0.2002331, (0 missing)
      indus   < 3.105   to the right, improve=0.1032093, (0 missing)
      ptratio < 18.65   to the right, improve=0.0835695, (0 missing)
  Surrogate splits:
      tax     < 222.5   to the right, agree=0.835, adj=0.087, (0 split)
      crim    < 0.01778 to the right, agree=0.827, adj=0.043, (0 split)
      ptratio < 13.85   to the right, agree=0.827, adj=0.043, (0 split)

Node number 6: 21 observations
  mean=32.90476, MSE=32.67474

Node number 7: 13 observations
  mean=46.46923, MSE=7.806746

Node number 8: 44 observations
  mean=12.18864, MSE=14.72873

Node number 9: 48 observations
  mean=17.22292, MSE=11.83968

Node number 10: 104 observations
  mean=21.95769, MSE=9.53321

Node number 11: 23 observations,    complexity param=0.01419376
  mean=27.12174, MSE=28.14344
  left son=22 (16 obs) right son=23 (7 obs)
  Primary splits:
      tax     < 268.5   to the right, improve=0.4354419, (0 missing)
      ptratio < 18.7    to the right, improve=0.3278305, (0 missing)
      indus   < 4.415   to the right, improve=0.2124938, (0 missing)
      rm      < 6.978   to the left,  improve=0.1449609, (0 missing)
      nox     < 0.5225  to the right, improve=0.1122918, (0 missing)
  Surrogate splits:
      indus   < 2.21    to the right, agree=0.783, adj=0.286, (0 split)
      rm      < 7.1455  to the left,  agree=0.783, adj=0.286, (0 split)
      ptratio < 13.9    to the right, agree=0.783, adj=0.286, (0 split)

Node number 22: 16 observations
  mean=24.80625, MSE=16.75934

Node number 23: 7 observations
  mean=32.41429, MSE=13.89837 

Complexity Table

printcp(Boston_tree)

Regression tree:
rpart(formula = medv ~ ., data = Boston_train)

Variables actually used in tree construction:
[1] crim  lstat rm    tax

Root node error: 19858/253 = 78.491

n= 253

        CP nsplit rel error  xerror     xstd
1 0.512271      0   1.00000 1.00792 0.120633
2 0.175301      1   0.48773 0.62700 0.080092
3 0.074396      2   0.31243 0.41810 0.063020
4 0.029298      3   0.23803 0.38866 0.067651
5 0.025293      4   0.20873 0.34899 0.061710
6 0.014194      5   0.18344 0.31728 0.056246
7 0.010000      6   0.16925 0.30517 0.056180
plotcp(Boston_tree, col = color_palette[1])

Question 5 bis

This function compute the rmse for two vectors of same length containing actual and predicted values.

RMSE = function(actual, predicted){
    return (sqrt(sum((predicted-actual)**2)/length(predicted)))
}

Here is an simple example :

#EXAMPLE
actual = c(1.5, 1.0, 2.0, 7.4, 5.8, 6.6)
predicted = c(1.0, 1.1, 2.5, 7.3, 6.0, 6.2)
RMSE(actual, predicted)
[1] 0.3464102

Question 6

Use the function predict() to predict the response on the test set. Then calculate the RMSE obtained with tree model.

tree_Predict=predict(Boston_tree,newdata=Boston_test)
rmse_singleTree = RMSE(Boston_test$medv,tree_Predict)
rmse_singleTree
[1] 5.051138

Here is some predictions VS the actual datas

as.vector(round(tree_Predict,1))[1:10]
Boston_test$medv[1:10]
 [1] 32.9 32.4 32.9 17.2 17.2 22.0 22.0 22.0 22.0 17.2
 [1] 34.7 36.2 28.7 16.5 15.0 20.4 18.2 19.9 23.1 17.5

Question 7

Fit a linear regression model on the training set. Then predict the response on the test set using the linear model. Calculate the RMSE and compare the performance of the tree and the linear regression model.

LinearModel = lm(medv ~ ., data = Boston_train)
LinearPrediction = predict(LinearModel, Boston_test)
rmseLinear = RMSE(Boston_test$medv, LinearPrediction)
rmseLinear
[1] 5.016083

Results

Linear Model

plot(LinearPrediction,Boston_test$medv,
            col=color_palette[1],
            xlab = "Prediction from linear model",
            ylab = "Actual value", cex=1,
            pch=20)
title('Comparison between actual & predicted value from linear model')
abline(0,1, col=color_palette[2], lwd=2.5)

Tree Model

plot(tree_Predict,Boston_test$medv,
            col=color_palette[1],
            xlab = "Prediction from tree model",
            ylab = "Actual value",)
title('Comparison between actual & predicted value from tree model')
abline(0,1, col=color_palette[2])

Bagging

Question 8

Fit a bagged model, using the randomForest() function from the randomForest package.

mtry is equal to p the number of predictors.

baggedmodel=randomForest(medv~.,data = Boston_train, mtry=13) #Number of predictors = 13
baggedmodel$importance
##         IncNodePurity
## crim        643.59117
## zn           50.09744
## indus       100.33552
## chas         12.57610
## nox         470.50166
## rm         4068.82093
## age         237.34529
## dis         486.89991
## rad          74.26411
## tax         454.53904
## ptratio     129.32766
## black       291.56419
## lstat     12669.04283
varImpPlot(baggedmodel, col = color_palette, main = "Variable Importance of the Bagged Model")

Question 9

Predict the response on the test set using the bagging model. Calculate the RMSE. Is the performance of the model better than linear regression or a simple tree?

baggedmodel_Predict=predict(baggedmodel, Boston_test)
#table(baggedmodel_Predict,Boston_test$medv) it is not a categorical variable so multiple table is not interesting
rmse_bagged = RMSE(Boston_test$medv,baggedmodel_Predict)
rmse_bagged
[1] 3.909783
plot(baggedmodel[["mse"]], type = 'l',
     col = color_palette,
     main = 'Bagged Trees: Error vs Number of Trees',
     xlab='trees',
     ylab='Error',
     lwd= 1)

Bagging model’s performance is better than linear regression or a simple tree.

Random Forests

Question 10

Number of predictors equal to p/3.

rdmForestmodel=randomForest(medv~.,data = Boston_train, mtry=4)
rdmForestmodel_Predict=predict(rdmForestmodel, Boston_test)
rmse_rdmForest = RMSE(Boston_test$medv,rdmForestmodel_Predict)
rmse_rdmForest
[1] 4.169829

We observe that the RMSE is lower for the bagging model, hence this model seems better.

Question 11

The three most important predictors are lstat, rm and indus. These are not the best predictors for linear regression model.

rdmForestmodel$importance
        IncNodePurity
crim       1398.87015
zn          156.36048
indus      1480.59041
chas         92.22241
nox        1272.65652
rm         4906.72407
age         471.37480
dis         871.61607
rad         163.50122
tax         747.88426
ptratio    1131.53288
black       444.69195
lstat      6412.05749

Question 12

varImpPlot(rdmForestmodel, col = color_palette, main = "Variable Importance of the Random Forest Model")

Boosting

Question 10

Using the gbm() function like following, fit a boosted model on the training set. Then compare its performance with the previous models by calculating the predictions and the RMSE.

Boston_boost = gbm(medv ~ ., data = Boston_train, distribution = "gaussian",
                    n.trees = 5000, interaction.depth = 4, shrinkage = 0.01)
Boston_boost_Predict=predict(Boston_boost, Boston_test)
rmse_boost = RMSE(Boston_test$medv,Boston_boost_Predict)
rmse_boost
[1] 3.645138

Question 11

Show the summary of the boosted model. A figure of the variable importance will be shown.

summary(Boston_boost)

            var    rel.inf
lstat     lstat 45.8125176
rm           rm 25.5296071
dis         dis  5.7982413
nox         nox  4.8147082
crim       crim  4.6772738
black     black  3.6608466
age         age  2.9176428
ptratio ptratio  2.6588683
tax         tax  2.2412215
indus     indus  0.9206316
rad         rad  0.6724993
zn           zn  0.2290519
chas       chas  0.0668899

Comparison

comparison.frame <- data.frame(
   Model_name = c("Single Tree", "Linear regression", "Bagging", "Random Forest", "Boosting"),
   RMSE = c(rmse_singleTree, rmseLinear, rmse_bagged, rmse_rdmForest, rmse_boost),
   stringsAsFactors = FALSE
)
kable(comparison.frame)
Model_name RMSE
Single Tree 5.051138
Linear regression 5.016083
Bagging 3.909784
Random Forest 4.169829
Boosting 3.645138
plot(tree_Predict,Boston_test$medv,
            col=color_palette[1],
            xlab = "Predicted",
            ylab = "Actual value",
            pch = 20,
            panel.first = grid(lty=2)
            )
title('Single Tree, Test data')
abline(0,1, col=color_palette[2])

plot(baggedmodel_Predict,Boston_test$medv,
            col=color_palette[1],
            xlab = "Predicted",
            ylab = "Actual value", cex=1.1,
            pch=20, panel.first = grid(lty=2))
title('Bagging, Test data')
abline(0,1, col=color_palette[2])

plot(rdmForestmodel_Predict,Boston_test$medv,
            col=color_palette[1],
            xlab = "Predicted",
            ylab = "Actual value",
            pch=20,
            panel.first = grid(lty=2)
     )
title('Random Forest, Test data')
abline(0,1, col=color_palette[2])

plot(Boston_boost_Predict,Boston_test$medv,
            col=color_palette[1],
            xlab = "Predicted",
            ylab = "Actual value",
            cex=1,
            pch=20,
            panel.first = grid(lty=2)
     )
title('Boosting, Test data')
abline(0,1, col=color_palette[2], lwd=2)

Classification Trees

spam <- read.csv('spam.csv')
spam$spam = as.factor(spam$spam)
#spam$spam

Exploration du data set : The Spam dataset

#View(spam)
str(spam)
'data.frame':   4601 obs. of  58 variables:
 $ spam      : Factor w/ 2 levels "FALSE","TRUE": 2 2 2 2 2 2 2 2 2 2 ...
 $ make      : num  0 0.21 0.06 0 0 0 0 0 0.15 0.06 ...
 $ address   : num  0.64 0.28 0 0 0 0 0 0 0 0.12 ...
 $ all       : num  0.64 0.5 0.71 0 0 0 0 0 0.46 0.77 ...
 $ X3d       : num  0 0 0 0 0 0 0 0 0 0 ...
 $ our       : num  0.32 0.14 1.23 0.63 0.63 1.85 1.92 1.88 0.61 0.19 ...
 $ over      : num  0 0.28 0.19 0 0 0 0 0 0 0.32 ...
 $ remove    : num  0 0.21 0.19 0.31 0.31 0 0 0 0.3 0.38 ...
 $ internet  : num  0 0.07 0.12 0.63 0.63 1.85 0 1.88 0 0 ...
 $ order     : num  0 0 0.64 0.31 0.31 0 0 0 0.92 0.06 ...
 $ mail      : num  0 0.94 0.25 0.63 0.63 0 0.64 0 0.76 0 ...
 $ receive   : num  0 0.21 0.38 0.31 0.31 0 0.96 0 0.76 0 ...
 $ will      : num  0.64 0.79 0.45 0.31 0.31 0 1.28 0 0.92 0.64 ...
 $ people    : num  0 0.65 0.12 0.31 0.31 0 0 0 0 0.25 ...
 $ report    : num  0 0.21 0 0 0 0 0 0 0 0 ...
 $ addresses : num  0 0.14 1.75 0 0 0 0 0 0 0.12 ...
 $ free      : num  0.32 0.14 0.06 0.31 0.31 0 0.96 0 0 0 ...
 $ business  : num  0 0.07 0.06 0 0 0 0 0 0 0 ...
 $ email     : num  1.29 0.28 1.03 0 0 0 0.32 0 0.15 0.12 ...
 $ you       : num  1.93 3.47 1.36 3.18 3.18 0 3.85 0 1.23 1.67 ...
 $ credit    : num  0 0 0.32 0 0 0 0 0 3.53 0.06 ...
 $ your      : num  0.96 1.59 0.51 0.31 0.31 0 0.64 0 2 0.71 ...
 $ font      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ X000      : num  0 0.43 1.16 0 0 0 0 0 0 0.19 ...
 $ money     : num  0 0.43 0.06 0 0 0 0 0 0.15 0 ...
 $ hp        : num  0 0 0 0 0 0 0 0 0 0 ...
 $ hpl       : num  0 0 0 0 0 0 0 0 0 0 ...
 $ george    : num  0 0 0 0 0 0 0 0 0 0 ...
 $ X650      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ lab       : num  0 0 0 0 0 0 0 0 0 0 ...
 $ labs      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ telnet    : num  0 0 0 0 0 0 0 0 0 0 ...
 $ X857      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ data      : num  0 0 0 0 0 0 0 0 0.15 0 ...
 $ X415      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ X85       : num  0 0 0 0 0 0 0 0 0 0 ...
 $ technology: num  0 0 0 0 0 0 0 0 0 0 ...
 $ X1999     : num  0 0.07 0 0 0 0 0 0 0 0 ...
 $ parts     : num  0 0 0 0 0 0 0 0 0 0 ...
 $ pm        : num  0 0 0 0 0 0 0 0 0 0 ...
 $ direct    : num  0 0 0.06 0 0 0 0 0 0 0 ...
 $ cs        : num  0 0 0 0 0 0 0 0 0 0 ...
 $ meeting   : num  0 0 0 0 0 0 0 0 0 0 ...
 $ original  : num  0 0 0.12 0 0 0 0 0 0.3 0 ...
 $ project   : num  0 0 0 0 0 0 0 0 0 0.06 ...
 $ re        : num  0 0 0.06 0 0 0 0 0 0 0 ...
 $ edu       : num  0 0 0.06 0 0 0 0 0 0 0 ...
 $ table     : num  0 0 0 0 0 0 0 0 0 0 ...
 $ conference: num  0 0 0 0 0 0 0 0 0 0 ...
 $ ch.       : num  0 0 0.01 0 0 0 0 0 0 0.04 ...
 $ ch..1     : num  0 0.132 0.143 0.137 0.135 0.223 0.054 0.206 0.271 0.03 ...
 $ ch..2     : num  0 0 0 0 0 0 0 0 0 0 ...
 $ ch..3     : num  0.778 0.372 0.276 0.137 0.135 0 0.164 0 0.181 0.244 ...
 $ ch..4     : num  0 0.18 0.184 0 0 0 0.054 0 0.203 0.081 ...
 $ ch..5     : num  0 0.048 0.01 0 0 0 0 0 0.022 0 ...
 $ crl.ave   : num  3.76 5.11 9.82 3.54 3.54 ...
 $ crl.long  : int  61 101 485 40 40 15 4 11 445 43 ...
 $ crl.tot   : int  278 1028 2259 191 191 54 112 49 1257 749 ...
summary(spam, na=FALSE) 
    spam           make           address            all
 FALSE:2788   Min.   :0.0000   Min.   : 0.000   Min.   :0.0000
 TRUE :1813   1st Qu.:0.0000   1st Qu.: 0.000   1st Qu.:0.0000
              Median :0.0000   Median : 0.000   Median :0.0000
              Mean   :0.1046   Mean   : 0.213   Mean   :0.2807
              3rd Qu.:0.0000   3rd Qu.: 0.000   3rd Qu.:0.4200
              Max.   :4.5400   Max.   :14.280   Max.   :5.1000
      X3d                our               over            remove
 Min.   : 0.00000   Min.   : 0.0000   Min.   :0.0000   Min.   :0.0000
 1st Qu.: 0.00000   1st Qu.: 0.0000   1st Qu.:0.0000   1st Qu.:0.0000
 Median : 0.00000   Median : 0.0000   Median :0.0000   Median :0.0000
 Mean   : 0.06542   Mean   : 0.3122   Mean   :0.0959   Mean   :0.1142
 3rd Qu.: 0.00000   3rd Qu.: 0.3800   3rd Qu.:0.0000   3rd Qu.:0.0000
 Max.   :42.81000   Max.   :10.0000   Max.   :5.8800   Max.   :7.2700
    internet           order              mail            receive
 Min.   : 0.0000   Min.   :0.00000   Min.   : 0.0000   Min.   :0.00000
 1st Qu.: 0.0000   1st Qu.:0.00000   1st Qu.: 0.0000   1st Qu.:0.00000
 Median : 0.0000   Median :0.00000   Median : 0.0000   Median :0.00000
 Mean   : 0.1053   Mean   :0.09007   Mean   : 0.2394   Mean   :0.05982
 3rd Qu.: 0.0000   3rd Qu.:0.00000   3rd Qu.: 0.1600   3rd Qu.:0.00000
 Max.   :11.1100   Max.   :5.26000   Max.   :18.1800   Max.   :2.61000
      will            people            report           addresses
 Min.   :0.0000   Min.   :0.00000   Min.   : 0.00000   Min.   :0.0000
 1st Qu.:0.0000   1st Qu.:0.00000   1st Qu.: 0.00000   1st Qu.:0.0000
 Median :0.1000   Median :0.00000   Median : 0.00000   Median :0.0000
 Mean   :0.5417   Mean   :0.09393   Mean   : 0.05863   Mean   :0.0492
 3rd Qu.:0.8000   3rd Qu.:0.00000   3rd Qu.: 0.00000   3rd Qu.:0.0000
 Max.   :9.6700   Max.   :5.55000   Max.   :10.00000   Max.   :4.4100
      free            business          email             you
 Min.   : 0.0000   Min.   :0.0000   Min.   :0.0000   Min.   : 0.000
 1st Qu.: 0.0000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.: 0.000
 Median : 0.0000   Median :0.0000   Median :0.0000   Median : 1.310
 Mean   : 0.2488   Mean   :0.1426   Mean   :0.1847   Mean   : 1.662
 3rd Qu.: 0.1000   3rd Qu.:0.0000   3rd Qu.:0.0000   3rd Qu.: 2.640
 Max.   :20.0000   Max.   :7.1400   Max.   :9.0900   Max.   :18.750
     credit              your              font              X000
 Min.   : 0.00000   Min.   : 0.0000   Min.   : 0.0000   Min.   :0.0000
 1st Qu.: 0.00000   1st Qu.: 0.0000   1st Qu.: 0.0000   1st Qu.:0.0000
 Median : 0.00000   Median : 0.2200   Median : 0.0000   Median :0.0000
 Mean   : 0.08558   Mean   : 0.8098   Mean   : 0.1212   Mean   :0.1016
 3rd Qu.: 0.00000   3rd Qu.: 1.2700   3rd Qu.: 0.0000   3rd Qu.:0.0000
 Max.   :18.18000   Max.   :11.1100   Max.   :17.1000   Max.   :5.4500
     money                hp               hpl              george
 Min.   : 0.00000   Min.   : 0.0000   Min.   : 0.0000   Min.   : 0.0000
 1st Qu.: 0.00000   1st Qu.: 0.0000   1st Qu.: 0.0000   1st Qu.: 0.0000
 Median : 0.00000   Median : 0.0000   Median : 0.0000   Median : 0.0000
 Mean   : 0.09427   Mean   : 0.5495   Mean   : 0.2654   Mean   : 0.7673
 3rd Qu.: 0.00000   3rd Qu.: 0.0000   3rd Qu.: 0.0000   3rd Qu.: 0.0000
 Max.   :12.50000   Max.   :20.8300   Max.   :16.6600   Max.   :33.3300
      X650             lab                labs            telnet
 Min.   :0.0000   Min.   : 0.00000   Min.   :0.0000   Min.   : 0.00000
 1st Qu.:0.0000   1st Qu.: 0.00000   1st Qu.:0.0000   1st Qu.: 0.00000
 Median :0.0000   Median : 0.00000   Median :0.0000   Median : 0.00000
 Mean   :0.1248   Mean   : 0.09892   Mean   :0.1029   Mean   : 0.06475
 3rd Qu.:0.0000   3rd Qu.: 0.00000   3rd Qu.:0.0000   3rd Qu.: 0.00000
 Max.   :9.0900   Max.   :14.28000   Max.   :5.8800   Max.   :12.50000
      X857              data               X415              X85
 Min.   :0.00000   Min.   : 0.00000   Min.   :0.00000   Min.   : 0.0000
 1st Qu.:0.00000   1st Qu.: 0.00000   1st Qu.:0.00000   1st Qu.: 0.0000
 Median :0.00000   Median : 0.00000   Median :0.00000   Median : 0.0000
 Mean   :0.04705   Mean   : 0.09723   Mean   :0.04784   Mean   : 0.1054
 3rd Qu.:0.00000   3rd Qu.: 0.00000   3rd Qu.:0.00000   3rd Qu.: 0.0000
 Max.   :4.76000   Max.   :18.18000   Max.   :4.76000   Max.   :20.0000
   technology          X1999           parts              pm
 Min.   :0.00000   Min.   :0.000   Min.   :0.0000   Min.   : 0.00000
 1st Qu.:0.00000   1st Qu.:0.000   1st Qu.:0.0000   1st Qu.: 0.00000
 Median :0.00000   Median :0.000   Median :0.0000   Median : 0.00000
 Mean   :0.09748   Mean   :0.137   Mean   :0.0132   Mean   : 0.07863
 3rd Qu.:0.00000   3rd Qu.:0.000   3rd Qu.:0.0000   3rd Qu.: 0.00000
 Max.   :7.69000   Max.   :6.890   Max.   :8.3300   Max.   :11.11000
     direct              cs             meeting           original
 Min.   :0.00000   Min.   :0.00000   Min.   : 0.0000   Min.   :0.0000
 1st Qu.:0.00000   1st Qu.:0.00000   1st Qu.: 0.0000   1st Qu.:0.0000
 Median :0.00000   Median :0.00000   Median : 0.0000   Median :0.0000
 Mean   :0.06483   Mean   :0.04367   Mean   : 0.1323   Mean   :0.0461
 3rd Qu.:0.00000   3rd Qu.:0.00000   3rd Qu.: 0.0000   3rd Qu.:0.0000
 Max.   :4.76000   Max.   :7.14000   Max.   :14.2800   Max.   :3.5700
    project              re               edu              table
 Min.   : 0.0000   Min.   : 0.0000   Min.   : 0.0000   Min.   :0.000000
 1st Qu.: 0.0000   1st Qu.: 0.0000   1st Qu.: 0.0000   1st Qu.:0.000000
 Median : 0.0000   Median : 0.0000   Median : 0.0000   Median :0.000000
 Mean   : 0.0792   Mean   : 0.3012   Mean   : 0.1798   Mean   :0.005444
 3rd Qu.: 0.0000   3rd Qu.: 0.1100   3rd Qu.: 0.0000   3rd Qu.:0.000000
 Max.   :20.0000   Max.   :21.4200   Max.   :22.0500   Max.   :2.170000
   conference            ch.              ch..1           ch..2
 Min.   : 0.00000   Min.   :0.00000   Min.   :0.000   Min.   :0.00000
 1st Qu.: 0.00000   1st Qu.:0.00000   1st Qu.:0.000   1st Qu.:0.00000
 Median : 0.00000   Median :0.00000   Median :0.065   Median :0.00000
 Mean   : 0.03187   Mean   :0.03857   Mean   :0.139   Mean   :0.01698
 3rd Qu.: 0.00000   3rd Qu.:0.00000   3rd Qu.:0.188   3rd Qu.:0.00000
 Max.   :10.00000   Max.   :4.38500   Max.   :9.752   Max.   :4.08100
     ch..3             ch..4             ch..5             crl.ave
 Min.   : 0.0000   Min.   :0.00000   Min.   : 0.00000   Min.   :   1.000
 1st Qu.: 0.0000   1st Qu.:0.00000   1st Qu.: 0.00000   1st Qu.:   1.588
 Median : 0.0000   Median :0.00000   Median : 0.00000   Median :   2.276
 Mean   : 0.2691   Mean   :0.07581   Mean   : 0.04424   Mean   :   5.191
 3rd Qu.: 0.3150   3rd Qu.:0.05200   3rd Qu.: 0.00000   3rd Qu.:   3.706
 Max.   :32.4780   Max.   :6.00300   Max.   :19.82900   Max.   :1102.500
    crl.long          crl.tot
 Min.   :   1.00   Min.   :    1.0
 1st Qu.:   6.00   1st Qu.:   35.0
 Median :  15.00   Median :   95.0
 Mean   :  52.17   Mean   :  283.3
 3rd Qu.:  43.00   3rd Qu.:  266.0
 Max.   :9989.00   Max.   :15841.0  
data = spam
data[data==0] <- NA
boxplot(free ~ spam, data=data, col = color_palette[1], main="Boxplot free ~ spam", na.action = na.omit)

boxplot(mail ~ spam, data=spam,col = color_palette[1], main="Boxplot mail ~ spam", na.action = na.omit)

Number of Spam FALSE : 2788 TRUE : 1813 There are a lot of 0 values.

Splitting data

split = sample.split(spam$spam,SplitRatio = 0.75)
training_set = subset(spam,split == TRUE)
test_set = subset(spam,split == FALSE)
training_set$spam = ifelse(training_set$spam==TRUE, 1,0)
test_set$spam = ifelse(test_set$spam==TRUE, 1,0)

Fit models

1 logistic regression model

classifier.logreg = glm(spam ~ ., family = binomial,data = training_set)
classifier.logreg

Call:  glm(formula = spam ~ ., family = binomial, data = training_set)

Coefficients:
(Intercept)         make      address          all          X3d          our
  -1.605492    -0.155358    -0.136123     0.196375     2.105549     0.556877
       over       remove     internet        order         mail      receive
   1.742541     2.098125     0.532078     0.478785     0.282097    -0.519753
       will       people       report    addresses         free     business
  -0.100753    -0.210284     0.281151     1.717516     1.127949     0.917546
      email          you       credit         your         font         X000
   0.157576     0.036241     0.887090     0.192171     0.139295     2.045004
      money           hp          hpl       george         X650          lab
   0.370496    -2.240999    -0.847755   -18.067950     0.509855    -2.303256
       labs       telnet         X857         data         X415          X85
  -0.322576    -4.854319     5.895160    -0.885034     0.609476    -2.166808
 technology        X1999        parts           pm       direct           cs
   0.881314    -0.166003    -0.665962    -0.954530    -0.176493   -51.161622
    meeting     original      project           re          edu        table
  -2.798705    -1.922537    -1.570142    -0.785419    -1.540670    -2.356250
 conference          ch.        ch..1        ch..2        ch..3        ch..4
  -4.359194    -1.331310    -0.156926    -1.170922     0.254091     4.819150
      ch..5      crl.ave     crl.long      crl.tot
   2.702294     0.052523     0.007204     0.001247

Degrees of Freedom: 3450 Total (i.e. Null);  3393 Residual
Null Deviance:      4628
Residual Deviance: 1345     AIC: 1461
pred.glm = predict(classifier.logreg, newdata = test_set,type="response")
pred.glm_0_1 = ifelse(pred.glm >= 0.5, 1,0)
cm = table(test_set[,1],pred.glm_0_1)
cm
   pred.glm_0_1
      0   1
  0 665  32
  1  52 401

2 simple classification tree

Spam_tree = rpart(spam ~.,data=training_set,method="class")
plot(Spam_tree)
text(Spam_tree, pretty = 0)
title(main = "Classification Tree")

pred.tree = predict(Spam_tree,newdata=test_set, type ="class")
cm2 = table(test_set[,1],pred.tree)
cm2
   pred.tree
      0   1
  0 660  37
  1  93 360

3 Bagging

Spam_bagging=randomForest(spam~.,data = training_set, mtry=57)
pred.bagging =predict(Spam_bagging, test_set, type ="class")
pred.bagging_0_1 = ifelse(pred.bagging >=0.5, 1,0)
cm3 = table(test_set[,1],pred.bagging_0_1)
cm3
   pred.bagging_0_1
      0   1
  0 672  25
  1  38 415

4 Random Forests

Spam_forest=randomForest(spam~.,data = training_set, mtry=7)
pred.forest =predict(Spam_forest, test_set, type ="class")
pred.forest_0_1 = ifelse(pred.forest >=0.5, 1,0)
cm4 = table(test_set[,1],pred.forest_0_1)
cm4
   pred.forest_0_1
      0   1
  0 674  23
  1  34 419

5 Boosting models

Spam_boost = gbm(spam ~ ., data = training_set, distribution = "bernoulli",
                    n.trees = 5000, interaction.depth = 4, shrinkage = 0.01)
pred.boost =predict(Spam_boost, test_set, type = 'response')
pred.boost_0_1 = ifelse(pred.boost >= 0.5, 1,0)
cm5 = table(test_set[,1],pred.boost_0_1)
cm5
   pred.boost_0_1
      0   1
  0 673  24
  1  30 423

prediction accuracy of our models

accuracy = function(cm){
    return ((cm[1]+cm[4])/sum(cm))
}
accuracy_logreg = accuracy(cm)
accuracy_logreg
[1] 0.9269565
accuracy_singletree = accuracy(cm2)
accuracy_singletree
[1] 0.8869565
accuracy_bagging = accuracy(cm3)
accuracy_bagging
[1] 0.9452174
accuracy_rdmforest= accuracy(cm4)
accuracy_rdmforest
[1] 0.9504348
accuracy_boosting = accuracy(cm5)
accuracy_boosting
[1] 0.9530435
Join Newsletter
Get the latest news right in your inbox. I never spam!
Stéphan
Written by Stéphan
Computer science student in Paris.