This blog explains how to use cost-complexity pruning to find the optimal subtree of a decision tree and avoid overfitting. It also provides an example using the Iris dataset.
1. Introduction
In this blog, you will learn about one of the most important machine learning pruning techniques: cost-complexity pruning. Pruning is a method of reducing the size and complexity of a decision tree to avoid overfitting and improve generalization. Cost-complexity pruning is a way of finding the optimal subtree of a decision tree that minimizes a trade-off between accuracy and complexity.
But what is overfitting and why is it a problem? How can you measure the complexity of a decision tree? How can you find the optimal subtree and the best trade-off parameter? And how can you apply cost-complexity pruning to a real-world dataset? These are some of the questions that you will answer in this blog.
By the end of this blog, you will be able to:
- Explain what are decision trees and how they can overfit the data.
- Use cost-complexity pruning to find the optimal subtree of a decision tree.
- Choose the best alpha value for cost-complexity pruning using cross-validation.
- Apply cost-complexity pruning to the Iris dataset and compare the results.
Ready to learn about cost-complexity pruning? Let’s get started!
2. Decision Trees and Overfitting
Before we dive into cost-complexity pruning, let’s first understand what are decision trees and why they can overfit the data. Decision trees are one of the most popular and intuitive machine learning algorithms that can be used for both classification and regression tasks. They are based on a hierarchical structure of nodes and branches that represent the possible outcomes of a series of decisions.
A decision tree is built by recursively splitting the data into smaller subsets based on some criteria, such as the value of a feature or the purity of a class. Each split creates a new node in the tree, which can be either an internal node (with further splits) or a leaf node (with a final prediction). The goal is to create a tree that can accurately predict the target variable for any new data point.
However, decision trees have a tendency to overfit the data, especially when they are grown too deep or complex. Overfitting means that the model learns the noise and the specific patterns of the training data, but fails to generalize well to new and unseen data. This results in poor performance and high variance.
How can you tell if a decision tree is overfitting the data? And how can you prevent or reduce overfitting? These are the questions that we will answer in the next section.
2.1. What are Decision Trees?
Decision trees are one of the most popular and intuitive machine learning algorithms that can be used for both classification and regression tasks. They are based on a hierarchical structure of nodes and branches that represent the possible outcomes of a series of decisions.
A decision tree is built by recursively splitting the data into smaller subsets based on some criteria, such as the value of a feature or the purity of a class. Each split creates a new node in the tree, which can be either an internal node (with further splits) or a leaf node (with a final prediction). The goal is to create a tree that can accurately predict the target variable for any new data point.
To illustrate how a decision tree works, let’s look at a simple example. Suppose you want to classify whether a person is male or female based on their height and weight. You can use a decision tree to create a set of rules that can separate the two classes. For example, you can start by asking if the person’s height is less than 170 cm. If yes, you can create a left branch and assign it to the female class. If no, you can create a right branch and ask another question, such as if the person’s weight is less than 80 kg. You can continue this process until you reach a leaf node with a single class label.
A decision tree is easy to understand and interpret, as it mimics the human way of making decisions. However, decision trees have some drawbacks, such as the tendency to overfit the data, especially when they are grown too deep or complex. Overfitting means that the model learns the noise and the specific patterns of the training data, but fails to generalize well to new and unseen data. This results in poor performance and high variance.
How can you tell if a decision tree is overfitting the data? And how can you prevent or reduce overfitting? These are the questions that we will answer in the next section.
2.2. How to Measure Overfitting?
One of the main challenges of machine learning is to avoid overfitting, which means that the model learns the noise and the specific patterns of the training data, but fails to generalize well to new and unseen data. This results in poor performance and high variance.
But how can you tell if a decision tree is overfitting the data? And how can you measure the degree of overfitting? There are several ways to do that, but one of the most common and simple methods is to use a train-test split.
A train-test split is a technique that divides the data into two subsets: a training set and a test set. The training set is used to build and train the model, while the test set is used to evaluate the model’s performance on new and unseen data. The idea is to compare the accuracy of the model on both sets and see if there is a significant difference.
If the model has a high accuracy on the training set, but a low accuracy on the test set, it means that the model is overfitting the data. It has learned the specific details of the training data, but it cannot generalize to new data. This is a sign that the model is too complex and needs to be simplified.
On the other hand, if the model has a low accuracy on both the training set and the test set, it means that the model is underfitting the data. It has not learned enough from the training data, and it cannot capture the underlying patterns of the data. This is a sign that the model is too simple and needs to be improved.
The ideal scenario is when the model has a high accuracy on both the training set and the test set, or at least a similar accuracy. This means that the model is fitting the data well, and it can generalize to new data. This is a sign that the model has the right level of complexity and does not need to be changed.
However, finding the right level of complexity for a decision tree is not easy, as it depends on many factors, such as the number of features, the depth of the tree, the splitting criteria, and the pruning method. In the next section, we will learn about one of the most effective pruning methods: cost-complexity pruning.
3. Cost-Complexity Pruning
Cost-complexity pruning is one of the most effective pruning methods for decision trees. It is also known as minimum error-complexity pruning or weakest link pruning. It is based on the idea of finding the optimal subtree of a decision tree that minimizes a trade-off between accuracy and complexity.
But what is a subtree and what is the trade-off? A subtree is a smaller tree that is obtained by removing some nodes and branches from the original tree. The trade-off is a balance between the error rate and the number of nodes of the subtree. The error rate is the proportion of incorrect predictions made by the subtree on the training data. The number of nodes is a measure of the complexity or size of the subtree.
The goal of cost-complexity pruning is to find the subtree that has the lowest cost-complexity, which is defined as:
$$
\text{cost-complexity} = \text{error rate} + \alpha \times \text{number of nodes}
$$
The alpha parameter is a positive constant that controls the trade-off. A higher alpha value means more emphasis on simplicity, while a lower alpha value means more emphasis on accuracy. The optimal alpha value is the one that minimizes the cost-complexity of the subtree.
How can you find the optimal subtree and the optimal alpha value? There are two main steps involved in cost-complexity pruning:
- Generate a sequence of subtrees by removing the nodes that have the smallest impact on the error rate.
- Select the best subtree from the sequence by using cross-validation or a validation set.
In the next section, we will explain these steps in more detail and show how to implement them in Python.
3.1. What is Cost-Complexity Pruning?
Cost-complexity pruning is one of the most effective pruning methods for decision trees. It is also known as minimum error-complexity pruning or weakest link pruning. It is based on the idea of finding the optimal subtree of a decision tree that minimizes a trade-off between accuracy and complexity.
But what is a subtree and what is the trade-off? A subtree is a smaller tree that is obtained by removing some nodes and branches from the original tree. The trade-off is a balance between the error rate and the number of nodes of the subtree. The error rate is the proportion of incorrect predictions made by the subtree on the training data. The number of nodes is a measure of the complexity or size of the subtree.
The goal of cost-complexity pruning is to find the subtree that has the lowest cost-complexity, which is defined as:
$$
\text{cost-complexity} = \text{error rate} + \alpha \times \text{number of nodes}
$$
The alpha parameter is a positive constant that controls the trade-off. A higher alpha value means more emphasis on simplicity, while a lower alpha value means more emphasis on accuracy. The optimal alpha value is the one that minimizes the cost-complexity of the subtree.
How can you find the optimal subtree and the optimal alpha value? There are two main steps involved in cost-complexity pruning:
- Generate a sequence of subtrees by removing the nodes that have the smallest impact on the error rate.
- Select the best subtree from the sequence by using cross-validation or a validation set.
In the next section, we will explain these steps in more detail and show how to implement them in Python.
3.2. How to Find the Optimal Subtree?
The first step of cost-complexity pruning is to generate a sequence of subtrees by removing the nodes that have the smallest impact on the error rate. This can be done by using a bottom-up approach, starting from the leaf nodes and moving up to the root node.
For each node in the tree, we can calculate its cost-complexity as:
$$
\text{cost-complexity}(t) = \text{error rate}(t) + \alpha \times \text{number of nodes}(t)
$$
where t is the subtree rooted at that node, and alpha is a fixed parameter. The error rate of a subtree is the proportion of incorrect predictions made by the subtree on the training data. The number of nodes of a subtree is the total number of nodes in that subtree, including the root node and the leaf nodes.
We can also calculate the cost-complexity of the entire tree as:
$$
\text{cost-complexity}(T) = \text{error rate}(T) + \alpha \times \text{number of nodes}(T)
$$
where T is the original tree. The cost-complexity of the tree is equal to the sum of the cost-complexities of all the leaf nodes.
To generate a sequence of subtrees, we can use the following algorithm:
- Start with the original tree T as the first subtree in the sequence.
- Find the node t that has the smallest cost-complexity among all the internal nodes (non-leaf nodes) in the current subtree.
- Remove the node t and its children from the current subtree, and replace it with a leaf node. This creates a new subtree with one less internal node.
- Add the new subtree to the sequence.
- Repeat steps 2-4 until the current subtree has only one node (the root node).
This algorithm will produce a sequence of subtrees, each with a different alpha value and a different cost-complexity. The alpha value of each subtree is the difference between the cost-complexities of the previous subtree and the current subtree, divided by the difference between the number of nodes of the previous subtree and the current subtree. For example, if the previous subtree has a cost-complexity of 0.2 and 10 nodes, and the current subtree has a cost-complexity of 0.18 and 9 nodes, then the alpha value of the current subtree is:
$$
\alpha = \frac{0.2 – 0.18}{10 – 9} = 0.02
$$
The alpha value of the first subtree (the original tree) is zero, and the alpha value of the last subtree (the root node) is infinity. The alpha values of the subtrees in between are increasing from zero to infinity.
In the next section, we will show how to select the best subtree from the sequence by using cross-validation or a validation set.
3.3. How to Choose the Best Alpha Value?
The second step of cost-complexity pruning is to select the best subtree from the sequence by using cross-validation or a validation set. Cross-validation is a technique that splits the data into k folds, and uses one fold as the test set and the rest as the training set. This process is repeated k times, and the average performance of the model on the test sets is calculated. A validation set is a subset of the data that is held out from the training process and used to evaluate the model’s performance.
The idea is to use cross-validation or a validation set to measure the accuracy of each subtree in the sequence on new and unseen data. The best subtree is the one that has the highest accuracy, or the lowest error rate, on the test set or the validation set. This subtree will have the optimal alpha value for cost-complexity pruning.
However, there is a caveat to this approach. Sometimes, there may be more than one subtree that has the same or similar accuracy on the test set or the validation set. In this case, how can you choose the best subtree? The answer is to use the principle of Occam’s razor, which states that the simplest explanation is usually the best. In other words, you should choose the smallest or the simplest subtree that has the same or similar accuracy as the larger or more complex subtrees. This will help you avoid overfitting and improve generalization.
In the next section, we will show how to apply cost-complexity pruning to the Iris dataset and compare the results.
4. Example: Cost-Complexity Pruning on the Iris Dataset
In this section, we will show how to apply cost-complexity pruning to the Iris dataset and compare the results. The Iris dataset is a classic dataset that contains 150 samples of three different species of iris flowers: setosa, versicolor, and virginica. Each sample has four features: sepal length, sepal width, petal length, and petal width. The task is to classify each sample into one of the three species based on the features.
We will use the scikit-learn library in Python to load the dataset, train a decision tree classifier, and perform cost-complexity pruning. We will also use matplotlib and graphviz to visualize the decision tree and the subtrees. The following code imports the necessary modules and loads the dataset:
# Import modules import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz from sklearn.model_selection import train_test_split, cross_val_score from sklearn.metrics import accuracy_score import graphviz # Load the Iris dataset iris = load_iris() X = iris.data # Features y = iris.target # Labels feature_names = iris.feature_names # Feature names class_names = iris.target_names # Class names
The next step is to split the data into training and test sets. We will use 80% of the data for training and 20% for testing. We will also set a random state for reproducibility. The following code splits the data and prints the shapes of the resulting sets:
# Split the data into training and test sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Print the shapes of the sets print("Training set shape:", X_train.shape, y_train.shape) print("Test set shape:", X_test.shape, y_test.shape)
The output is:
Training set shape: (120, 4) (120,) Test set shape: (30, 4) (30,)
Now we are ready to train a decision tree classifier on the training set. We will use the default parameters of the scikit-learn DecisionTreeClassifier class, which means that the tree will grow until all the leaves are pure or contain less than two samples. The following code trains the classifier and prints its depth and number of nodes:
# Train a decision tree classifier on the training set clf = DecisionTreeClassifier(random_state=42) clf.fit(X_train, y_train) # Print the depth and number of nodes of the tree print("Depth of the tree:", clf.get_depth()) print("Number of nodes of the tree:", clf.get_n_leaves())
The output is:
Depth of the tree: 5 Number of nodes of the tree: 9
We can also visualize the tree using the plot_tree function from scikit-learn. The following code plots the tree and saves it as an image file:
# Visualize the tree using plot_tree plt.figure(figsize=(12,8)) plot_tree(clf, feature_names=feature_names, class_names=class_names, filled=True) plt.savefig("tree.png") plt.show()
As you can see, the tree has five levels and nine nodes. Each node shows the feature and the threshold used for the split, the gini impurity of the node, the number of samples in the node, the number of samples for each class in the node, and the predicted class for the node. The leaves are colored according to the majority class.
To evaluate the performance of the tree, we can use the test set and calculate the accuracy score. The accuracy score is the proportion of correct predictions made by the model on the test set. The following code calculates and prints the accuracy score of the tree:
# Predict the labels for the test set y_pred = clf.predict(X_test) # Calculate and print the accuracy score acc = accuracy_score(y_test, y_pred) print("Accuracy score of the tree:", acc)
The output is:
Accuracy score of the tree: 1.0
This means that the tree correctly classified all the samples in the test set. This may seem impressive, but it could also indicate that the tree is overfitting the data. To check if this is the case, we can use cross-validation to measure the accuracy of the tree on different folds of the data. Cross-validation is a technique that splits the data into k folds, and uses one fold as the test set and the rest as the training set. This process is repeated k times, and the average performance of the model on the test sets is calculated. The following code performs 10-fold cross-validation on the entire dataset and prints the mean and standard deviation of the accuracy scores:
# Perform 10-fold cross-validation on the entire dataset scores = cross_val_score(clf, X, y, cv=10) # Print the mean and standard deviation of the scores print("Mean accuracy score of the tree:", scores.mean()) print("Standard deviation of the accuracy score of the tree:", scores.std())
The output is:
Mean accuracy score of the tree: 0.96 Standard deviation of the accuracy score of the tree: 0.05333333333333332
This shows that the tree has a high mean accuracy score, but also a relatively high standard deviation. This means that the tree’s performance varies depending on the fold of the data. This could be a sign of overfitting, as the tree may be learning some noise or specific patterns of some folds, but not others. To reduce overfitting, we can use cost-complexity pruning to find the optimal subtree of the tree that minimizes the trade-off between accuracy and complexity.
4.1. Load and Explore the Data
In this section, you will load and explore the Iris dataset, which is a classic dataset for machine learning. The Iris dataset contains 150 samples of three different species of iris flowers: setosa, versicolor, and virginica. Each sample has four features: sepal length, sepal width, petal length, and petal width. The goal is to classify each sample into one of the three species based on the features.
To load the Iris dataset, you can use the load_iris function from the scikit-learn library. This function returns a dictionary-like object that contains the data, the target, the feature names, and the target names. You can convert this object into a pandas DataFrame for easier manipulation and visualization.
The following code shows how to load and explore the Iris dataset:
# Import libraries import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.datasets import load_iris # Load the Iris dataset iris = load_iris() # Convert the data and target into a DataFrame df = pd.DataFrame(data=np.c_[iris['data'], iris['target']], columns=iris['feature_names'] + ['target']) # Map the target values to the target names df['target'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'}) # Print the first five rows of the DataFrame print(df.head())
The output of the code is:
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target |
---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | setosa |
4.9 | 3.0 | 1.4 | 0.2 | setosa |
4.7 | 3.2 | 1.3 | 0.2 | setosa |
4.6 | 3.1 | 1.5 | 0.2 | setosa |
5.0</td |
4.2. Train and Visualize a Decision Tree
Now that you have loaded and explored the Iris dataset, you can train and visualize a decision tree on it. To train a decision tree, you can use the DecisionTreeClassifier class from the scikit-learn library. This class allows you to specify various parameters to control the growth and complexity of the tree, such as the criterion, the maximum depth, the minimum samples per leaf, and the random state.
To visualize a decision tree, you can use the plot_tree function from the scikit-learn library. This function creates a matplotlib figure that shows the structure and the information of the tree, such as the feature, the threshold, the impurity, the samples, and the value at each node.
The following code shows how to train and visualize a decision tree on the Iris dataset:
# Import libraries from sklearn.tree import DecisionTreeClassifier, plot_tree # Split the data into features and target X = df[iris['feature_names']] y = df['target'] # Train a decision tree with default parameters dt = DecisionTreeClassifier(random_state=42) dt.fit(X, y) # Visualize the decision tree plt.figure(figsize=(12, 8)) plot_tree(dt, feature_names=iris['feature_names'], class_names=iris['target_names'], filled=True) plt.show()
As you can see, the decision tree has 9 leaf nodes and a depth of 5. It uses the petal length and the petal width features to split the data and classify the samples. The tree has a perfect accuracy on the training data, as it correctly predicts the class of each sample. However, this does not mean that the tree is optimal, as it might be overfitting the data. How can you check if the tree is overfitting and how can you improve it? This is what you will learn in the next section.
4.3. Apply Cost-Complexity Pruning and Compare the Results
In this section, you will apply cost-complexity pruning to the decision tree that you trained and visualized in the previous section. You will also compare the results of the pruned and the unpruned trees to see how cost-complexity pruning affects the performance and the complexity of the tree.
Cost-complexity pruning is a technique that reduces the size and complexity of a decision tree by removing nodes that have a low impact on the accuracy but a high impact on the complexity. The impact of a node is measured by a parameter called the alpha value, which represents the trade-off between accuracy and complexity. The lower the alpha value, the more nodes are pruned and the simpler the tree becomes. The higher the alpha value, the fewer nodes are pruned and the more complex the tree remains.
To apply cost-complexity pruning, you can use the cost_complexity_pruning_path method from the scikit-learn library. This method returns the alpha values and the corresponding impurities of the subtrees that can be obtained by pruning the original tree. You can then use the GridSearchCV class from the scikit-learn library to find the best alpha value that maximizes the cross-validation score of the pruned tree. Cross-validation is a technique that splits the data into multiple folds and evaluates the model on each fold using the rest of the data as the training set.
The following code shows how to apply cost-complexity pruning and compare the results:
# Import libraries from sklearn.tree import DecisionTreeClassifier, plot_tree from sklearn.model_selection import GridSearchCV # Split the data into features and target X = df[iris['feature_names']] y = df['target'] # Train a decision tree with default parameters dt = DecisionTreeClassifier(random_state=42) dt.fit(X, y) # Get the alpha values and the impurities of the subtrees path = dt.cost_complexity_pruning_path(X, y) ccp_alphas = path.ccp_alphas impurities = path.impurities # Plot the alpha values and the impurities plt.figure(figsize=(8, 6)) plt.plot(ccp_alphas, impurities) plt.xlabel('alpha') plt.ylabel('impurity') plt.title('Alpha vs Impurity') plt.show()
As you can see, the impurity decreases as the alpha value increases, which means that the tree becomes simpler and less overfitting. However, if the alpha value is too high, the tree might become too simple and underfitting. Therefore, you need to find the optimal alpha value that balances the accuracy and the complexity of the tree.
To find the optimal alpha value, you can use the GridSearchCV class to perform a grid search over a range of alpha values and find the one that maximizes the cross-validation score. You can also specify a scoring metric, such as accuracy, precision, recall, or f1-score, to evaluate the performance of the pruned tree. The following code shows how to find the optimal alpha value using GridSearchCV:
# Define a range of alpha values param_grid = {'ccp_alpha': ccp_alphas} # Define a scoring metric scoring = 'accuracy' # Perform a grid search over the alpha values gs = GridSearchCV(dt, param_grid, scoring=scoring, cv=5) gs.fit(X, y) # Get the best alpha value and the best score best_alpha = gs.best_params_['ccp_alpha'] best_score = gs.best_score_ # Print the best alpha value and the best score print(f'The best alpha value is {best_alpha:.4f}') print(f'The best {scoring} score is {best_score:.4f}')
The output of the code is:
The best alpha value is 0.0163 The best accuracy score is 0.9733
As you can see, the best alpha value is 0.0163, which gives the best accuracy score of 0.9733. This means that the pruned tree with this alpha value has the highest accuracy among all the possible subtrees. You can now train and visualize the pruned tree using this alpha value and compare it with the unpruned tree. The following code shows how to train and visualize the pruned tree:
# Train a decision tree with the best alpha value dt_pruned = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42) dt_pruned.fit(X, y) # Visualize the pruned decision tree plt.figure(figsize=(12, 8)) plot_tree(dt_pruned, feature_names=iris['feature_names'], class_names=iris['target_names'], filled=True) plt.show()
As you can see, the pruned tree has 5 leaf nodes and a depth of 3. It uses only the petal length feature to split the data and classify the samples. The tree has a slightly lower accuracy on the training data, as it misclassifies one sample of versicolor as virginica. However, this does not mean that the tree is worse, as it might be more generalizable and less variable on new and unseen data. To compare the performance and the complexity of the pruned and the unpruned trees, you can use the following metrics:
- The accuracy score on the training and the test data.
- The number of nodes and the depth of the tree.
- The time and the memory required to train and predict with the tree.
The following code shows how to compare the pruned and the unpruned trees using these metrics:
# Import libraries from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split import time import sys # Split the data into training and test sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Train and predict with the unpruned tree start = time.time() dt.fit(X_train, y_train) end = time.time() y_pred_train = dt.predict(X_train) y_pred_test = dt.predict(X_test) time_unpruned = end - start memory_unpruned = sys.getsizeof(dt) # Train and predict with the pruned tree start = time.time() dt_pruned.fit(X_train, y_train) end = time.time() y_pred_train_pruned = dt_pruned.predict(X_train) y_pred_test_pruned = dt_pruned.predict(X_test) time_pruned = end - start memory_pruned = sys.getsizeof(dt_pruned) # Compare the accuracy scores accuracy_train_unpruned = accuracy_score(y_train, y_pred_train) accuracy_test_unpruned = accuracy_score(y_test, y_pred_test) accuracy_train_pruned = accuracy_score(y_train, y_pred_train_pruned) accuracy_test_pruned = accuracy_score(y_test, y_pred_test_pruned) print(f'Accuracy on training data (unpruned): {accuracy_train_unpruned:.4f}') print(f'Accuracy on test data (unpruned): {accuracy_test_unpruned:.4f}') print(f'Accuracy on training data (pruned): {accuracy_train_pruned:.4f}') print(f'Accuracy on test data (pruned): {accuracy_test_pruned:.4f}') # Compare the number of nodes and the depth n_nodes_unpruned = dt.tree_.node_count n_nodes_pruned = dt_pruned.tree_.node_count depth_unpruned = dt.get_depth() depth_pruned = dt_pruned.get_depth() print(f'Number of nodes (unpruned): {n_nodes_unpruned}') print(f'Number of nodes (pruned): {n_nodes_pruned}') print(f'Depth of the tree (unpruned): {depth_unpruned}') print(f'Depth of the tree (pruned): {depth_pruned}') # Compare the time and the memory print(f'Time required (unpruned): {time_unpruned:.4f} seconds') print(f'Time required (pruned): {time_pruned:.4f} seconds') print(f'Memory required (unpruned): {memory_unpruned} bytes') print(f'Memory required (pruned): {memory_pruned} bytes')
The output of the code is:
Accuracy on training data (unpruned): 1.0000 Accuracy on test data (unpruned): 1.0000 Accuracy on training data (pruned): 0.9917 Accuracy on test data (pruned): 1.0000 Number of nodes (unpruned): 17 Number of nodes (pruned): 9 Depth of the tree (
5. Conclusion
In this blog, you learned about cost-complexity pruning, a technique that reduces the size and complexity of a decision tree by removing nodes that have a low impact on the accuracy but a high impact on the complexity. You also learned how to apply cost-complexity pruning to the Iris dataset and compare the results of the pruned and the unpruned trees.
Here are some key points that you learned:
- Decision trees are hierarchical models that can be used for classification and regression tasks.
- Decision trees can overfit the data, which means that they learn the noise and the specific patterns of the training data, but fail to generalize well to new and unseen data.
- Cost-complexity pruning is a technique that finds the optimal subtree of a decision tree that minimizes a trade-off between accuracy and complexity.
- The trade-off between accuracy and complexity is measured by a parameter called the alpha value, which represents the impact of a node on the tree.
- The optimal alpha value can be found by using cross-validation, which splits the data into multiple folds and evaluates the model on each fold using the rest of the data as the training set.
- Cost-complexity pruning can improve the performance and the efficiency of the decision tree by reducing overfitting and variance.
We hope that you enjoyed this blog and learned something new and useful. If you have any questions or feedback, please feel free to leave a comment below. Thank you for reading!