Decision Trees in Python: Predicting Diabetes
In this post, we’ll be learning about decision trees, how they work and what the benefits are for using them. We’ll also use this algorithm in a real-world data to predict diabetes.
So, what are decision trees? Decision trees are a machine learning method for classification or regression. It works by segmenting the dataset through if-else control statements applied to the features.
There are few algorithms that can be used to implement decision trees and you may have heard of some of them. The most popular algorithms are ID3, C4.5 and CART. However, the Scikit-learn python library only supports the CART algorithm which stands for Classification and Regression Trees. This article is entirely based on the CART algorithm.
Benefits of Decision Trees
Decision trees are known as ‘white box’ models which means that you can easily find and interpret their decisions. This is in contrast to ‘black box’ neural networks where it is extremely difficult to figure out exactly how final predictions were made. Fortunately, decision tree models are easy to explain in simple terms, along with why and how the predictions were made by the model. Since decision trees are just if-else control statements at heart, you can even apply their rules and make predictions by hand.
Decision trees can be easily visualised in a tree-like plot that makes it even easier to understand and interpret the model. Have a look at this simplified decision tree below based on the data we’ll be analysing later on in this article. We can actually take a single data point and trace the path it would take to reach the final prediction for it.
The Scikit-learn python library together with the CART algorithm supports binary, categorical, and numerical target variables. However, for the feature variables, only binary and numerical features are supported at this time. This means that each node in the decision tree can only have up to 2 branches leading out the node and so features must either be true or false.
The good news is that decision trees require very little data preparation and so you don’t need to worry about centering or normalising the numerical features first. Having said that, there are still a couple of best practices to follow when fitting a decision tree to your data and we’ll chat about them a bit more towards the end of this article.
It’s also good to keep in mind that decision trees are pretty sensitive to even small changes in the data and tend to learn the data like a parrot. This means that it is easy for the model to overfit to the data and can even be biased if the target variable classes are unbalanced. For this reason, decision trees must be closely controlled and optimised to prevent these problems (also more on this later).
How Does it Work?
Without getting all technical, let’s go over how the decision tree CART algorithm works. The main goal is to divide the data into distinct regions and to then make predictions based on those regions.
Starting at the top of the tree, the first question the algorithm must answer is “which feature should be chosen at the root?” To answer this question, the algorithm needs a way of evaluating each feature and choosing the ‘best’ feature to start with. Thereafter, the algorithm needs to keep asking a similar question at each node: “which feature should be used to split this node?” It does this by calculating and optimising a metric against each of the available features.
There are a couple of metrics that can be used depending on the problem at hand. For example, if we’re dealing with a regression problem then we can seek to find the feature with the lowest RSS (residual sum of squares). However, if we have a classification problem then we can choose the feature with the lowest entropy (or highest information gain) or the lowest gini impurity.
If you’re wondering whether to choose entropy or gini impurity for classification problems, don’t waste too much time on it as there isn’t a hell of a lot of difference between the resulting decision trees. So toss a (balanced) coin and pick one.
Sklearn uses the gini impurity by default so we’ll briefly go over this metric. The gini impurity is a metric that measures the purity of a node. That is, it measures how similar the observations are to each other. If all observations belong to the same class then the gini impurity would be 0 and the node would be considered ‘pure’.
The decision tree algorithm is in constant pursuit of pure nodes and it will continue to split the data into deeper and deeper trees until it finally reaches pure leaf nodes (or it runs out of data or features). In addition, the data is split in a greedy fashion using recursive binary splitting. It’s called greedy because the algorithm will make a split based on what is optimal for the step it is currently dealing with and will not choose a split that can result in a more optimal tree further down the line. This also means there are no backsies – the algorithm won’t backtrack and change its mind for a previous split.
We mentioned ‘leaf’ nodes. These are nodes that do not get any further splits and any observation that takes a route to a leaf node will get the resulting predicted class. At each leaf node, the class that has more than 50% of its samples belonging to it will serve as the prediction for that node. For classes with a tie, the non-event class is automatically chosen.
Predicting Diabetes with Decision Trees in Python
The data in this project contains biographical and medical information that is used to predict whether or not a patient has diabetes. You can find the data on Kaggle.
These are the goals for this project:
- Explore the data – determine if it requires any cleaning and if there are any correlations in the data
- Apply the decision tree classification algorithm (using sklearn)
- Visualise the decision tree
- Evaluate the accuracy of the model
- Optimise the model to improve accuracy
Setup
import numpy as np import pandas as pd from sklearn import tree from sklearn.model_selection import train_test_split from sklearn import metrics import graphviz import matplotlib.pyplot as plt import seaborn as sns sns.set() custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", rc=custom_params) sns.set_palette(sns.dark_palette("seagreen", reverse=True))
Explore
These are the feature variables in our data:
- Number of pregnancies
- Glucose
- Blood pressure
- Skin thickness
- Insulin
- BMI
- Diabetes Pedigree Function
- Age
The most confusing feature here is the diabetes pedigree function. To understand it, I read the recommended research paper for this dataset and made these notes:
- The criteria for a diabetes diagnosis is if the ‘2 hour post-load plasma glucose was at least 200 mg/dl’.
- This dataset specifically contains women over the age of 21.
- The Diabetes Pedigree Function provides ‘a synthesis of the diabetes mellitus history in relatives and the genetic relationship of these relatives to the subject’. In other words, this score is higher if there is a family history of diabetes and it’s lower if not.
To start our data exploration, let’s have a look at some summary statistics of our data:
These are a few points we can note about the data:
- All our features are numerical
- We have a total sample size of 768
- There are no missing values to deal with right now
- A few features have a minimum value of 0 which is suspicious for (living) humans:
- Min glucose = 0
- Min blood pressure = 0
- Min skin thickness = 0
- Min insulin = 0
- Min BMI = 0
To clean this up, we will convert these zeros to nulls and remove them from our dataset. This takes our dataset down from 768 to 392 (that was painful!).
Next, let’s have a look at a correlation matrix between all the variables in our data.
- The outcome is positively correlated with all features which is a good sign for modelling
- The outcome is most strongly correlated with Glucose (which makes sense since this is about Diabetes) and then Age
- There is a strong correlation between Age and Pregnancies – older women = more pregnancies?
- Insulin and Glucose are correlated – higher insulin = higher glucose?
- SkinThickness and BMI are correlated – higher BMI = higher skin thickness?
An additional note to make is that decision trees tend to be sensitive to unbalanced classes. So, we’re also going to take note of how many observations fall into each outcome class from the original dataset:
Not Diabetes | Diabetes |
---|---|
500 | 268 |
From this table, we can see that the outcome classes are unbalanced – there are twice as many non-events as there are events. This could make it difficult for the model to correctly predict when someone has diabetes. We may need to balance out the classes during the optimisation phase to see if it improves the accuracy of the model.
Model
The is the modelling process we’ll follow to fit a decision tree model to the data:
- Separate the features and target into 2 separate dataframes
- Split the data into training and testing sets (80/20) – using
train_test_split
fromsklearn
- Apply the decision tree classifier – using
DecisionTreeClassifier
fromsklearn
- Predict the target for the test set
- Evaluate model accuracy – using
metrics
fromsklearn
- Visualise the decision trees – using
graphviz
- Optimise the decision tree – using various parameters such as
max_depth
,min_samples_split
, etc.
# Separate features and target target = df_reduced["Outcome"] features = df_reduced.drop("Outcome", axis = 1) # Split into train and test sets features_train, features_test, target_train, target_test = train_test_split(features, target, test_size = 0.2, random_state = 42) # Fit the decision tree model decisiontree = tree.DecisionTreeClassifier(random_state = 42) decisiontree = decisiontree.fit(features_train, target_train) # Predict the target for the test set target_pred = decisiontree.predict(features_test)
This fit gives an accuracy score of 72.15%. Not too bad, but I’m sure we can improve it.
Visualising this tree, we can see it is a bit of a mess. Some nodes have just 1 sample in them and since we don’t have much data we will need to control the number of samples in each node and also the max depth as this can lead to overfitting and poor performance on unseen data.
Optimise
We’ll try a couple of strategies to optimise this model:
max_depth
= 3,min_samples_leaf
= 2 — this produced an accuracy of 74.68%max_depth
= 4,min_samples_leaf
= 2 — this produced as accuracy of 73.42%max_depth
= 5,min_samples_leaf
= 2 — this produced an accuracy of 75.95%
Increasing the depth any further results in declining performance. A max depth of 5 gave the highest accuracy.
Next, we will balance the dataset by randomly selecting only 130 rows from a total of 262 where the outcome = 0 (ie. class = Not Diabetic). Balancing the dataset produces an accuracy of 78.85%, in addition to setting max_depth
= 5 and min_samples_leaf
= 2.
Looks like we’ve found a winner 🙂
Here is a visualisation of this optimised tree:
I hope you found this post helpful. If you have any questions don’t hesitate to drop a comment or reach out to me on Twitter!