Decision Tree
langu_xyz

0x01 DT

分类算法

优点

  • 计算复杂度不高
  • 输出结果易于理解
  • 中间值缺失不敏感
  • 可处理不相关特征

缺点

  • 可能会产生过度匹配问题

适用数据类型:

  • 数值型
  • 标称型

0x02 准备数据

算法描述

1.根节点开始,测试待分类项中相应的特征属性

2.按照其值选择输出分支,直到到达叶子节点

3.将叶子节点存放的类别作为决策结果
  • 划分数据集

将无序的数据变得更加有序

信息增益:划分数据集之后信息发生的变化
:信息的期望值

  • 熵计算公式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #计算数据集中实例总数
labelCounts = {}
#统计每个键值的数量,dict
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
#计算香农熵
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt

  • 划分数据集

按照给定特征划分数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def splitDataSet(dataSet, axis, value):
'''

:param dataSet: 待划分数据集
:param axis: 特征
:param value: 特征值
:return: 符合条件的值列表
'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:]) #把特征列除去
retDataSet.append(reducedFeatVec)
return retDataSet

选择最好的数据集划分方式

熵越高,则混合的数据就越多

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def chooseBestFeatureToSplit(dataSet):
'''
:param dataSet: 数据集
:return:
'''
numFeatures = len(dataSet[0]) - 1 #特征列的长度,-1为label
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] #创建一个list包含所有数据的第i个feature
uniqueVals = set(featList) #转变为set格式
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value) #遍历featList中的所有feature,对每个feture划分一次数据集
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet) #计算当前feature的香农熵
infoGain = baseEntropy - newEntropy #计算熵差,信息增益
if (infoGain > bestInfoGain): #计算最大信息增益
bestInfoGain = infoGain
bestFeature = i
return bestFeature #返回最好的feature

递归构建决策树

1.得到数据集
2.最好feature划分
3.递归划分

当处理了所有feature后,类标签仍然不唯一时,采用多数表决方式决定子节点分类

1
2
3
4
5
6
7
8
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]

利用递归构建tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet] #数据集的所有类标签
if classList.count(classList[0]) == len(classList):
return classList[0] #当类标签完全相同返回该类标签
if len(dataSet[0]) == 1: #当所有属性都处理完,label仍然不唯一时,采用表决方式
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat] #当前数据集选取的最好特征变量
myTree = {bestFeatLabel: {}}
del(labels[bestFeat]) #删除用过的feature
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) #利用递归构建tree
return myTree
  • 绘制树形图

利用Matplotlib annotations实现绘制树形图

实现效果如下图

0x03 测试和储存分类器

  • 将标签字符串转换为索引
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def classify(inputTree,featLabels,testVec):
'''

:param inputTree: tree dict
:param featLabels: labels
:param testVec: 位置,eg.[1, 0]
:return:
'''
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
  • 存储决策树

使用pickle持久化对象

pickle.dump(obj, file[, protocol])

1
2
3
4
5
6
7
8
9
10
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()

def grabTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)

0x04 使用决策树预测隐形眼镜类型

  • 收集数据

lenses

  • 准备数据

解析通过’\t’分隔的数据

  • 分析数据&训练模型
1
2
labels = ['age', 'prescript', 'astigmatic', 'tearRate']
lenses_tree = createTree(lenses, labels)

  • 测试模型

0x05 其它模型

  • ID3(分类树)

    每次根据“最大信息熵增益”选取当前最佳的特征来分割数据,并按照该特征的所有取值来切分

  • C4.5(分类树)

    ID3的升级版,采用信息增益比率,通过引入一个被称作分裂信息(Split information)的项来惩罚取值较多的Feature
    弥补了ID3中不能处理特征属性值连续的问题

  • CART(分类回归树)

    CART是一棵二叉树,采用二元切分法,每次把数据切成两份,分别进入左子树、右子树。而且每个非叶子节点都有两个孩子,所以CART的叶子节点比非叶子多1

0x05 安全领域

  • 分析恶意网络攻击和入侵
  • 口令爆破检测
  • 僵尸流量检测
  • Post title:Decision Tree
  • Post author:langu_xyz
  • Create time:2019-08-01 21:00:00
  • Post link:https://blog.langu.xyz/Decision Tree/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.