ホームページ >バックエンド開発 >Python チュートリアル >Pythonでデシジョンツリー分類アルゴリズムを実装する方法
書き換えられた文: 教師あり学習では、一般的に使用される分類アルゴリズムは決定木です。これはサンプルのバッチに基づいており、各サンプルには一連の属性と対応する分類結果が含まれています。これらのサンプルを学習に使用すると、アルゴリズムは新しいデータを正しく分類できるデシジョン ツリーを生成できます
14 人の既存ユーザーとその個人属性があると仮定します。特定の商品を購入するかどうかは次のとおりです。
年齢 | 収入範囲 | 職種 | 信用格付け | 購入決定 | |||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
高い | 不安定 | 悪い | No | ||||||||||||
##高 | 不安定 | 良好 | なし | ||||||||||||
30-40 | 高い | 不安定 | 悪い | は | |||||||||||
>40 | 中 | 不安定 | 悪い | はい | #05 | ||||||||||
低 | 安定 | 悪い | はい | 06 | |||||||||||
低 | 安定 | 良好 | なし | 07 | |||||||||||
低 | 安定 | 良い | はい | 08 | |||||||||||
不安定 | 悪い | No | 09 | ||||||||||||
安定 | 悪い | は | 10 | ||||||||||||
中 | 安定 | 悪い | はい | ##11 | 中 | 安定 | 良い | はい | ##12 | 30-40 | |||||
不安定 | 良い | はい | 13 | 30-40 | |||||||||||
安定 | 悪い | #は | 14 | >40 | |||||||||||
不安定 | 良い | いいえ | ## |
仕事の内容 | 信用格付け | ## | |
---|---|---|---|
# | 高 | 不安定 | |
def classify(inputtree,featlabels,testvec): firststr = list(inputtree.keys())[0] seconddict = inputtree[firststr] featindex = featlabels.index(firststr) for key in seconddict.keys(): if testvec[featindex]==key: if type(seconddict[key]).__name__=='dict': classlabel=classify(seconddict[key],featlabels,testvec) else: classlabel=seconddict[key] return classlabel labels=['age','income','job','credit'] tsvec=[0,0,1,1] print('result:',classify(Tree,labels,tsvec)) tsvec1=[0,2,0,1] print('result1:',classify(Tree,labels,tsvec1)) |
result: Y | result1: N
import matplotlib.pyplot as plt decisionNode = dict(box, fc="0.8") leafNode = dict(box, fc="0.8") arrow_args = dict(arrow) #获取叶节点的数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点 numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs #获取树的层数 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点 thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth #绘制节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) #绘制连接线 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) #绘制树结构 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #创建决策树图形 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.savefig('决策树.png',dpi=300,bbox_inches='tight') plt.show()
以上がPythonでデシジョンツリー分類アルゴリズムを実装する方法の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。