重寫後的句子: 在監督學習中,常用的一種分類演算法是決策樹,其基於一批樣本,每個樣本都包含一組屬性和對應的分類結果。利用這些樣本進行學習,演算法可以產生一棵決策樹,決策樹可以對新資料進行正確分類
假設現有用戶14名,其個人屬性及是否購買某一產品的資料如下:
編號 | #年齡 | ##工作性質#信用評級 | |||
---|---|---|---|---|---|
#01 | 高 | 不穩定 | 較差 | ||
#02 | 高 | 不穩定 | 好 | ||
#03 | 30-40 | 高 | 不穩定 | 較差 | |
#04 | >40 | 中 | 不穩定 | 較差 | |
#05 | >40 | 低 | 穩定 | 較差 | |
06 | #> 40 | 低 | 穩定 | 好 | |
#07 | 30- 40 | 低 | 穩定 | 好 | |
##08 | 中等 | 不穩定 | 較差 | ||
#09 | # | ||||
穩定 | 較差 | #是 | 10 | >40 | |
穩定 | 較差 | #是 | 11 | ||
穩定 | 好 | #是 | 12 | 30-40 | |
不穩定 | 好 | #是 | 13 | 30-40 | |
穩定 | 較差 | 是 | 14 | >40 |
為了方便處理,對類比資料依下列規則轉換為數值型清單資料:
年齡:< ;30賦值為0;30-40賦值為1;>40賦值為2
收入:低為0;中為1;高為2
工作性質:不穩定為0;穩定為1
信用評級:差為0;好為1
#创建数据集 def createdataset(): dataSet=[[0,2,0,0,'N'], [0,2,0,1,'N'], [1,2,0,0,'Y'], [2,1,0,0,'Y'], [2,0,1,0,'Y'], [2,0,1,1,'N'], [1,0,1,1,'Y'], [0,1,0,0,'N'], [0,0,1,0,'Y'], [2,1,1,0,'Y'], [0,1,1,1,'Y'], [1,1,0,1,'Y'], [1,2,1,0,'Y'], [2,1,0,1,'N'],] labels=['age','income','job','credit'] return dataSet,labels
呼叫函數,可取得資料:
ds1,lab = createdataset() print(ds1) print(lab)
[[0, 2 , 0, 0, ‘N’], [0, 2, 0, 1, ‘N’], [1, 2, 0, 0, ‘Y’], [2, 1, 0, 0, ‘Y’], [2, 0, 1, 0, ‘Y’], [2, 0, 1, 1, ‘N’], [1, 0, 1, 1, ‘Y’] , [0, 1, 0, 0, ‘N’], [0, 0, 1, 0, ‘Y’], [2, 1, 1, 0, ‘Y’], [0, 1 , 1, 1, ‘Y’], [1, 1, 0, 1, ‘Y’], [1, 2, 1, 0, ‘Y’], [2, 1, 0, 1, ‘N’]]
[‘age’, ‘income’, ‘job’, ‘credit’]
# #資訊熵也稱為香農熵,是隨機變數的期望。度量資訊的不確定程度。訊息的熵越大,訊息就越不容易搞清楚。處理資訊就是為了把資訊搞清楚,就是熵減少的過程。
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob*log(prob,2) return shannonEnt
樣本資料資訊熵:
shan = calcShannonEnt(ds1) print(shan)
0.9402859586706309
資訊增益:用於度量屬性A降低樣本集合X熵的貢獻大小。資訊增益越大,越適於對X分類。
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0;bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntroy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prop = len(subDataSet)/float(len(dataSet)) newEntroy += prop * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntroy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
以上程式碼實現了基於資訊熵增益的ID3決策樹學習演算法。其核心邏輯原理為:依序選取屬性集中的每一個屬性,將樣本集依此屬性的取值分割為若干個子集;對這些子集計算資訊熵,其與樣本的資訊熵的差,即為依照此屬性分割的資訊熵增益;找出所有增益中最大的那一個對應的屬性,就是用來分割樣本集的屬性。
計算樣本最佳的分割樣本屬性,結果顯示為第0列,即age屬性:
col = chooseBestFeatureToSplit(ds1) col
0
def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys():classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classList.iteritems(),key=operator.itemgetter(1),reverse=True)#利用operator操作键值排序字典 return sortedClassCount[0][0] #创建树的函数 def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree
majorityCnt
函數用來處理情況:當最終的理想決策樹應該沿著決策分支到達最底端時,所有的樣本應該都是相同的分類結果。但真實樣本中難免會出現所有屬性一致但分類結果不一樣的情況,此時majorityCnt
將這類樣本的分類標籤都調整為出現次數最多的那一個分類結果。
createTree
是核心任務函數,它對所有的屬性依序呼叫ID3資訊熵增益演算法進行計算處理,最終產生決策樹。
利用樣本資料建構決策樹:
Tree = createTree(ds1, lab) print("样本数据决策树:") print(Tree)
樣本資料決策樹:
{‘age’: {0: {‘job’: {0: ‘N’, 1: ‘Y’}},
1: ‘Y’,
2: {‘credit’: {0: ‘Y’, 1: ‘N’}}}}
#收入範圍 | 工作性質 | #信用評級 | |
---|---|---|---|
#穩定 | 好 | ||
不穩定 | 好 |
result: Y後置資訊:繪製決策樹程式碼#以下程式碼用於繪製決策樹圖形,非決策樹演算法重點,有興趣可參考學習result1: N
import matplotlib.pyplot as plt decisionNode = dict(box, fc="0.8") leafNode = dict(box, fc="0.8") arrow_args = dict(arrow) #获取叶节点的数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点 numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs #获取树的层数 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点 thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth #绘制节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) #绘制连接线 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) #绘制树结构 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #创建决策树图形 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.savefig('决策树.png',dpi=300,bbox_inches='tight') plt.show()
以上是python如何實作決策樹分類演算法的詳細內容。更多資訊請關注PHP中文網其他相關文章!