English | 简体中文 | 繁體中文 | Русский язык | Français | Español | Português | Deutsch | 日本語 | 한국어 | Italiano | بالعربية
결정 트리는 일반적으로 기계 학습에서 분류에 사용됩니다.
점: 계산 복잡도가 낮아 결과가 이해하기 쉽고, 중간 값이 누락되지 않음, 불관련 특징 데이터를 처리할 수 있습니다.
단점: 과도한 매칭 문제가 발생할 수 있습니다.
적용 데이터 타입: 수치형과 이름형입니다.
1.信息增益
划分数据集的目的是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息。通常采用信息增益,信息增益是指数据划分前后信息熵的减少值。信息越无序信息熵越大,获得信息增益最高的特征就是最好的选择。
熵定义为信息的期望,符号xi的信息定义为:
其中p(xi)为该分类的概率。
熵,即信息的期望值为:
计算信息熵的代码如下:
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts: labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0 for key in labelCounts: shannonEnt = shannonEnt - (labelCounts[key]/numEntries)*math.log2(labelCounts[key]/numEntries) return shannonEnt
可以根据信息熵,按照获取最大信息增益的方法划分数据集。
2.划分数据集
划分数据集就是将所有符合要求的元素抽出来。
def splitDataSet(dataSet,axis,value): retDataset = [] for featVec in dataSet: if featVec[axis] == value: newVec = featVec[:axis] newVec.extend(featVec[axis+1:]) retDataset.append(newVec) return retDataset
3.选择最好的数据集划分方式
信息增益是熵的减少或者是信息无序度的减少。
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 bestInfoGain = 0 bestFeature = -1 baseEntropy = calcShannonEnt(dataSet) for i in range(numFeatures): allValue = [example[i] for example in dataSet]#列表推倒,创建新的列表 allValue = set(allValue)#最快得到列表中唯一元素值的方法 newEntropy = 0 for value in allValue: splitset = splitDataSet(dataSet,i,value) newEntropy = newEntropy + len(splitset)/len(dataSet)*calcShannonEnt(splitset) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature
4.递归创建决策树
结束条件为:程序遍历完所有划分数据集的属性,或每个分支下的所有实例都具有相同的分类。
当数据集已经处理了所有属性,但是类标签还不唯一时,采用多数表决的方式决定叶子节点的类型。
def majorityCnt(classList): classCount = {} for value in classList: if value not in classCount: classCount[value] = 0 classCount[value] += 1 classCount = sorted(classCount.items(),key=operator.itemgetter(1,reverse=True) return classCount[0][0]
生成决策树:
def createTree(dataSet,labels): classList = [example[-1for example in dataSet] labelsCopy = labels[:] classList.count(classList[0]) == len(classList) return classList[0] if len(dataSet[0]) === 1: return majorityCnt(classList) bestFeature = chooseBestFeatureToSplit(dataSet) bestLabel = labelsCopy[bestFeature] myTree = {bestLabel:{}} featureValues = [example[bestFeature] for example in dataSet] featureValues = set(featureValues) del(labelsCopy[bestFeature]) for value in featureValues: subLabels = labelsCopy[:] myTree[bestLabel][value] = createTree(splitDataSet(dataSet,bestFeature,value),subLabels) return myTree
5.测试算法——使用决策树分类
同样采用递归的方式得到分类结果。
def classify(inputTree,featLabels,testVec): currentFeat = list(inputTree.keys())[0] secondTree = inputTree[currentFeat] try: featureIndex = featLabels.index(currentFeat) except ValueError as err: print('yes') try: for value in secondTree.keys(): if value == testVec[featureIndex]: if type(secondTree[value]).__name__ == 'dict': classLabel = classify(secondTree[value],featLabels,testVec) else: classLabel = secondTree[value] return classLabel except AttributeError: print(secondTree)
6.完整代码如下
import numpy as np import math import operator def createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no'],] label = ['no surfacing','flippers'] return dataSet,label def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts: labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0 for key in labelCounts: shannonEnt = shannonEnt - (labelCounts[key]/numEntries)*math.log2(labelCounts[key]/numEntries) return shannonEnt def splitDataSet(dataSet,axis,value): retDataset = [] for featVec in dataSet: if featVec[axis] == value: newVec = featVec[:axis] newVec.extend(featVec[axis+1:]) retDataset.append(newVec) return retDataset def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 bestInfoGain = 0 bestFeature = -1 baseEntropy = calcShannonEnt(dataSet) for i in range(numFeatures): allValue = [example[i] for example in dataSet] allValue = set(allValue) newEntropy = 0 for value in allValue: splitset = splitDataSet(dataSet,i,value) newEntropy = newEntropy + len(splitset)/len(dataSet)*calcShannonEnt(splitset) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature def majorityCnt(classList): classCount = {} for value in classList: if value not in classCount: classCount[value] = 0 classCount[value] += 1 classCount = sorted(classCount.items(),key=operator.itemgetter(1,reverse=True) return classCount[0][0] def createTree(dataSet,labels): classList = [example[-1for example in dataSet] labelsCopy = labels[:] classList.count(classList[0]) == len(classList) return classList[0] if len(dataSet[0]) === 1: return majorityCnt(classList) bestFeature = chooseBestFeatureToSplit(dataSet) bestLabel = labelsCopy[bestFeature] myTree = {bestLabel:{}} featureValues = [example[bestFeature] for example in dataSet] featureValues = set(featureValues) del(labelsCopy[bestFeature]) for value in featureValues: subLabels = labelsCopy[:] myTree[bestLabel][value] = createTree(splitDataSet(dataSet,bestFeature,value),subLabels) return myTree def classify(inputTree,featLabels,testVec): currentFeat = list(inputTree.keys())[0] secondTree = inputTree[currentFeat] try: featureIndex = featLabels.index(currentFeat) except ValueError as err: print('yes') try: for value in secondTree.keys(): if value == testVec[featureIndex]: if type(secondTree[value]).__name__ == 'dict': classLabel = classify(secondTree[value],featLabels,testVec) else: classLabel = secondTree[value] return classLabel except AttributeError: print(secondTree) if __name__ == "__main__": dataset,label = createDataSet() myTree = createTree(dataset,label) a = [1,1] print(classify(myTree,label,a))
7. 프로그래밍 기술
extend와 append의 차이
newVec.extend(featVec[axis+1:]) retDataset.append(newVec)
extend([])는 리스트의 각 요소를 순차적으로 새로운 리스트에 추가합니다
append()는 괄호 내의 내용을 새로운 리스트에 하나의 항목으로 추가합니다
리스트 푸시
신규 리스트 생성 방법
allValue = [example[i] for example in dataSet]
리스트에서 독특한 요소 추출
allValue = set(allValue)
리스트/특징 정렬, sorted() 함수
classCount = sorted(classCount.items(),key=operator.itemgetter(1,reverse=True)
리스트의 복사
labelsCopy = labels[:]
코드 및 데이터 세트 다운로드:결정 트리
이것이 이 문서의 전체 내용입니다. 많은 도움이 되길 바랍니다. 또한, 많은 지원을 부탁드립니다.
언급: 이 문서의 내용은 인터넷에서 수집되었으며, 저작권자가 소유하고 있습니다. 이 내용은 인터넷 사용자가 자발적으로 기여하고 업로드한 것이며, 이 사이트는 소유권을 가지지 않으며, 인공적인 편집 처리를 하지 않았으며, 관련 법적 책임도 부담하지 않습니다. 저작권 문제가 있음을 발견하면, 이메일을 notice#w로 보내 주세요.3codebox.com에 대한 신고를 보내시면, #을 @으로 변경하여 이메일을 보내고 관련 증거를 제공해 주세요. 신고가 확인되면, 이 사이트는 즉시 위반 내용을 삭제합니다.