ホームページ >テクノロジー周辺機器 >AI >LightGBM実戦+ランダム探索パラメータ調整:命中率96.67%
皆さんこんにちは、ピーターです〜
LightGBM は古典的な機械学習アルゴリズムであり、その背景、原理、特性は非常に研究する価値があります。 LightGBM のアルゴリズムは、効率、拡張性、高精度などの機能をもたらします。この記事では、LightGBM の特徴と原理、および LightGBM とランダム検索最適化に基づくいくつかの事例を簡単に紹介します。
機械学習の分野では、勾配ブースティング マシン (GBM) は、弱い学習器 (通常は決定木) を徐々に追加することで予測誤差を最小限に抑える強力なアンサンブル学習アルゴリズムのクラスです。 GBM は、予測誤差を最小限に抑え、残差関数または損失関数を最小限に抑えることで達成できる強力なモデルを構築するためによく使用されます。このアルゴリズムは広く使用されており、デシジョン ツリーなどの弱い学習器で構築された強力なモデルの予測誤差を最小限に抑えるためによく使用されます。
ビッグデータの時代では、データセットのサイズが劇的に増大しており、従来の GBM はコンピューティングとストレージのコストが高いため、効果的に拡張することが困難です。
これらの問題を解決するために、Microsoft は 2017 年に、より高速でメモリ消費量が低く、パフォーマンスの高い勾配ブースティング フレームワークである LightGBM (Light Gradient Boosting Machine) を発売しました。
公式学習アドレス: https://lightgbm.readthedocs.io/en/stable/
1. ヒストグラムに基づく決定木アルゴリズム:
2. 深さ制限のあるリーフごとのツリー成長戦略:
3. 片側勾配サンプリング (GOSS):
4. 相互排他的機能バンドリング (EFB):
5. 並列学習と分散学習のサポート:
6. キャッシュの最適化:
7. 複数の損失関数をサポートします:
8. 正則化と枝刈り:
9. モデルの解釈可能性:
In [1]:
import numpy as npimport lightgbm as lgbfrom sklearn.model_selection import train_test_split, RandomizedSearchCVfrom sklearn.datasets import load_irisfrom sklearn.metrics import accuracy_scoreimport warningswarnings.filterwarnings("ignore")
パブリックアヤメデータセットのロード:
In [2]:
# 加载数据集data = load_iris()X, y = data.data, data.targety = [int(i) for i in y]# 将标签转换为整数
In [3]:
X[:3]
アウト[3]:
array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2]])
In [4]:
y[:10]
Out[4]:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
In [5]:
# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
同時に:
[6]内:
lgb_train = lgb.Dataset(X_train, label=y_train)
[7]内:
# 设置参数范围param_dist = {'boosting_type': ['gbdt', 'dart'],# 提升类型梯度提升决策树(gbdt)和Dropouts meet Multiple Additive Regression Trees(dart)'objective': ['binary', 'multiclass'],# 目标;二分类和多分类'num_leaves': range(20, 150),# 叶子节点数量'learning_rate': [0.01, 0.05, 0.1],# 学习率'feature_fraction': [0.6, 0.8, 1.0],# 特征采样比例'bagging_fraction': [0.6, 0.8, 1.0],# 数据采样比例'bagging_freq': range(0, 80),# 数据采样频率'verbose': [-1]# 是否显示训练过程中的详细信息,-1表示不显示}
[8]内:
# 初始化模型model = lgb.LGBMClassifier()# 使用随机搜索进行参数调优random_search = RandomizedSearchCV(estimator=model, param_distributinotallow=param_dist, # 参数组合 n_iter=100, cv=5, # 5折交叉验证 verbose=2, random_state=42, n_jobs=-1)# 模型训练random_search.fit(X_train, y_train)Fitting 5 folds for each of 100 candidates, totalling 500 fits
最適なパラメータを出力組み合わせ:
In [9]:
# 输出最佳参数print("Best parameters found: ", random_search.best_params_)Best parameters found:{'verbose': -1, 'objective': 'multiclass', 'num_leaves': 87, 'learning_rate': 0.05, 'feature_fraction': 0.6, 'boosting_type': 'gbdt', 'bagging_freq': 22, 'bagging_fraction': 0.6}
In [10]:
# 使用最佳参数训练模型best_model = random_search.best_estimator_best_model.fit(X_train, y_train)# 预测y_pred = best_model.predict(X_test)y_pred = [round(i) for i in y_pred]# 将概率转换为类别# 评估模型print('Accuracy: %.4f' % accuracy_score(y_test, y_pred))Accuracy: 0.9667
以上がLightGBM実戦+ランダム探索パラメータ調整:命中率96.67%の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。