Skip to Content
Learn
Decision Trees
Decision Trees in scikit-learn

Nice work! You’ve written a decision tree from scratch that is able to classify new points. Let’s take a look at how the Python library scikit-learn implements decision trees.

The sklearn.tree module contains the DecisionTreeClassifier class. To create a DecisionTreeClassifier object, call the constructor:

classifier = DecisionTreeClassifier()

Next, we want to create the tree based on our training data. To do this, we’ll use the .fit() method.

.fit() takes a list of data points followed by a list of the labels associated with that data. Note that when we built our tree from scratch, our data points contained strings like "vhigh" or "5more". When creating the tree using scikit-learn, it’s a good idea to map those strings to numbers. For example, for the first feature representing the price of the car, "low" would map to 1, "med" would map to 2, and so on.

classifier.fit(training_data, training_labels)

Finally, once we’ve made our tree, we can use it to classify new data points. The .predict() method takes an array of data points and will return an array of classifications for those data points.

predictions = classifier.predict(test_data)

If you’ve split your data into a test set, you can find the accuracy of the model by calling the .score() method using the test data and the test labels as parameters.

print(classifier.score(test_data, test_labels))

.score() returns the percentage of data points from the test set that it classified correctly.

Instructions

1.

We’ve imported the full car dataset and split it into a training and test set. We’ve also mapped the features that were strings like "vgood" to numbers.

Print training_points[0] and training_labels[0] to see the first car in the training set.

2.

Create a DecisionTreeClassifier and name it classifier.

3.

Build the tree using the training data by calling the .fit() method. .fit() takes two parameters — the training data and the training labels.

4.

Test the decision tree on the testing set and print the results. How accurate was the model?

Folder Icon

Sign up to start coding

Already have an account?