博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
decisionTree填坑记
阅读量:7108 次
发布时间:2019-06-28

本文共 13983 字,大约阅读时间需要 46 分钟。

hot3.png

定义

其实决策树的定义用一个邮件分类系统去诠释一下。

if 我们收到一个来自邮件域为funny.com的邮件 :
    我们将这份邮件甩到无聊的时候才会去读的栏目中
elif 收到的邮件中包含字眼“爱人” :
    赶紧放到非常重要并且非常迅速回复邮件
else:
    什么玩意儿,拒绝浏览

哈哈,这就是决策树的一个比较好理解的例子啦,有点类似于 程序流程图哈(不过这是树结构)

优点

计算复杂度不高,输出的结果便于理解,而且对中间值的缺失不敏感,可以处理不相关特征数据

缺点

会产生过度匹配的问题,这个时候我们就需要进行“剪枝”等一系列操作咯

一般流程

  1. 收集数据:可以使用任何方法
  2. 准备数据:树构造算法只适用于标称型数据,如果是数值型必须得离散化
  3. 分析数据:可以使用任何方法,构造完树后,应该检查图形是否符合预期
  4. 训练数据:构造树的数据结构
  5. 测试算法:使用经验树计算一个错误率
  6. 使用算法:这个方法可以适用于任何监督学习算法,而使用决策树可以更好地理解数据结构的内在含义

ID3算法

粗略查看我们的ID3的创建者Ross·Quinlan的解释,我翻译了一下

ID3算法首先是将原始的数据集S当成一个树根,然后通过算法每次迭代计算出S中没用被使用过的属性的entropy(熵) H(S)或者Infomation gain(信息增益) IG(S)。接着我们选择最小熵或者最大信息增益的属性去分割S产生新的子数据集合。然后算法继续递归每个子集合中的没有被使用过的属性。当出现以下三种情况结束递归:
第一种,子集中每个元素的属性属于相同的类型,节点变成叶子节点并被打上示例中类型的标签
第二种,没有更多的属性被选择,这个时候子集中仍然存在不同的属性,这个时候该结点也变成叶子节点同时被打上示例中最常见的属性标签
第三种,父集合中没有示例与所选属性的特殊值匹配(比如age>100没有这个示例),那么就会创建这个叶子节点并且打上在父集合的示例中最常见的属性标签

总结

第一步,计算S中的每个正在使用的属性的熵值

第二步,选择熵最小或者信息收益最大的属性切割S
第三步,创建包含这个属性的决策树节点
第四步,使用剩下的属性去递归子集

用法

通过训练数据集S去产生一个储存在内存中的决策树,在运行时,决策树用来识别新的不被看见的测试用例,通过使用这个测试用例值的决策树去处理如何抵达最终的节点并且告诉你这个测试用例属于哪个分类

ID3指标

Entrop:     H(S)

S:被计算的数据集S
X:S中的所有类别集合
p(x): 类别x在集合S中出现的比例

H(S)=0的时候说明已经被完美识别了,如果这个值越大说明改进识别能力的潜在可能性就越高

Information gain :         IG(A,S)

 

H(S): 集合S的熵

T:被属性A切割的子集
p(t): t在集合S中出现的比例
H(t):子集t的熵

案例

使用决策树去预测隐形眼镜类型

描述

我们通过使用一个小数据集,帮助眼科医生为患者选择需要佩带的眼镜片类型

步骤

  1. 收集数据:提供的文本文件
  2. 准备数据:解析tab健分隔的数据行
  3. 分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot函数去绘制最终的树形图
  4. 训练数据:使用createTree()函数
  5. 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
  6. 使用算法:储存树的数据结构,以便下次使用的时候不用再重新构造树结构

代码

在编码的过程中,风采依旧,恩,入坑了 当然已经填上啦

#!/usr/bin/env python3# -*- coding: utf-8 -*-' a DT module '__author__ = 'OJ cheng'from math import logimport operatorimport matplotlib.pyplot as pltimport plotTreeimport pickle"""如何去划分数据集,将无序的数据变得更加有序当然方法有很多,这里我们使用信息论量化度量信息的内容信息增益指的是在划分数据之前之后信息发生的变化"""#第一步:计算香农熵#--------------------------------函数开始---------------------def calcShannonEnt(dataSet):	#数据的总长度	numEntries = len(dataSet)	#定义一个字典:用来对每个标签进行计数	labelCounts = {}	#迭代数据集从中获取键值	for featVec in dataSet:		#当前标签的键值是最后一列的值		#比如:第一行数据最后一列“属于鱼类”的值是:“是的”		currentLabel = featVec[-1]		#如果当前键值不存在,将其添加到字典中		#字典中每个键值表示该标签出现的次数		if currentLabel not in labelCounts.keys():			labelCounts[currentLabel] = 0		labelCounts[currentLabel] += 1	#开始计算香农熵	shannonEnt = 0.0	for key in labelCounts:		#计算每个标签出现的频率		prob = float(labelCounts[key])/numEntries		#计算香农熵		shannonEnt -= prob*log(prob,2)	return shannonEnt#--------------------------------函数结束---------------------"""测试以上计算香农熵的函数"""#--------------------------------函数开始---------------------def createDataSet():	#根据书中表数据创建一下数据集	dataSet = [		[1,1,'yes'],		[1,1,'yes'],		[1,0,'no'],		[0,1,'no'],		[0,1,'no']	]	labels = ['no surfacing','flippers']	return dataSet,labels#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------myDat,labels = createDataSet()print(len(myDat[0]))# print(calcShannonEnt(myDay))#--------------------------------测试结束---------------------"""计算出香农熵后,我们需要去计算一下IG值在计算之前,我们又需要考虑:如果在一个分布在二维空间上的数据散点图,我们又如何去划分成两部分呢?按照x轴还是y轴呢?"""#第二步,按照给定特征划分数据集#--------------------------------函数开始---------------------"""参数:dataSet:待划分数据集,value值也是一个列表axis:划分数据集的特征value:特征值的返回值用到的函数:-append:在列表末尾添加新的对象好比,a=[1,2,3] b=[4,5,6]a.append(b) #a结果为[1,2,3,[4,5,6]]-extend:在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)好比,a=[1,2,3] b=[4,5,6]a.extend(b) #a结果为[1,2,3,4,5,6]在这里我们需要注意的是:python不用去考虑内存分配的问题,但是在函数传递过程中该参数是一个列表的引用,所以在函数体内修改相关数据后,原始列表数据也会变化"""#--------------------------------函数开始---------------------def splitDataSet(dataSet,axis,value) :	#定义一个临时列表,防止修改原始数据列表	tempDataSet = []	#遍历数据集中的每个元素	for featVec in dataSet:		#发现符合要求的特征值后,抽取出来并且添加到临时列表		if featVec[axis] == value:			reducedFeatVec = featVec[:axis]			reducedFeatVec.extend(featVec[axis+1:])			tempDataSet.append(reducedFeatVec)	return tempDataSet#--------------------------------函数结束---------------------"""测试以上splitDataSet()函数"""#--------------------------------测试开始---------------------# myDat,labels = createDataSet()# print(splitDataSet(myDat,0,1))# print(splitDataSet(myDat,0,0))#--------------------------------测试结束---------------------"""知道如何去切割数据集后,我们需要去计算最大IG去切割数据集"""#第三步,选择最好的数据集切割方式"""chooseBestFeatureToSplit()函数需要注意的是:dataSet必须满足两个要求:	①一种由列表元素构成的列表且列表元素的长度得一致	②数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签"""#--------------------------------函数开始---------------------def chooseBestFeatureToSplit(dataSet):	#初始化bestIG和bestFeature	bestIG = 0.0	bestFeature = -1	#这里可能有的人不理解,其实就是求有多少个特征,因为dataSet是由列表元素构成的列表嘛,请注意,最后一列是类标签。	#在这里写dataSet[1]等等都是可以的,不过感觉怪怪的哈	numFeatures = len(dataSet[0]) - 1	#求出原始的香农熵	baseEntrop = calcShannonEnt(dataSet)	#迭代每个列表元素	for i in range(numFeatures):		#分别迭每个代列表元素的特征值		featList = [example[i] for example in dataSet]		#获取唯一的特征集合		uniqueVals = set(featList)		#计算IG		newEntrop = 0.0		for value in uniqueVals:			subDataSet  = splitDataSet(dataSet,i,value)			prob = len(subDataSet) / float(len(dataSet))			newEntrop += prob*calcShannonEnt(subDataSet)		infoGain = baseEntrop - newEntrop		if(infoGain > bestIG):			bestIG = infoGain			bestFeature = i	return bestFeature#--------------------------------函数结束---------------------"""测试一下chooseBestFeatureToSplit()"""#--------------------------------测试开始---------------------# myDat,labels = createDataSet()# print(chooseBestFeatureToSplit(myDat)) #告诉我们用第0个特征去划分数据集最好#--------------------------------测试结束---------------------"""根据ID3算法结束的条件,其中有一条就是当数据集已经处理了所有的属性,但是类标签依然不是唯一的,这个时候我们采用多数表决的方法决定叶子节点的分类其实换句话说:找到分类名出现次数最多的并返回"""#第四步:多数表决方法#--------------------------------函数开始---------------------def majorityCnt(classList):	classCount = {}	for vote in classList:		if vote not in classCount.keys():  classCount[vote] = 0		classCount[vote] += 1	sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)	return sortedClassCount[0][0]#--------------------------------函数结束---------------------"""有了之前的函数的铺垫,我们现在开始创建一个决策树参数:dataSet:数据集labels:标签列表(包含了数据集中的所有特征的标签)算法本身不需要这个labels参数,主要是为了给出数据明确的含义下面函数使用的函数:-count():在list中统计括号中的值出现的次数,这其中有三个参数,想要知道详细的自己查一下哈,在这里我只介绍一下这里的用法比如:a=[1,2,1,3,4,1] a.count(1) #结果为2注意:使用该函数的仍然需要满足前面提到的两个条件"""#第五步:构建决策树#--------------------------------函数开始---------------------def createTree(dataSet,labels):	#创建数据集中所有的类标签列表	#对于书中例子而言,最后一列便是类标签	classList = [example[-1] for example in dataSet]	#如果所有的类标签都一样,返回该类标签	if classList.count(classList[0]) == len(classList) :		return classList[0]	#如果所有特征都使用过,但是仍然不能划分出唯一类别的分组【这儿可以换成len(labels) == 1  这样更好理解】	#则进行多数表决	if len(dataSet[0]) == 1 :		print('using majorityCnt')		return majorityCnt(classList)	#返回最佳特征值的索引值	bestFeature = chooseBestFeatureToSplit(dataSet)	#找到最佳特征值	bestFeatLabel = labels[bestFeature]	#使用字典类型储存树的所有信息	myTree = {bestFeatLabel:{}}	#删除已经使用过的最佳特征值	del(labels[bestFeature])	featValues = [ example[bestFeature] for example in dataSet]	uniqueVals = set(featValues)	#在每个数据集划分上进行递归	for value in uniqueVals:		#这里是切片的一个作业:完全复制labels类标签列表,防止因为操作改变了原始列表的值		subLabels = labels[:]		myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value),subLabels)	return myTree#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# myDat,labels = createDataSet()# myTree = createTree(myDat,labels)# print(myTree)#这里我犯了一个错误,主要是将索引值错当成了类标签值#--------------------------------测试结束---------------------"""在了解了决策树中原理以及相关函数操作后,我们去构建分类器并且去测试一下效果需要注意的是:一般储存带有特征的数据会面临一个问题:前面的函数中我们第一个用于划分数据集的特征是no surfacing,但是实际数据中这个属性到底在哪个位置,我们其实不知道的这个函数我们就可以使用index()方法找到"""#第六步,使用决策树执行分类#--------------------------------函数开始---------------------def classify(inputTree,featLabels,testVec):	firstStr = list(inputTree.keys())[0]	secondDict = inputTree[firstStr]	featIndex = featLabels.index(firstStr)	for key in secondDict.keys():		if testVec[featIndex] == key:			if type(secondDict[key]).__name__ == 'dict':				classLabel = classify(secondDict[key],featLabels,testVec)			else:				classLabel = secondDict[key]	return classLabel#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# myDat,labels = createDataSet()# myTree = plotTree.returnTree(0)# print(classify(myTree,labels,[1,1]))#--------------------------------测试结束---------------------"""我们知道,构建决策树是一个比较耗时间的事,哪怕是一个小小决策树因此,我们可以将决策树进行序列化保存到我们的磁盘上,用的时候直接读出来就好了"""#决策树的储存#--------------------------------函数开始---------------------def storeTree(inputTree,filename):	#这里我们需要注意的是:我们是以二进制储存所以读的时候也应该是以二进制读出	with open(filename,"wb",) as f:		pickle.dump(inputTree,f)def getTree(filename):	with open(filename,"rb") as f:		return pickle.load(f)#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# storeTree(myTree,'E:\PythonDemo\store.txt')# print(getTree('E:\PythonDemo\store.txt'))#--------------------------------测试结束---------------------"""这个数据来自于UCI数据库,然后为了显示做了一下简单的处理"""# 测试一下效果哈#--------------------------------函数开始---------------------def testMain():	with open("E:\ML_Data\lenses.txt","r") as f:		lenses = [inst.strip().split('\t') for inst in f.readlines()]	lensesLabels = ['age','prescript','astigmatic','tearRate']	lensesTree = createTree(lenses,lensesLabels)	return lensesTree#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# lensesTree = testMain()# print(lensesTree)# plotTree.createPlot(lensesTree)#--------------------------------测试结束---------------------

 

以上是构建决策树的核心代码,下面的代码主要是使用Matplotlib画出决策树,因为在这里我们主要去了解决策树算法(这里是ID3算法的核心),所以我就没有标注上画图过程中的各参数以及如何画,后期我也po出Matplotlib的一些简单画法以及博客中所出现的各种骚操作~

#!/usr/bin/env python3# -*- coding: utf-8 -*-' a plotTree module '__author__ = 'OJ cheng'import matplotlib.pyplot as plt"""决策树方便主要就是在于我们可以通过树图形直观地看到结果接下来根据书上的点儿学习如何用Matplotlib去画一个树"""#我们采用Matplotlib的注解工具annotation:用来给图像添加文本信息#第一步:使用文本注解绘制树节点,文件头部导入对应的包#--------------------------------函数开始---------------------decisionNode = dict(boxstyle="sawtooth",fc="0.8")leafNode = dict(boxstyle="round4",fc="0.8")arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt,centerPt,parentPt,nodeType):	createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)# def createPlot():# 	fig = plt.figure(1,facecolor='white')# 	fig.clf()# 	createPlot.ax1 = plt.subplot(111,frameon=False)# 	plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)# 	plotNode('a leaf node', (0.8,0.1),(0.3,0.8),leafNode)# 	plt.show()#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# createPlot()#--------------------------------测试结束---------------------"""为了减少创建树的时候耗费时间,在这里我们直接预先返回我们的树结构"""#返回树结构#--------------------------------函数开始---------------------def returnTree(i):	listOfTrees = [			{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},	        {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}        ] 	return listOfTrees[i]#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# print(returnTree(0))#--------------------------------测试结束---------------------"""知道如何绘制树节点了,那么我们如何去放置所有的树节点?确定x轴长度:去计算有多少个节点确定y轴长度:去计算树有多少层"""#第二步,获取叶子节点数以及树的层数#--------------------------------函数开始---------------------#获取叶子节点数def getNumLeafs(myTree):	numLeafs = 0	firstStr = list(myTree.keys())[0]	secondDict  = myTree[firstStr]	for key in secondDict.keys():		if type(secondDict[key]).__name__ == 'dict':			numLeafs += getNumLeafs(secondDict[key])		else:			numLeafs += 1	return numLeafsdef getTreeDepth(myTree):	maxDepth = 0	firstStr = list(myTree.keys())[0]	secondDict = myTree[firstStr]	for key in secondDict.keys():		if type(secondDict[key]).__name__ == 'dict':			thisDepth = 1 + getTreeDepth(secondDict[key])		else:			thisDepth = 1		if thisDepth > maxDepth :			maxDepth = thisDepth	return maxDepth#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# myTree = returnTree(0) # print(getNumLeafs(myTree))# print(getTreeDepth(myTree))#--------------------------------测试结束---------------------"""我们知道了树的节点数以及树的层数后,我们结合之前的函数,更新createPlot()函数"""#第三步,画出树形图#--------------------------------函数开始---------------------def plotMidText(cntrpt,parentPt,txtString):	xMid = (parentPt[0] - cntrpt[0])/2.0 + cntrpt[0]	yMid = (parentPt[1] - cntrpt[1])/2.0 + cntrpt[1]	createPlot.ax1.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt):	numLeafs = getNumLeafs(myTree)	depth = getTreeDepth(myTree)	firstStr = list(myTree.keys())[0]	cntrpt = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)	plotMidText(cntrpt,parentPt,nodeTxt)	plotNode(firstStr,cntrpt, parentPt, decisionNode)	secondDict = myTree[firstStr]	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD	for key in secondDict.keys():		if type(secondDict[key]).__name__  == 'dict' :			plotTree(secondDict[key], cntrpt,str(key))		else:			plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW			plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrpt,leafNode)			plotMidText((plotTree.xOff,plotTree.yOff),cntrpt,str(key))	plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):	fig = plt.figure(1,facecolor='white')	fig.clf()	axprops  = dict(xticks=[],yticks=[])	createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)	plotTree.totalW = float(getNumLeafs(inTree))	plotTree.totalD = float(getTreeDepth(inTree))	plotTree.xOff = -0.5 / plotTree.totalW	plotTree.yOff = 1.0	plotTree(inTree,(0.5,1.0),'')	plt.show()#--------------------------------函数结束---------------------#--------------------------------测试开始---------------------# myTree = returnTree(0)# createPlot(myTree)# myTree = returnTree(0)# myTree['no surfacing'][3] = 'maybe'# createPlot(myTree)#--------------------------------测试结束---------------------

对应数据:

百度云:https://pan.baidu.com/s/1b3UDoM    对应decisionTree文件夹

相关书籍pdf :https://github.com/apachecn/MachineLearning/tree/python-3.6/books

大家如果想要深入了解,请翻阅<机器学习实战>这一本书

或者直接访问开源https://github.com/apachecn/MachineLearning

 

 

喜欢可以多多收藏哦~转载请标注原文地址呦~

 

 

 

 

转载于:https://my.oschina.net/uncleoj/blog/1585220

你可能感兴趣的文章
明天回家了
查看>>
linux之SQL语句简明教程---INSERT INTO
查看>>
实时监控远程用户防问服务器的IP所属位置
查看>>
我的友情链接
查看>>
Java函数之Split的用法
查看>>
mysql 的CLOSE_WAIT 的问题
查看>>
1Python全栈之路系列之Django初体验
查看>>
获取资源的“手段”
查看>>
Freebsd9.0安装Nginx+PHP-FPM+MySQL+eAccelerator+Memcached
查看>>
更新 Exchange 2013 CU22后的问题
查看>>
[shell]线上环境puppet证书异常 重新进行认证脚本
查看>>
【读书笔记】数据库逻辑结构
查看>>
洛谷——1616 疯狂的采药(完全背包)
查看>>
Nginx脚本方式切割日志
查看>>
朋友圈H5营销怎么玩?又见情感营销
查看>>
$(window).scroll不能在样式内写overflow属性
查看>>
centos 6.0 dhcp 报错及解决
查看>>
查看linux 版本
查看>>
我的友情链接
查看>>
TNS-12537: TNS:connection closed ORA-609错误处理
查看>>