A python implementation of the decision tree

  • 2020-04-02 14:21:34
  • OfStack

This article illustrates a python implementation of the decision tree. Share with you for your reference. The specific implementation method is as follows:

Advantages and disadvantages of decision tree algorithm:

Advantages: the computational complexity is not high, the output is easy to understand, is not sensitive to the middle value missing, can deal with irrelevant feature data

Cons: there may be an overmatch problem

Applicable data types: numeric and nominal

Algorithm idea:

1. Overall idea of decision tree construction:

Decision tree to put it bluntly as if - else structure, its result is that you want to generate the a continuing to judge selection can begin from the root leaf nodes of the tree, but here is the if - else will not let us think to set, we need to do is to provide a kind of method, a computer can according to the decision tree method to get what we need. The focus of this method is how to select the valuable features from so many, and choose them in the best order from root to leaf. Once we've done that we can recursively construct a decision tree

2. Information gain

The biggest principle of partitioning data sets is to make the disordered data more orderly. Since this involves the problem of order and disorder of information, it is natural to think about the entropy of information. Here we calculate information entropy (another method is gini impurity). The formula is as follows:

Data requirements to be met:

The data must be a list of list elements, and all the elements must have the same data length
The last column of the data or the last element of each instance should be the category label of the current instance

Function:

CalcShannonEnt (dataSet)
The shannon entropy of the data set is calculated in two steps. The first step is to calculate the frequency, and the second step is to calculate the shannon entropy according to the formula

SplitDataSet (dataSet, aixs, value)
Divide the data set, divide the values that satisfy X[aixs]==value together, and return a divided set (not including the aixs attribute used for partitioning, because it is not needed).

ChooseBestFeature (dataSet)
I'm going to choose the best attribute to divide, and the idea is very simple is to divide each attribute and see which one is good. A set is used to select the only element in the list, which is a quick way to do this

MajorityCnt (classList)
Since we recursively build the decision tree based on the consumption of attributes, there may be a last attribute used up, but the classification is not finished, at which point the node classification will be calculated by majority vote

CreateTree (dataSet, labels)
Build a decision tree based on recursion. The label here is more the name of the classification feature, in order to better look and understand the following.


#coding=utf-8
import operator
from math import log
import time def createDataSet():
    dataSet=[[1,1,'yes'],
            [1,1,'yes'],
            [1,0,'no'],
            [0,1,'no'],
            [0,1,'no']]
    labels = ['no surfaceing','flippers']
    return dataSet, labels # Calculate shannon's entropy
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for feaVec in dataSet:
        currentLabel = feaVec[-1]
        if currentLabel not in labelCounts:
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
   
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1# Because the last term in the data set is the label
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy -newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
           
# Because we recursively build the decision tree by calculating the cost of the attributes, there may be a last attribute that runs out, but a classification
# Still not done, the node classification is calculated by majority vote
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    return max(classCount)        
   
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) ==len(classList):# If the categories are the same, the division is stopped
        return classList[0]
    if len(dataSet[0]) == 1:# All the features are used up
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]# The original list was copied so as not to change its contents
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,
                                        bestFeat, value),subLabels)
    return myTree
   
def main():
    data,label = createDataSet()
    t1 = time.clock()
    myTree = createTree(data,label)
    t2 = time.clock()
    print myTree
    print 'execute for ',t2-t1
if __name__=='__main__':
    main()

I hope this article has helped you with your Python programming.


Related articles: