搜尋
首頁科技週邊人工智慧為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎

對於二元分類,分類器輸出一個實值分數,然後透過對該值進行閾值的區分產生二元的對應。例如,邏輯迴歸輸出一個機率(一個介於0.0和1.0之間的值);得分等於或高於0.5的觀察結果產生正輸出(許多其他模型預設使用0.5閾值)。

但是使用預設的0.5閾值是不理想的。在本文中,我將展示如何從二元分類器中選擇最佳閾值。本文將使用Ploomber並行執行我們的實驗,並使用sklearn-evaluation產生圖。

為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎

這裡以訓練邏輯迴歸為例。假設我們正在開發一個內容審核系統,模型標記包含有害內容的貼文(圖片、影片等);然後,人工會查看並決定內容是否被刪除。

建立簡單的二元分類器

下面的程式碼片段訓練我們的分類器:

import matplotlib.pyplot as plt
 import matplotlib as mpl
 from sklearn import datasets
 from sklearn.linear_model import LogisticRegression
 from sklearn.model_selection import train_test_split
 from sklearn_evaluation.plot import ConfusionMatrix
 
 # matplotlib settings
 mpl.rcParams['figure.figsize'] = (4, 4)
 mpl.rcParams['figure.dpi'] = 150
 
 # create sample dataset
 X, y = datasets.make_classification(1000, 10, n_informative=5, class_sep=0.4)
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
 
 # fit model
 clf = LogisticRegression()
 _ = clf.fit(X_train, y_train)

現在讓我們對測試集進行預測,並透過混淆矩陣評估性能:

# predict on the test set
 y_pred = clf.predict(X_test)
 
 # plot confusion matrix
 cm_dot_five = ConfusionMatrix(y_test, y_pred)
 cm_dot_five

為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎

混淆矩陣總結了模型在四個區域的性能:

為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎

##我們希望在左上和在右下象限中獲得盡可能多的觀察值(從測試集),因為這些是我們的模型得到正確的觀察值。其他像限是模型錯誤。

改變模型的閾值將改變混淆矩陣中的值。在前面的範例中,使用clf.predict,傳回一個二元回應(即使用0.5作為閾值);但是我們可以使用clf.predict_proba函數取得原始機率並使用自訂閾值:

y_score = clf.predict_proba(X_test)

我們可以透過設定一個較低的閾值(即標記更多的帖子為有害的)來讓我們的分類器更具侵略性,並創建一個新的混淆矩陣:

cm_dot_four = ConfusionMatrix(y_score[:, 1] >= 0.4, y_pred)

sklearn-evaluation庫可以輕鬆比較兩個矩陣:

cm_dot_five + cm_dot_four

三角形的上面來自0.5的閾值,下面來自0.4的閾值:

    兩個模型對相同數量的觀測結果都預測為0(這是一個巧合)。 0.5閾值:(90 56 = 146)。 0.4閾值:(78 68 = 146)
  • 降低閾值會導致更多的假陰性(從56例降至68例)
  • 降低閾值將大大增加真陽性(從92例增加154例)
微小的閾值變化極大地影響了混淆矩陣。我們只分析了兩個閾值。那麼如果能夠分析跨所有值的模型效能,我們就可以好地理解閾值動態。但是在此之前,需要定義用於模型評估的新指標。

到目前為止,我們都是用絕對數字來評估我們的模型。為了便於比較和評估,我們現在將定義兩個標準化指標(它們的值在0.0和1.0之間)。

精度precision是標記的觀察事件的比例(例如,我們的模型認為有害的帖子,它們是有害的)。召回 recall是我們的模型檢索到的實際事件的比例(即,從所有有害的帖子中,我們能夠檢測到它們的哪個比例)。

為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎

以上圖片來自維基百科,可以很好的說明這兩個指標是如何計算的,精確度和召回率都是比例關係,所以它們都是0比1的比例。

運行實驗

我們將根據幾個閾值獲得精度、召回率和其他統計信息,以便更好地理解閾值如何影響它們。我們也會多次重複這個實驗來測量變異性。

本節的指令都是bash指令。需要在終端機中執行它們,如果使用Jupyter可以使用%%sh魔法命令。

這裡使用Ploomber Cloud來執行我們的實驗。因為它允許我們並行運行實驗並快速檢索結果。

建立了一個適合一個模型的Notebook,並為幾個閾值計算統計數據,並行執行同一個Notebook20次。

curl -O https://raw.githubusercontent.com/ploomber/posts/master/threshold/fit.ipynb?utm_source=medium&utm_medium=blog&utm_campaign=threshold

讓執行這個Notebook(檔案中的設定會告訴Ploomber Cloud並行運行它20次):

ploomber cloud nb fit.ipynb

幾分鐘後,我們就會看到的20個實驗完成了:

ploomber cloud status @latest --summary
 
 status count
 -------- -------
 finished 20
 
 Pipeline finished. Check outputs:
 $ ploomber cloud products

讓我們下載儲存在.csv檔案中的實驗結果:

ploomber cloud download 'threshold-selection/*.csv' --summary

可視化實驗結果

將載入所有實驗的結果,並一次將它們繪製出來。

 from glob import glob
 
 import pandas as pd
 import numpy as np
 paths = glob('threshold-selection/**/*.csv')
 metrics = [pd.read_csv(path) for path in paths]
 
 for idx, df in enumerate(metrics):
plt.plot(df.threshold, df.precision, color='blue', alpha=0.2,
label='precision' if idx == 0 else None)
plt.plot(df.threshold, df.recall, color='green', alpha=0.2,
label='recall' if idx == 0 else None)
plt.plot(df.threshold, df.f1, color='orange', alpha=0.2,
label='f1' if idx == 0 else None)
 
 
 plt.grid()
 plt.legend()
 plt.xlabel('Threshold')
 plt.ylabel('Metric value')
 
 for handle in plt.legend().legendHandles:
handle.set_alpha(1)
 
 ax = plt.twinx()
 
 for idx, df in enumerate(metrics):
ax.plot(df.threshold, df.n_flagged,
label='flagged' if idx == 0 else None,
color='red', alpha=0.2)
 
 plt.ylabel('Flagged')
 ax.legend(loc=0)
 ax.legend().legendHandles[0].set_alpha(1)

為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎#

左边的刻度(从0到1)是我们的三个指标:精度、召回率和F1。F1分为精度与查全率的调和平均值,F1分的最佳值为1.0,最差值为0.0;F1对精度和召回率都是相同对待的,所以你可以看到它在两者之间保持平衡。如果你正在处理一个精确度和召回率都很重要的用例,那么最大化F1是一种可以帮助你优化分类器阈值的方法。

这里还包括一条红色曲线(右侧的比例),显示我们的模型标记为有害内容的案例数量。

在这个的内容审核示例中,可能有X个的工作人员来人工审核模型标记的有害帖子,但是他们人数是有限的,因此考虑标记帖子的总数可以帮助我们更好地选择阈值:例如每天只能检查5000个帖子,那么模型找到10,000帖并不会带来任何的提高。如果我人工每天可以处理10000贴,但是模型只标记了100贴,那么显然也是浪费的。

当设置较低的阈值时,有较高的召回率(我们检索了大部分实际上有害的帖子),但精度较低(包含了许多无害的帖子)。如果我们提高阈值,情况就会反转:召回率下降(错过了许多有害的帖子),但精确度很高(大多数标记的帖子都是有害的)。

所以在为我们的二元分类器选择阈值时,我们必须在精度或召回率上妥协,因为没有一个分类器是完美的。我们来讨论一下如何推理选择合适的阈值。

选择最佳阈值

右边的数据会产生噪声(较大的阈值)。需要稍微清理一下,我们将重新创建这个图,我们将绘制2.5%、50%和97.5%的百分位数,而不是绘制所有值。

shape = (df.shape[0], len(metrics))
 precision = np.zeros(shape)
 recall = np.zeros(shape)
 f1 = np.zeros(shape)
 n_flagged = np.zeros(shape)
 for i, df in enumerate(metrics):
precision[:, i] = df.precision.values
recall[:, i] = df.recall.values
f1[:, i] = df.f1.values
n_flagged[:, i] = df.n_flagged.values
 precision_ = np.quantile(precision, q=0.5, axis=1)
 recall_ = np.quantile(recall, q=0.5, axis=1)
 f1_ = np.quantile(f1, q=0.5, axis=1)
 n_flagged_ = np.quantile(n_flagged, q=0.5, axis=1)
 plt.plot(df.threshold, precision_, color='blue', label='precision')
 plt.plot(df.threshold, recall_, color='green', label='recall')
 plt.plot(df.threshold, f1_, color='orange', label='f1')
 
 plt.fill_between(df.threshold, precision_interval[0],
precision_interval[1], color='blue',
alpha=0.2)
 
 plt.fill_between(df.threshold, recall_interval[0],
recall_interval[1], color='green',
alpha=0.2)
 
 
 plt.fill_between(df.threshold, f1_interval[0],
f1_interval[1], color='orange',
alpha=0.2)
 plt.xlabel('Threshold')
 plt.ylabel('Metric value')
 plt.legend()
 
 ax = plt.twinx()
 ax.plot(df.threshold, n_flagged_, color='red', label='flagged')
 ax.fill_between(df.threshold, n_flagged_interval[0],
n_flagged_interval[1], color='red',
alpha=0.2)
 
 ax.legend(loc=3)
 
 plt.ylabel('Flagged')
 plt.grid()

為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎

我们可以根据自己的需求选择阈值,例如检索尽可能多的有害帖子(高召回率)是否更重要?还是要有更高的确定性,我们标记的必须是有害的(高精度)?

如果两者都同等重要,那么在这些条件下优化的常用方法就是最大化F-1分数:

idx = np.argmax(f1_)
 prec_lower, prec_upper = precision_interval[0][idx], precision_interval[1][idx]
 rec_lower, rec_upper = recall_interval[0][idx], recall_interval[1][idx]
 threshold = df.threshold[idx]
 
 print(f'Max F1 score: {f1_[idx]:.2f}')
 print('Metrics when maximizing F1 score:')
 print(f' - Threshold: {threshold:.2f}')
 print(f' - Precision range: ({prec_lower:.2f}, {prec_upper:.2f})')
 print(f' - Recall range: ({rec_lower:.2f}, {rec_upper:.2f})')
 
 #结果
 Max F1 score: 0.71
 Metrics when maximizing F1 score:
- Threshold: 0.26
- Precision range: (0.58, 0.61)
- Recall range: (0.86, 0.90)

在很多情况下很难决定这个折中,所以加入一些约束条件会有一些帮助。

假设我们有10个人审查有害的帖子,他们可以一起检查5000个。那么让我们看看指标,如果我们修改了阈值,让它标记了大约5000个帖子:

idx = np.argmax(n_flagged_ <= 5000)
 
 prec_lower, prec_upper = precision_interval[0][idx], precision_interval[1][idx]
 rec_lower, rec_upper = recall_interval[0][idx], recall_interval[1][idx]
 threshold = df.threshold[idx]
 
 print('Metrics when limiting to a maximum of 5,000 flagged events:')
 print(f' - Threshold: {threshold:.2f}')
 print(f' - Precision range: ({prec_lower:.2f}, {prec_upper:.2f})')
 print(f' - Recall range: ({rec_lower:.2f}, {rec_upper:.2f})')
 
 # 结果
 Metrics when limiting to a maximum of 5,000 flagged events:
- Threshold: 0.82
- Precision range: (0.77, 0.81)
- Recall range: (0.25, 0.36)

如果需要进行汇报,我们可以在在展示结果时展示一些替代方案:比如在当前约束条件下(5000个帖子)的模型性能,以及如果我们增加团队(比如通过增加一倍的规模),我们可以做得更好。

总结

二元分类器的最佳阈值是针对业务结果进行优化并考虑到流程限制的阈值。通过本文中描述的过程,你可以更好地为用例决定最佳阈值。

另外,Ploomber Cloud!提供一些免费的算力!如果你需要一些免费的服务可以试试它。

以上是為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:51CTO.COM。如有侵權,請聯絡admin@php.cn刪除
具有多模式和Azure文檔智能的抹布具有多模式和Azure文檔智能的抹布Apr 13, 2025 am 10:38 AM

介紹 在基於數據運行的當前世界中,關係AI圖(RAG)通過關聯數據並繪製關係來對行業產生很大影響。但是,如果一個人可以再進一步多怎麼辦

在生成AI時代負責的AI在生成AI時代負責的AIApr 13, 2025 am 10:28 AM

介紹 現在,我們生活在人工智能時代,我們周圍的一切都在一天變得更加聰明。最先進的大語言模型(LLM)和AI代理,能夠執行複雜的任務

GPT-4O vs OpenAI O1:新的Openai模型值得炒作嗎?GPT-4O vs OpenAI O1:新的Openai模型值得炒作嗎?Apr 13, 2025 am 10:18 AM

介紹 Openai已根據備受期待的“草莓”建築發布了其新模型。這種稱為O1的創新模型增強了推理能力,使其可以通過問題進行思考

小語言模型的微調和推斷小語言模型的微調和推斷Apr 13, 2025 am 10:15 AM

介紹 想像一下,您正在建立醫療聊天機器人,大量的,渴望資源的大型語言模型(LLMS)似乎滿足您的需求。那是小語言模型(SLM)等傑瑪(SLM)發揮作用

如何訪問OpenAi O1 API |分析Vidhya如何訪問OpenAi O1 API |分析VidhyaApr 13, 2025 am 10:14 AM

介紹 OpenAI的O1系列模型代表了大語言模型(LLM)功能的重大飛躍,尤其是對於復雜的推理任務。這些模型在RESP之前從事深厚的內部思維過程

使用Python的Google表自動化|分析Vidhya使用Python的Google表自動化|分析VidhyaApr 13, 2025 am 10:01 AM

Google表是Excel的最受歡迎和廣泛使用的替代方案之一,它提供了具有實時編輯,版本控制和與Google Suite無縫集成等功能的協作環境,允許U

O1-Mini:一種改變遊戲規則的STEM和推理模型O1-Mini:一種改變遊戲規則的STEM和推理模型Apr 13, 2025 am 09:55 AM

OpenAI引入了O1-Mini,這是一種具有成本效益的推理模型,重點是STEM受試者。該模型在數學和編碼中表現出令人印象深刻的性能,與其前身Openai O1非常相似

用結構化輸出和功能調用增強LLM用結構化輸出和功能調用增強LLMApr 13, 2025 am 09:45 AM

介紹 假設您正在與知識淵博但有時缺乏具體/知情的回答或他/她/她/她在面對複雜問題時不會流利的回應時互動。我們在這裡做什麼

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
3 週前By尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解鎖Myrise中的所有內容
4 週前By尊渡假赌尊渡假赌尊渡假赌

熱工具

MinGW - Minimalist GNU for Windows

MinGW - Minimalist GNU for Windows

這個專案正在遷移到osdn.net/projects/mingw的過程中,你可以繼續在那裡關注我們。 MinGW:GNU編譯器集合(GCC)的本機Windows移植版本,可自由分發的導入函式庫和用於建置本機Windows應用程式的頭檔;包括對MSVC執行時間的擴展,以支援C99功能。 MinGW的所有軟體都可以在64位元Windows平台上運作。

MantisBT

MantisBT

Mantis是一個易於部署的基於Web的缺陷追蹤工具,用於幫助產品缺陷追蹤。它需要PHP、MySQL和一個Web伺服器。請查看我們的演示和託管服務。

Safe Exam Browser

Safe Exam Browser

Safe Exam Browser是一個安全的瀏覽器環境,安全地進行線上考試。該軟體將任何電腦變成一個安全的工作站。它控制對任何實用工具的訪問,並防止學生使用未經授權的資源。

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

Dreamweaver Mac版

Dreamweaver Mac版

視覺化網頁開發工具