首頁 >後端開發 >Python教學 >Python中的隨機森林演算法實例

Python中的隨機森林演算法實例

王林
王林原創
2023-06-10 13:12:073073瀏覽

隨機森林(Random Forest)是一種整合學習(Ensemble Learning)演算法,其透過結合多個決策樹的預測結果來提高準確性和穩健性。隨機森林在各領域都有廣泛的應用,例如金融、醫療、電商等。

本文將介紹如何使用Python實現隨機森林分類器,並使用鳶尾花資料集進行測試。

一、鳶尾花資料集

鳶尾花資料集是機器學習中一個經典的資料集,包含了150筆記錄,每筆記錄有4個特徵和1個類別標籤。其中4個特徵分別是花萼長度、花萼寬度、花瓣長度和花瓣寬度,類別標籤則表示鳶尾花的三個品種之一(山鳶尾、變色鳶尾、維吉尼亞鳶尾花)。

在Python中,我們可以使用scikit-learn這個強大的機器學習函式庫來載入鳶尾花資料集。具體操作如下:

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

二、建立隨機森林分類器

使用scikit-learn建立隨機森林分類器非常簡單。首先,我們需要從sklearn.ensemble中導入RandomForestClassifier類,並實例化一個物件:

from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier(n_estimators=10)

其中,n_estimators參數指定了隨機森林中包含的決策樹數量。此處,我們將隨機森林中的決策樹數量設定為10。

接著,我們需要將鳶尾花資料集分成訓練資料和測試資料。使用train_test_split函數將資料集隨機劃分為訓練集和測試集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

其中,test_size參數指定了測試集所佔比例,random_state參數指定了偽隨機數產生器的種子,以確保每次運行程序得到相同的結果。

然後,我們可以使用訓練資料來訓練隨機森林分類器:

rfc.fit(X_train, y_train)

三、測試隨機森林分類器

一旦分類器已經訓練完畢,我們可以使用測試數據來測試其性能。使用predict函數對測試集進行預測,並使用accuracy_score函數計算模型的準確率:

from sklearn.metrics import accuracy_score

y_pred = rfc.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

最後,我們可以使用matplotlib庫將分類器的決策邊界可視化,以便更好地理解分類器的行為:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
z_min, z_max = X[:, 2].min() - .5, X[:, 2].max() + .5
xx, yy, zz = np.meshgrid(np.arange(x_min, x_max, 0.2), np.arange(y_min, y_max, 0.2), np.arange(z_min, z_max, 0.2))

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

Z = rfc.predict(np.c_[xx.ravel(), yy.ravel(), zz.ravel()])
Z = Z.reshape(xx.shape)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y)
ax.set_xlabel('Sepal length')
ax.set_ylabel('Sepal width')
ax.set_zlabel('Petal length')
ax.set_title('Decision Boundary')

ax.view_init(elev=30, azim=120)
ax.plot_surface(xx, yy, zz, alpha=0.3, facecolors='blue')

plt.show()

上述程式碼將得到一個三維圖像,其中資料點的顏色表示鳶尾花的品種,決策邊界則用半透明的藍色面來表示。

四、總結

本文介紹如何使用Python實作隨機森林分類器,並使用鳶尾花資料集進行測試。由於隨機森林演算法的穩健性和準確性,它在實際應用中有廣泛的應用前景。如果您對該演算法感興趣,建議多實踐並閱讀相關的文獻。

以上是Python中的隨機森林演算法實例的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn