Exercise: A Decision Tree in scikit-learn

Learn how to model a decision tree on our case study data and visualize it using graphviz.

Modeling a decision tree on the case study data

In this exercise, we will use the case study data to grow a decision tree, where we specify the maximum depth. We’ll also use some handy functionality to visualize the decision tree, in the form of the graphviz package. Perform the following steps to complete the exercise:

  1. Load several of the packages that we’ve been using, and an additional one, graphviz, so that we can visualize decision trees:

    import numpy as np #numerical computation 
    import pandas as pd #data wrangling 
    import matplotlib.pyplot as plt #plotting package 
    #Next line helps with rendering plots 
    %matplotlib inline 
    import matplotlib as mpl #additional plotting functionality 
    mpl.rcParams['figure.dpi'] = 400 #high res figures 
    import graphviz #to visualize decision trees 
    
  2. Load the cleaned case study data:

    df = pd.read_csv('Chapter_1_cleaned_data.csv')
    
  3. Get a list of column names of the DataFrame:

    features_response = df.columns.tolist()
    
  4. Make a list of columns to remove that aren’t features or the response variable:

    items_to_remove = ['ID', 'SEX', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'EDUCATION_CAT',\
    'graduate school', 'high school', 'none', 'others', 'university']
    
  5. Use a list comprehension to remove these column names from our list of features and the response variable:

    features_response = [item for item in features_response if item not in items_to_remove] 
    features_response
    

    This should output the list of features and the response variable:

    ['LIMIT_BAL', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_1', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'default payment next month']
    

    Now the list of features is prepared. Next, we will make some imports from scikit-learn. We want to make a train/test split, which we are already familiar with. We also want to import the decision tree functionality.

  6. Run this code to make imports from scikit-learn:

    from sklearn.model_selection import train_test_split 
    from sklearn import tree
    

    The tree library of scikit-learn contains decision tree-related classes.

  7. Split the data into training and testing sets using the same random seed that we have used throughout the course:

    X_train, X_test, y_train, y_test = train_test_split(df[features_response[:-1]].values,\
    df['default payment next month'].values, test_size=0.2, random_state=24)
    

    Here, we use all but the last element of the list to get the names of the features, but not the response variable: features_response[:-1]. We use this to select columns from the DataFrame, and then retrieve their values using the .values method. We also do something similar for the response variable, but specify the column name directly. In making the train/test split, we’ve used the same random seed as in previous work, as well as the same split size. This way, we can directly compare the work we will do in this section with previous results. Also, we continue to reserve the same “unseen test set” from the model development process.

    Now we are ready to instantiate the decision tree class.

  8. Instantiate the decision tree class by setting the max_depth parameter to 2:

    dt = tree.DecisionTreeClassifier(max_depth=2)
    

    We have used the DecisionTreeClassifier class because we have a classification problem. Because we specified max_depth=2, when we grow the decision tree using the case study data, the tree will grow to a depth of at most 2. Let’s now train this model.

  9. Use this code to fit the decision tree model and grow the tree:

    dt.fit(X_train, y_train)
    

    This should display the following output:

    DecisionTreeClassifier(max_depth=2)
    

    Now that we have fit this decision tree model, we can use the graphviz package to display a graphical representation of the tree.

  10. Export the trained model in a format that can be read by the graphviz package using this code:

    dot_data = tree.export_graphviz(dt, out_file=None, filled=True, rounded=True,\
    feature_names = features_response[:-1], proportion=True, class_names=[ 'Not defaulted', 'Defaulted'])
    

    Here, we’ve provided a number of options for the .export_graphviz method. First, we need to say which trained model we’d like to graph, which is dt. Next, we say we don’t want an output file: out_file=None. Instead, we provide the dot_data variable to hold the output of this method.

    The rest of the options are set as follows:

    • filled=True: Each node will be filled with a color.

    • rounded=True: The nodes will appear with rounded edges as opposed to rectangles.

    • feature_names=features_response[:-1]: The names of the features from our list will be used as opposed to generic names such as X[0].

    • proportion=True: The proportion of training samples in each node will be displayed (we’ll discuss this more later).

    • class_names=['Not defaulted', 'Defaulted']: The name of the predicted class will be displayed for each node.

    What is the output of this method?

    If you examine the contents of dot_data, you will see that it is a long text string. The graphviz package can interpret this text string to create a visualization.

  11. Use the .Source method of the graphviz package to create an image from dot_data and display it:

    graph = graphviz.Source(dot_data) 
    graph
    

    The output should look like this:

Get hands-on with 1200+ tech skills courses.