Rumah > Artikel > Peranti teknologi > Masalah overfitting dalam algoritma pembelajaran mesin
Masalah terlalu muat dalam algoritma pembelajaran mesin memerlukan contoh kod khusus
Dalam bidang pembelajaran mesin, masalah terlalu sesuai model adalah salah satu cabaran biasa. Apabila model mengatasi data latihan, ia menjadi terlalu sensitif kepada hingar dan outlier, menyebabkan model berprestasi buruk pada data baharu. Untuk menyelesaikan masalah over-fitting, kita perlu mengambil beberapa kaedah yang berkesan semasa proses latihan model.
Pendekatan biasa adalah menggunakan teknik regularization seperti regularization L1 dan regularization L2. Teknik ini mengehadkan kerumitan model dengan menambahkan istilah penalti untuk mengelakkan model daripada terlampau pasang. Yang berikut menggunakan contoh kod khusus untuk menggambarkan cara menggunakan regularization L2 untuk menyelesaikan masalah overfitting.
Kami akan menggunakan bahasa Python dan perpustakaan pembelajaran Scikit untuk melaksanakan model regresi. Pertama, kami perlu mengimport perpustakaan yang diperlukan:
import numpy as np from sklearn.linear_model import Ridge from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error
Seterusnya, kami mencipta set data palsu dengan 10 ciri dan pembolehubah sasaran. Ambil perhatian bahawa kami mensimulasikan data dunia sebenar dengan menambahkan beberapa hingar rawak:
np.random.seed(0) n_samples = 1000 n_features = 10 X = np.random.randn(n_samples, n_features) y = np.random.randn(n_samples) + 2*X[:, 0] + 3*X[:, 1] + np.random.randn(n_samples)*0.5
Kemudian, kami membahagikan set data kepada set latihan dan ujian:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
Kini, kami boleh mencipta model regresi rabung dan menetapkan nilai regularisasi parameter alfa :
model = Ridge(alpha=0.1)
Seterusnya, kami menggunakan set latihan untuk melatih model:
model.fit(X_train, y_train)
Selepas latihan selesai, kami boleh menggunakan set ujian untuk menilai prestasi model:
y_pred = model.predict(X_test) mse = mean_squared_error(y_test, y_pred) print("Mean squared error: ", mse)
Dalam contoh ini, kami menggunakan model regresi rabung , dan tetapkan parameter penaturan alfa kepada 0.1. Dengan menggunakan regularisasi L2, kerumitan model adalah terhad untuk membuat generalisasi yang lebih baik kepada data baharu. Semasa menilai prestasi model, kami mengira ralat kuasa dua min, yang menerangkan perbezaan antara nilai ramalan dan nilai sebenar.
Dengan melaraskan nilai alfa parameter regularisasi, kami boleh mengoptimumkan prestasi model. Apabila nilai alfa kecil, model akan cenderung untuk melebihkan data latihan apabila nilai alfa besar, model akan cenderung tidak sesuai. Dalam amalan, kami biasanya memilih nilai alfa yang optimum melalui pengesahan silang.
Untuk meringkaskan, masalah overfitting ialah cabaran biasa dalam pembelajaran mesin. Dengan menggunakan teknik regularization, seperti regularization L2, kita boleh mengehadkan kerumitan model untuk mengelakkan model daripada overfitting data latihan. Contoh kod di atas menunjukkan cara menggunakan model regresi rabung dan regularisasi L2 untuk menyelesaikan masalah overfitting. Semoga contoh ini akan membantu pembaca lebih memahami dan menggunakan teknik regularisasi.
Atas ialah kandungan terperinci Masalah overfitting dalam algoritma pembelajaran mesin. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!