Learn
Decision Trees
Recursive Tree Building

Now that we can find the best feature to split the dataset, we can repeat this process again and again to create the full tree. This is a recursive algorithm! We start with every data point from the training set, find the best feature to split the data, split the data based on that feature, and then recursively repeat the process again on each subset that was created from the split.

We’ll stop the recursion when we can no longer find a feature that results in any information gain. In other words, we want to create a leaf of the tree when we can’t find a way to split the data that makes purer subsets.

The leaf should keep track of the classes of the data points from the training set that ended up in the leaf. In our implementation, we’ll use a Counter object to keep track of the counts of labels.

We’ll use these counts to make predictions about new data that we give the tree.

Instructions

1.

We’ve given you the function find_best_split() that takes a set of data points and a set of labels.

The function returns the index of the feature that causes the best split and the information gain caused by that split.

For now, at the bottom of your code, call this function using car_data and car_labels as parameters and store the values in variables named best_feature and best_gain.

Print those two variables. What was the best feature to split on and what was the information gain?

2.

Let’s create a function called build_tree() that takes data and labels as parameters.

Move your call of find_best_split() inside this function, but change the parameters from car_data and car_labels to data and labels.

If best_gain is 0, return a Counter object of labels. We’ve reached the base case — there’s no way to gain any more information so we want to create a leaf.

3.

After the if statement, we want to start working on the recursive case.

In the recursive case, we want to split the data into subsets using the best feature, and then recursively call the build_tree() function on those subsets to create subtrees. Finally, we want to return a list of all those subtrees.

Let’s begin by splitting the data. You can do this by using the split() function which takes three parameters — the data and labels that you want to split and the index of the feature you want to split on.

Store the result of the split() function in two variables named data_subsets and label_subsets.

For now, return data_subsets at the bottom of your function.

4.

Before that final return statement, create an empty list named branches. This list will store all of the subtrees we’re about to make from our recursive calls.

We now want to loop through all of the subsets of data and labels. Set up your for loop like this

for i in range(len(data_subsets)):

Inside the for loop, call build_tree using data_subsets[i] and label_subsets[i] as parameters and append the result to branches.

Finally outside the for loop, return branches instead of data_subsets.

5.

Let’s test our function! At the bottom of your code outside of your function definition, call build_tree() using car_data and car_labels as parameters and store the result in a variable named tree.

We’ve written a function called print_tree() that will help you visualize the tree. Call print_tree() using tree as a parameter.

Folder Icon

Sign up to start coding

Already have an account?