Don’t get lost in a forest

Exploit the power and simplicity of tree-based models in R.

What you’ll find in this post

This is a brief intro to tree-based models in non-technical terms and their implementation in R. At the end of this post you’ll be able to apply various learning algorithms to a dataset. Since the post is already pretty long as it is, you’ll find only some code here, but no worries! Full code is in this post’s repo on Github and in a Kaggle notebook.

The structure of the post:

  • Why trees are here to stay: a brief intro to decision trees and the ratio behind them
  • When to use trees: reasons to use trees over other algorithms
  • R preparation: packages used in the post
  • Data preparation: data retrieval, loading and cleaning
  • Logistic regression benchmark: build a benchmark prediction
  • Fast and Frugal Trees: very simple decision tree with a good predicting power
  • It’s time to party!: using the party package for recursive trees
  • Ensemble models: intro to ensemble models
    • Bagging: start by growing hundreds of trees
    • Random Forest: thousands of simple trees are better!
    • Boosting: learning slowly might be even better
  • What we learned: a recap of lessons learned

Why trees are here to stay

Trees are a family of machine learning algorithms usually used for classification, they are among the first algorithms taught to learners because they’ re simple and effective. You probably won’t see them on latest machine learning papers and research, but trees are still widely used in the real world.

One of the main reasons for their widespread use is their simplicity and their interpretability. Below you can see a simple decision tree predicting if the sky is going to be overcast or not.

Decison tree

The nice thing about this method is that we get a way to predict some variable by feeding data to it, but probably even more important we can infer what’s the relationship among predictors. This means that we can start from the bottom and see what makes the outlook overcast.

In this case if we have weak wind, and looks like is going to rain we will have overcast. For simple models these rules can be learned and applied by humans, or we can produce checklists to assist the decision process. By visualizing the tree we can understand how the machine is working and why is classifying some days as overcast and some as not.

While this can seem trivial, in many cases you want to know why the model is making some predictions. Consider a model predicting whether to take in a patient with chest pain. After testing many advanced models, the doctors wanted to understand why the algorithm was sending home some people, patients’ life was at stake. So they ran a tree based model on the data, and it turned out that patients with chest pain and asthma were considered not at risk.

This was a huge mistake, physicians know very well that asthma and chest pain must be treated immediately, this meant that patients were taken in, treated and then dismissed. Can you see what was the issue? The data used for modeling considered these patients at low risk because all of them were treated and so very few of them died afterwards.

When to use trees?

As you saw earlier when interpretability matters trees are very good, even only for understanding what can go wrong with predictions. Actually, tree-based models can get very complicated losing interpretability but gaining precision. There’s always a trade-off.

Anoter reason to use trees is that they are easy to understand and explain. In case of a few strong predictors, trees can be used to build simple models that can be used by both machines and humans. One thing that comes to mind is a tree model predicting whether a customer will eventually purchase something or not.

Benchmarking is also one of the areas where these methods shine: you’ll soon find out that even pretty simple tree-based models are very difficult to beat by a wide margin in classification contexts. Personally, I often run random forest (more on this later) on the dataset I’m working on, and then I try to beat it.

R preparation

Before starting you may want to setup R, if you are new to it, or if you just want a refresher you’ll find every step needed in this tutorial. We will use a few packages, so you’d better install them now.

These are among the main packages for working with trees in R and for data munging, but they are not the only ones. There are tenths of packages for almost every kind of tree you may want, take a look at a rapid search on Crantastic if you don’t believe me.

Now it’s time to grow some tree! I’ve decided to use the Titanic dataset, one of the most famous dataset in the machine learning community. You can get the data from Kaggle, or from this post Github repo. I’ll jump directly to cleaning and modeling, if you need help with data downloading, loading or you get lost somewhere refer to my previous post or to the full code on Github.

Data Preparation

The first thing to do is to take a look at the data to see what we are working with. Below a table with every field explanation.

Field NameDescriptionType
PassengerIDThe id number of the passengerinteger (1,2,...,n)
SurvivedWhether the passenger survived or notinteger (0 = not survived, 1 = survived)
PclassThe class where the passenger was travelinginteger (1 = 1st class, 2 = 2nd class, 3 = 3rd class)
NameName of the passengercharacter
SexSex of the passengercharacter (male, female)
AgeAge of the passengernumeric
SibSpNumber of siblings/spouses aboardinteger
ParchNumber of parents/children aboardinteger
TicketTicket numbercharacter
FareThe passenger farenumeric
CabinCabin numbercharacter
EmbarkedPort of embarkationcharacter (S = Southampton, Q = Queenstown, C = Cherbourg)

I really dislike when datasets have uppercase names, luckily we can change them in just one line of code using the tolower() function:

Then convert the sex and embarked variables to factors:

One of the most important steps when modeling is to deal with NAs. Many R models can deal automatically with missing values, but most of them will simply remove the observations containing NA. This means less training data for our models, and almost surely less precision.

There are various techniques to impute NAs: imputing the mean, median or mode, or we can predict with a model their value. In this case we will use a linear regression to substitute missing values in the age variable in the dataset.

At the beginning this concept can look a bit scary or weird, you may think: “Are you saying that to improve my model I should use another model?!“. But it isn’t as difficult as it looks, especially if we use linear regression.

First let’s see how many NAs there are in the age variable:

Almost 20% of passengers don’t have a recorded age, which means that if we would have run a model on the dataset without substituting missing values we would have trained it with 714 records, and not with all 891 passengers.

Time to run a linear regression on the data

What we did here? We told R to solve the linear equation by finding the right values for \alpha and \beta_n

age = \alpha + \beta_1*survived + \beta_2*pclass + \beta_3*fare

To check the result of our linear regression we call summary() on the created model. R will spit out some statistics that we need to check in order to understand what is going on with our data.

The first line in our results reminds us which model refers to this result (Call), the second line shows us residuals and then we have coefficients. Here we see the coefficients estimates, their standard error and the t and p-values. Afterwards we have other statistics and we see that R actually removed data with NAs (177 observations deleted due to missingness).

Now we can use the model to impute NAs. We will use the predict() function:

Logistic Regression Benchmark

For binary classification problems – one can survive or not – logistic regression can be very difficult to beat. We will predict who survived the Titanic with a logistic regression, and we will use this result as our benchmark.

Don’t worry, performing logistic regression in R is pretty straightforward, but if you want a refresher about basic modeling and data wrangling (for instance if you have issues with the %>% operator below) take a look at this brief tutorial about dplyr and intubate. To run it, we just need to call the  glm() function and pass as arguments the response variable (survived) and the predictors (agepclass, etc) while indicating the dataset. The last thing to do is to pass binomial to the family argument telling R that we have a binomial outcome as a response.

Afterwards we may want to take a look at the predictions of the model and see what’s the result of the logistic regression

The confusion matrix above tells us what’s the result of the model on training data: 572 people are predicted to die (0), 319 will survive (1). To see what’s the error check the diagonals of the matrix: 480 and 250 are correct predictions, while the logistic regression predicts as “not survived” 92 people that actually survived, and 69 as “survived” who actually didn’t.

An 82% prediction accuracy out of the box is pretty good for such a simple model, right? But we want to test it on data it has never seen, so let’s load the test set and try our model on it.

And now predict survival rates on test data

We save our result as csv because test data is not labeled, so we don’t know if our predictions are right. To check the results we have to upload the result on KaggleIf you don’t want to upload results, you’ll have to trust me (result in the picture below) and consider 77,5% of correct predictions.

Logistic regression result

Fast and Frugal Trees

Time to grow some tree! The first model we will try are Fast and Frugal Trees. These trees are basically the simplest possible models that maximize results. To create FFTrees there is the FFTrees package in R.

After loading the package, we just have to grow the FFTree on selected variables.

The model requires a few seconds, this is because more than one FFTree will be grown and tested on training data. The result is an FFTree object with all tested trees that are built by considering from 1 to 5 pieces of information and by ignoring all the rest.

By printing the object in console we see that the algorithm tested 8 trees using at most 4 predictors, and that the best performing is tree #5. Then we have some statistics about the result of this tree. The output is helpful, but the best way to understand what’s going on and the reason this is one of my favourite packages is the plotting method.

fftree Titanic

We get a lot of information in just one plot, starting from the top: number of observations, number of classes, the number of the tree and then the tree itself and some diagnostics. Let’s focus on the tree.

The very first node considers the sex variable: if we have a female (sex != male) we will directly exit the tree by predicting survived. Brutal, but pretty effective. If we have a male we will go through the second node: pclass. Here, if we are in third class we will exit the tree by predicting not survived. If we survived, we’d better have paid more than £ 26.96 (fare), because if we paid less we would exit the tree and predict not survived. The last node considers age: if we are older than 21.35 we wouldn’t survive the Titanic.

In the Performance section of the plot we can check the confusion matrix on the left, and other statistics on the right. In this case we care mostly about the confusion matrix that we can compare with the one we got with the logistic regression.

Or, we can look at the ROC curve on the right. The FFTrees package automatically performs a logistic regression and a CART (another kind of tree) on the data and compares them with the modeled trees. If you look closely you can see that the purple circle (logistic regression) is almost completely hidden by the circle #5 indicating the plotted tree. This means that the performances of the two models are comparable.

Now we have to classify test data and submit it through Kaggle. As I told you before the good thing about these trees is that they are dead simple. When I was explaining how the tree worked, every sentence at every node started with an “if”, this means that we can follow the same structure and build a classifier, a checklist or we could even memorize it.

4 nested ifelse statements are all we need to classify the whole dataset. We get just 2 NAs, that I decided to classify as “not survived“. All we have left to do is to upload a csv of the results and check how the model performed.

fftree result

Exactly, our 4 if-else statements got just 1% less than our benchmark. This is remarkable considering the simplicity of the model.

It’s time to Party!

The party package uses inferencial trees which are trees generated by single splits of the data undergoing a stochastic process. This means that are more complicated than FFTrees, but only under the hood, the final result will always be a nice tree.

The difference between this package and other trees implementations is that with party the tree won’t be built by simply splitting nodes in order of importance, but it will also consider data distribution. To use it we just have to load the package and build a tree with the ctree function.

After running the model the package has a plotting method to draw the resulting decision tree. It’s enough to call plot(ctree_result) to get the resulting tree. In this case we don’t care much about some bells and whistles and I like a clean look, so let’s use some optional arguments.

The resulting tree

Unfortunately large trees take a lot of space and just another couple of nodes more would have made the plotting nearly useless. Comparing this tree to the FFTree above we see it’s more complicated: before we were predicting every male as Not survived, while now the model is trying to split males as well.

The added complexity reduces training error to 15% as we can see by predicting the training set itself. This is an improvement compared to the FFTree above.

But I’m afraid we’re going to learn one of the most important lessons in machine learning the hard way. In fact, after predicting on test data and uploading the result on Kaggle we get just a 73.7% of correct classifications.

ctree result on Kaggle

You may ask, how is that even possible? We just saw what can happen with overfitting. The model accounted for some variable that turns out to be just noise. The result is an improvement on training set, but a decrease of performance on data it has never seen before. There are various ways to deal with the issue, in this case pruning would likely help. Pruning means cutting branches, and we would do it on our model by decreasing allowed depth of the tree. This practice, coupled with cross-validation is likely to improve the result on test data.

Ensemble models

Up to now we developed single learners, meaning that we’re finding solutions with just one model. Another family of machine learning algorithms are ensemble, models built by many so-called weak learners. The theory is that by using many learners (in this case decision trees) with a few variables and combining their choices we would get a good result.

Ensemble models vary depending on model building method and the way we combine results to get only one answer from many learners. It can look a bit messy, but some of these methods tend to work well out-of-the-box and are a good choice for profiling and setting a bar for improvement.

The aim of these models is to reduce variance, as we saw with the above decision tree we got a good result on training set, but a much worse error rate on testing. This is typical of these learners, and if we had a different training set we would have got a completely different result.

We will see three different algorithms: Bagging, Random Forest and Boosting.


The main idea of bagging is fairly simple: if we grow many large trees on different training sets we would get many models with high variance, but low bias. By averaging predictions from every tree, we get one classification with relatively low variance and low bias.

You may have already found an issue, we don’t have many training sets. To deal with this we create them with bootstrapping. The bootstrap is just a repeated sampling with replacement. We can perform a bootstrap of a dataset in base R without using any package like this:

To run bagging on Titanic data we can use the randomForest package. This is because bagging and random forest are similar and differ only in how many predictors we will consider for building trees. With bagging we consider every predictor in the dataset, and we can control this parameter with the mtry argument as below.

After selecting the columns we use as predictors and response, we call the randomForest function and it is possible to use a formula as usual, but we have to tell the function that we want classification trees and not regression (yes, you can estimate parameters for a regression as well) by passing the response as a factor.

As I said before, the mtry parameters limits the number of predictors the algorithm will consider when building every tree of the 500 built by default. In case you want to increase the number of trees to grow, add the ntree argument and set a higher number.

The issue with these algorithms is that they simply pass NAs and don’t predict them. To avoid further feature engineering but get a valid result from Kaggle I decided to substitute NAs in the test set with the median and then predict who survived. Unfortunately, this issue limited predicting power and the result is 66.5% of correct predictions.

Random Forest

Random Forest is one of the most famous machine learning algorithms, the reason is that it usually performs insanely well out-of-the-box. The method is the same as for bagging, but in this case we want weaker learners, built considering only a limited number of predictors.

You may ask what is the difference between using all predictors, and using only some of them. The answer is that by using all predictors, every time we grow a tree on a different bootstrapped dataset it’s really likely that the first and second split will be the same because trees are built considering importance. So our 500 trees using bagging will be very similar and the same goes with predictions.

To limit this behaviour we use random forest and we will limit predictors using the mtry argument. To decide the “best” value for it we can use cross-validation, or try with some empirical rules. The default values are ncol(data)/3 and sqrt(ncol(data)), but in this case I’m going to use 3.

My suggestion is to experiment with different values and check what happens.

The result is 74.6%, much better than bagging, but a bit worse than logistic regression.

Since there are many implementations of random forest, we might try with the party package that uses inferential trees for its algorithm.

As you can see the code is similar to the previous one, but what about the result?

We can consider this result as a draw with logistic regression. The reason is that we can see only half of test results, the other half will become public in 2017, so half percentage point of difference could translate in a better berformance on full test data.


With boosting we try to learn slowly, and not in a “hard” way as with the previous algorithms. In fact, to avoid overfitting we had to grow thousands of small trees and average all predictions. Boosting works in a different way: grow a tree, than grow another on the results of the first one, continue to grow trees as long as needed.

Boosting learns slower than the other tree-based algorithms, this can help in preventing overfitting, but we have to be careful by tuning learning speed. You’ll see that parameters are similar to random forest and at this point you should have understood how the “thing” works.

We use the gbm package homonym function and after inputting the formula we tell to the function that this is a classification problem by setting bernoulli as distribution. We want 5000 trees and we limit them to a maximum depth of three with interaction.depth.

76% as a result puts boosting together with logistic regression, random forest and FFtrees.

What we learned

Yes! Right, a lot of people died in the Titanic. But this is not the lesson of this post. There are a few important concepts to take home with you:

  • Complex models > simple models == FALSE. Logistic regression and FFTrees were tough to beat, and we could have made it only by improving feature engineering
  • Feature engineering > complex models == TRUE. Feature engineering is an art. It’s one of the most powerful weapons for a data scientist and we can use it to improve our predictions
  • Model building == FUN! Data science can be fun, and though sometimes R makes learning and training a bit frustrating, it’s very rewarding. If you want to investigate further or if you want a step-by-step guide you can head over to github to find full code to reproduce this post and an intro to dplyr and intubate

If you enjoyed this you can let me know commenting below, or by spreading this post. You can also follow the blog and/or subscribe to the newsletter.