第 3 章 决策树
[TOC]
本章内容
- 决策树简介
- 在数据集中度量一致性
- 使用递归构造决策树
- 使用 Matplotlib 绘制树形图
与 k-近邻算法相比,决策树 的主要优势在于数据形式非常容易理解
1. 决策树的构造
决策树 :
- 优点:计算复杂度不高,输出结果易于理解,对中间值得缺失不敏感,可以处理不相关特征数据
- 缺点:可能会产生过度匹配问题
- 适用数据类型:数值型和标称型
本节将一步步地构造决策树算法,首先我们讨论数学上如何使用 信息论 划分数据集,然后编写代码将理论应用到具体的数据集上,最后编写代码构建决策树。
- 在构造决策树时,我们需要解决的第一个问题就是,当前数据集上那个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。
- 完成测试后,原始数据就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。
- 如果某个分支下的数据属于同一类型,则到这里以及正确地划分数据分类,无序进一步对数据集进行分割。
- 如果数据子集内的数据不属于同一类型,则需要重复划分数据子集的过程。
- 如何划分数据子集的算法和划分原始数据集的方法相同,直到所有具有相同类型的数据均在一个数据子集内。
创建分支的伪代码函数 createBranch() 如下:
检测数据集中的每一个子项是否属于同一分类:
if so return 类标签;
else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数 createBranch 并增加返回结果到分支节点中
return 分支节点
决策树的一般流程:
- 收集数据
- 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
- 分析数据:可以使用任何方法,构造树完成之后,应该检查图形是否符合预期
- 训练算法:使用经验树计算错误率
- 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义
如果依据某个属性划分法数据将会产生 4 个可能的值,我们将把数据划分为 4 块,并 创建 4 个不同的分支。这里将使用 ID3 算法划分数据集。
问题 :每次划分数据集时我们只选取一个特征属性,如果训练集中存在 20 个特征,第一次我们选择哪个特征作为话的参考属性呢?
表3-1 的数据包含 5 个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。我们可以将这些动物分成两类:鱼类和非鱼类。在回答这个问题之前,我们必须采用量化的方法判断如何划分数据:
1.1 信息增益
划分数据集的大原则是:将无需的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据之前使用信息论量化度量信息的内容。
在划分数据集之前之后信息发生的变化称为 信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择 。
在可以评测哪种数据划分方式就是最好的数据划分之前,必须学习如何计算信息增益。集合信息的度量方式称为香农熵(information gain) 或者简称为 熵(entropy) 。
熵(entropy),定义为信息的期望值。
信息:如果待分类的事务可能划分在多个分类之中,则符号
的信息定义为
其中
是选择该分类的概率。
为了计算熵,需要计算所有类别所有可能只包含的信息期望值,通过下面的公式得到:
其中 n 是分类的数目。
创建名为 s_2_tree.py 的文件,添加 calcShannonEnt 函数,其功能是计算给定数据集的熵:
import math
def calcShannonEnt(dataSet):
"""
计算给定数据集的香农熵
:param dataSet: 数据集
:return: 香农熵
"""
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # 可能值
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 计数加一
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries # 可能值的期望
shannonEnt -= prob * math.log(prob, 2) # 熵
return shannonEnt
熵越高,则混合的数据也越多,我们可以在数据集中添加更多的分类,观察熵是如何变化的。
得到熵之后,就可以按照获取最大信息增益的方法划分数据集。另一个度量集合无序程度的方法是 基尼不纯度(Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。
1.2 划分数据集
添加 splitDataSet 函数,按照给定特征划分数据集:
def splitDataSet(dateSet, axis, value):
"""
按照给定特征划分数据集
:param dateSet: 待划分的数据集
:param axis: 划分数据集的特征
:param value: 特征的返回值
:return:划分完的数据集
"""
retDataSet = [] # 因python不用考虑内存分配问题,在函数中传递的是列表的引用,所以需声明一个新列表对象
for featVec in dateSet:
if featVec[axis] == value: # 该特征值等于判断值
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:]) #上一步和这一步是排除掉特征值
retDataSet.append(reducedFeatVec) # 加入返回的数据集
print(retDataSet)
return retDataSet
接下来将遍历整个数据集,循环计算香农熵和 splitDataSet 函数,找到最好的特征划分方式。熵计算将会告诉我们如何划分数据集时最好的数据组织方式。
添加 chooseBestFeatureToSplit 函数,选择最好的数据集划分方式:
def chooseBestFeatureToSplit(dataSet):
"""
选择最好的数据集划分方式
:param dataSet: 待划分的数据集
:return: 最好的特征
"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet) # 原始的香农熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
# 1.创建唯一的分类标签列表
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 2.计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # 概率
newEntropy += prob * calcShannonEnt(subDataSet) # 香农熵,也就是信息量
infoGain = baseEntropy - newEntropy # 该特征的熵
# 3.计算最好的信息增益
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
测试上面代码的实际输出结果:
myDataSet,labels = createDataSet()
bestFeature = chooseBestFeatureToSplit(myDataSet)
print(bestFeature) # 输出:0
运行结果告诉我们,第 0 个特征是最好的用于划分数据集的特征。
1.3 递归构建决策树
从数据集构造决策树算法所需要的子功能模块,其工作原理如下:
- 得到原始数据集
- 基于最好的属性值划分数据集,由于特征值可能多余两个,因此可能存在大于两个分支的数据集划分。
- 第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此可以采用递归的原则处理数据集。
- 递归结束的条件是程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类,参见图3-2。
- 如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定叶子节点的分类。
添加 majorityCnt 函数,多数表决来决定叶子节点的分类:
def majorityCnt(classList):
"""
多数表决判断分类
:param classList:分类名称的列表
:return: 表决得到的分类名称
"""
# 1.分类计数
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
# 2.排序,取出次数最多的分类名称
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
添加 createTree 函数,创建树:
def createTree(dataSet, labels):
"""
递归创建树
:param dataSet: 数据集
:param labels: 标签集
:return:树
"""
# 1. 取出所有类别
classList = [example[-1] for example in dataSet]
# 2. 判断:类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 3. 遍历完所有特征时返回出现次数最多的
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 4. 取出最好的分类特征
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLable = labels[bestFeat]
myTree = {bestFeatLable:{}}
del(labels[bestFeat])
# 5. 得到列表包含的所有属性值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
# 6.递归创建树
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
2. 在 python 中使用 Matplotlib 注解绘制树形图
使用 Matplotlib 可创建树形图,决策树的主要优点就是直观易于理解,如果不能将其直观地显示出来,就无法发挥其优势。
2.1 Matplotlib 注解
Matplotlib 提供了一个 **注解工具 annotations **,可以在数据图形上添加文本注释。注解通常用于解释数据的内容
下面是 treePlotter.py 的代码(使用文本注解绘制树节点):
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制 parentPt 指向 centerPt 带箭头的线,箭头节点的文本为 nodeTxt
:param nodeTxt:节点文本
:param centerPt:子节点
:param parentPt:父节点
:param nodeType:节点样式
:return:
"""
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
xytext=centerPt, textcoords='axes fraction',\
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.axl = plt.subplot(111, frameon=False)
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
这是第一个版本的 createPlot 函数,与例子文件中的 createPlot 函数有些不同,随着内容的深入,我们将逐步添加缺失的代码。代码定义了树节点格式的常量。然后定义 plotNode 函数执行了实际的绘图功能,该函数需要一个绘图区,该区域由全局变量 createPlot.ax1 定义。python 语言中所有的变量默认都是全局有效的,只要我们清楚知道当前代码的主要功能,并不会引入太大的麻烦。最好定义了 createPlot 函数,它是这段代码的核心。createPlot 函数首先创建了一个新图形并清空绘图区,然后再绘图区上绘制了两个代表不同类型的树节点,后面我们将用这两个节点绘制树图形。
2.2 构造注解树
添加 getNumLeafs 函数,获取叶节点的数目:
def getNumLeafs(myTree):
"""
获取叶节点的数目
:param myTree: 树
:return: 叶节点的数目
"""
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
添加 getTreeDepth 函数,获取树的层数
def getTreeDepth(myTree):
"""
获取树的层数
:param myTree:树
:return: 层数
"""
maxDepth = 0
firstStr = list(myTree.keys())[0] # python3 中 dict.keys 返回的是 dict_keys 对象,支持 iterable,但不支持 indexable,所以要转换成list
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 +getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth : maxDepth = thisDepth
return maxDepth
为了节省时间,函数 retrieveTree 输出预先存储的树信息,避免每次测试都要创建树的麻烦:
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0:{'head': {0:'no', 1: 'yes'}}, 1:'no'}}}}
]
return listOfTrees[i]
接下来是主要的绘图部分:
添加 plotNode 函数,绘制节点:
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 注解样式
leafNode = dict(boxstyle="round4", fc="0.8") # 注解样式
arrow_args = dict(arrowstyle="<-") # 箭头样式
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制 parentPt 指向 centerPt 带箭头的线,箭头节点的文本为 nodeTxt
:param nodeTxt:节点文本
:param centerPt:子节点
:param parentPt:父节点
:param nodeType:节点样式
:return:
"""
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
添加 plotMidText 函数:
def plotMidText(cntrPt, parentPt, txtString):
"""
计算父节点和子节点的中间位置,并添加文本标签信息
:param cntrPt:子节点位置
:param parentPt:父节点位置
:param txtString:文本标签信息
:return:
"""
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
添加 plotTree 函数,绘制树形图:
def plotTree(myTree, parentPt, nodeTxt):
"""
绘制树形图
:param myTree:树
:param parentPt:父节点位置
:param nodeTxt:节点文本
:return:
"""
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # 找到第一个元素,根节点
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # 节点位置
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr] # 获取节点下的内容
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 减少 y 的值,将树的总深度平分,向下移动(树是向下绘制)
for key in secondDict.keys(): # 键值:0、1
if type(secondDict[key]) == dict: # 判断是 dict 还是 value
plotTree(secondDict[key], cntrPt, str(key)) # 递归调用
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 更新 x 值
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
添加 createPlot 函数,创建绘图:
def createPlot(inTree):
"""
创建绘图
:param inTree:
:return:
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
axprps = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprps) # 定义绘图区
plotTree.totalW = float(getNumLeafs(inTree)) # 存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
# 使用了这两个全局变量追踪已经绘制的节点位置,以及放置下一个节点的恰当位置
plotTree.xOff = -0.5/plotTree.totalW # 存储树在 x 轴的偏移
plotTree.yOff = 1.0 # 存储树在 y 轴的偏移
plotTree(inTree, (0.5,1.0), ' ')
plt.show()
- 函数 createPlot() 是使用的主函数,它调用了 plotTree() ,函数 plotTree() 又依次调用了前面的函数和 plotMidText() 。
- 函数 plotTree() 是递归函数:
- 计算树的宽和高。全局变量 plotTree.totalW 、plotTree.totalD 存储树的宽度、深度,使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。
- 树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间,而不仅仅是它子节点的中间。
- 同时我们使用两个全局变量 plotTree.xOff 、plotTree.yOff 追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。
- 另一个需要说明的问题是,绘制图形的 x 轴有效范围是 0.0 到 1.0,y 轴有效范围也是 0.0 到 1.0。
- 通过计算树包含的所有叶子节点数,划分图形的宽度,从而计算得到当前节点的中心位置,也就是说,我们按照叶子节点的数目将 x 轴划分为若干部分。按照图形比例绘制树形图的最大好处是无序关心实际输出图形的大小,一旦图形大小发生了变化,函数会自动按照图形大小重新绘制。
- 绘制出子节点具有的特征值,或者沿此分支向下的数据实例必须具有的特征值。
- 使用 plotMidText() 计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息
- 按比例减少全局变量 plotTree.yOff,并标注此处将要绘制子节点,这些节点即可以是叶子节点也可以是判断节点,此处需要只保存绘制图形的轨迹
- 采用 getNumLeafs() 和 getTreeDepth() 以相同的方式递归遍历整棵树,如果节点是叶子节点则在图形画出叶子节点,如果不是叶子节点则递归调用 plotTree() 函数。在绘制了所有子节点之后,增加全局变量 Y 的偏移
3. 测试和存储分类器
3.1 测试算法:使用决策树执行分类
依靠训练数据构造了决策树之后,可以将它用于实际数据的分类:
- 在执行数据分类时,需要决策树以及用于构造树的标签向量;
- 程序比较测试数据域决策树上的数值,递归执行该过程直到进入叶子节点;
- 将测试数据定义为叶子节点所属的类型
添加 classify() 函数,使用决策树进行分类:
def classify(inputTree, featLabels, testVec):
"""
使用决策树分类
:param inputTree:树
:param featLabels: 标签集
:param testVec:测试向量
:return:分类名称
"""
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]) == dict:
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
3.2 使用算法:决策时的存储
为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。需要使用 python 模块 pickle 序列化对象,然后在磁盘上保存对象,并在需要的时候取出来。
使用 pickle 模块存储决策树:
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def gradTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
通过上面的代码,我们可以将分类器存储在硬盘上,而不用每次对数据分类时重新学习一遍,这也是决策树的优点之一,像 k-近邻算法就无法持久化分类器。可以余弦提炼并存储数据集中包含的知识信息,在需要对事物进行分类时再使用这些知识。
4. 示例:使用决策树预测隐形眼镜类型
本节将通过一个例子讲解决策树如何预测患者需要佩戴的隐形眼镜类型。
示例:使用决策树预测隐形眼镜类型:
- 收集数据:提供的文本文件
- 准备数据:解析 tab 键分隔的数据行
- 分析数据:快速检查数据,确保正确地解析数据内容,使用 createPlot() 函数绘制最终的树形图
- 训练算法:使用 第 1 节的 createTree() 函数
- 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
- 使用算法:存储树的数据结构,以便下次使用时无需重新构造树
训练代码:
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses,lensesLabels)
createPlot(lensesTree)
本章使用的算法称为 ID3,是一个好的算法但是并不完美。ID3 无法直接处理数值型数据,尽管可以通过量化的方法将数值型数据转化为标称型数值,但是如果蔡遵太多的特征划分,ID3 仍然会面临其他问题。
5. 本章小结
决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,首先需要测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据,直到数据集中的所有数据属于同一分类。ID3 算法可以用于划分标称型数据集。构建决策树时,通常采用递归的方法将数据集转化为决策树。