Rumah > Artikel > Peranti teknologi > Apakah yang boleh diberitahu oleh graf metrik latihan dan pengesahan dalam pembelajaran mesin?
Dalam artikel ini, kami akan meringkaskan kemungkinan situasi latihan dan pengesahan serta memperkenalkan jenis maklumat yang boleh diberikan oleh carta ini kepada kami.
Mari mulakan dengan beberapa kod mudah Kod berikut mewujudkan rangka kerja proses latihan asas.
from sklearn.model_selection import train_test_split<br>from sklearn.datasets import make_classification<br>import torch<br>from torch.utils.data import Dataset, DataLoader<br>import torch.optim as torch_optim<br>import torch.nn as nn<br>import torch.nn.functional as F<br>import numpy as np<br>import matplotlib.pyplot as pltclass MyCustomDataset(Dataset):<br>def __init__(self, X, Y, scale=False):<br>self.X = torch.from_numpy(X.astype(np.float32))<br>self.y = torch.from_numpy(Y.astype(np.int64))<br><br>def __len__(self):<br>return len(self.y)<br><br>def __getitem__(self, idx):<br>return self.X[idx], self.y[idx]def get_optimizer(model, lr=0.001, wd=0.0):<br>parameters = filter(lambda p: p.requires_grad, model.parameters())<br>optim = torch_optim.Adam(parameters, lr=lr, weight_decay=wd)<br>return optimdef train_model(model, optim, train_dl, loss_func):<br># Ensure the model is in Training mode<br>model.train()<br>total = 0<br>sum_loss = 0<br>for x, y in train_dl:<br>batch = y.shape[0]<br># Train the model for this batch worth of data<br>logits = model(x)<br># Run the loss function. We will decide what this will be when we call our Training Loop<br>loss = loss_func(logits, y)<br># The next 3 lines do all the PyTorch back propagation goodness<br>optim.zero_grad()<br>loss.backward()<br>optim.step()<br># Keep a running check of our total number of samples in this epoch<br>total += batch<br># And keep a running total of our loss<br>sum_loss += batch*(loss.item())<br>return sum_loss/total<br>def train_loop(model, train_dl, valid_dl, epochs, loss_func, lr=0.1, wd=0):<br>optim = get_optimizer(model, lr=lr, wd=wd)<br>train_loss_list = []<br>val_loss_list = []<br>acc_list = []<br>for i in range(epochs): <br>loss = train_model(model, optim, train_dl, loss_func)<br># After training this epoch, keep a list of progress of <br># the loss of each epoch <br>train_loss_list.append(loss)<br>val, acc = val_loss(model, valid_dl, loss_func)<br># Likewise for the validation loss and accuracy<br>val_loss_list.append(val)<br>acc_list.append(acc)<br>print("training loss: %.5f valid loss: %.5f accuracy: %.5f" % (loss, val, acc))<br><br>return train_loss_list, val_loss_list, acc_list<br>def val_loss(model, valid_dl, loss_func):<br># Put the model into evaluation mode, not training mode<br>model.eval()<br>total = 0<br>sum_loss = 0<br>correct = 0<br>batch_count = 0<br>for x, y in valid_dl:<br>batch_count += 1<br>current_batch_size = y.shape[0]<br>logits = model(x)<br>loss = loss_func(logits, y)<br>sum_loss += current_batch_size*(loss.item())<br>total += current_batch_size<br># All of the code above is the same, in essence, to<br># Training, so see the comments there<br># Find out which of the returned predictions is the loudest<br># of them all, and that's our prediction(s)<br>preds = logits.sigmoid().argmax(1)<br># See if our predictions are right<br>correct += (preds == y).float().mean().item()<br>return sum_loss/total, correct/batch_count<br>def view_results(train_loss_list, val_loss_list, acc_list):<br>plt.rcParams["figure.figsize"] = (15, 5)<br>plt.figure()<br>epochs = np.arange(0, len(train_loss_list)) plt.subplot(1, 2, 1)<br>plt.plot(epochs-0.5, train_loss_list)<br>plt.plot(epochs, val_loss_list)<br>plt.title('model loss')<br>plt.ylabel('loss')<br>plt.xlabel('epoch')<br>plt.legend(['train', 'val', 'acc'], loc = 'upper left')<br><br>plt.subplot(1, 2, 2)<br>plt.plot(acc_list)<br>plt.title('accuracy')<br>plt.ylabel('accuracy')<br>plt.xlabel('epoch')<br>plt.legend(['train', 'val', 'acc'], loc = 'upper left')<br>plt.show()<br><br>def get_data_train_and_show(model, batch_size=128, n_samples=10000, n_classes=2, n_features=30, val_size=0.2, epochs=20, lr=0.1, wd=0, break_it=False):<br># We'll make a fictitious dataset, assuming all relevant<br># EDA / Feature Engineering has been done and this is our <br># resultant data<br>X, y = make_classification(n_samples=n_samples, n_classes=n_classes, n_features=n_features, n_informative=n_features, n_redundant=0, random_state=1972)<br><br>if break_it: # Specifically mess up the data<br>X = np.random.rand(n_samples,n_features)<br>X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_size, random_state=1972) train_ds = MyCustomDataset(X_train, y_train)<br>valid_ds = MyCustomDataset(X_val, y_val)<br>train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)<br>valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=True) train_loss_list, val_loss_list, acc_list = train_loop(model, train_dl, valid_dl, epochs=epochs, loss_func=F.cross_entropy, lr=lr, wd=wd)<br>view_results(train_loss_list, val_loss_list, acc_list)
Kod di atas sangat mudah, ia merupakan proses asas untuk mendapatkan data, latihan dan pengesahan.
Model Kehilangan kereta api perlahan-lahan berkurangan tanpa mengira hiperparameter, tetapi kehilangan Val tidak berkurangan dan Ketepatannya tidak menunjukkan bahawa ia belajar apa sahaja.
Sebagai contoh, dalam kes ini, ketepatan klasifikasi binari berlegar sekitar 50%.
class Scenario_1_Model_1(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, out_features)<br>def forward(self, x):<br>x = self.lin1(x)<br>return x<br>get_data_train_and_show(Scenario_1_Model_1(), lr=0.001, break_it=True)
Tiada maklumat yang mencukupi dalam data untuk membenarkan 'pembelajaran' dan data latihan mungkin tidak mengandungi maklumat yang mencukupi untuk membenarkan model 'belajar'.
Dalam kes ini (data rawak semasa melatih data dalam kod), ini bermakna ia tidak dapat mempelajari apa-apa perkara yang penting.
Data mesti mempunyai maklumat yang mencukupi untuk dipelajari. EDA dan kejuruteraan ciri adalah kunci! Model mempelajari perkara yang boleh dipelajari, bukannya mengada-adakan perkara yang tidak wujud.
Sebagai contoh, kod berikut: lr=0.1, bs=128
class Scenario_2_Model_1(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, out_features)<br>def forward(self, x):<br>x = self.lin1(x)<br>return x<br>get_data_train_and_show(Scenario_2_Model_1(), lr=0.1)
"Kadar pembelajaran terlalu tinggi" atau "batch terlalu kecil" Anda boleh cuba mengurangkan kadar pembelajaran daripada 0.1 kepada 0.001, yang bermaksud ia tidak akan "melantun", tetapi akan menurun dengan lancar.
get_data_train_and_show(Scenario_1_Model_1(), lr=0.001)
Selain menurunkan kadar pembelajaran, meningkatkan saiz kelompok juga akan menjadikannya lebih lancar.
get_data_train_and_show(Scenario_1_Model_1(), lr=0.001, batch_size=256)
class Scenario_3_Model_1(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, 50)<br>self.lin2 = nn.Linear(50, 150)<br>self.lin3 = nn.Linear(150, 50)<br>self.lin4 = nn.Linear(50, out_features)<br>def forward(self, x):<br>x = F.relu(self.lin1(x))<br>x = F.relu(self.lin2(x))<br>x = F.relu(self.lin3(x))<br>x = self.lin4(x)<br>return x<br>get_data_train_and_show(Scenario_3_Model_1(), lr=0.001)
Ini sudah pasti overfitting: kehilangan latihan adalah rendah dan ketepatan adalah tinggi, manakala kehilangan pengesahan dan kehilangan latihan semakin besar dan lebih besar, yang merupakan penunjuk overfitting klasik.
Pada asasnya, keupayaan pembelajaran model anda terlalu kuat. Ia mengingati data latihan terlalu baik, yang bermaksud ia juga tidak boleh digeneralisasikan kepada data baharu.
Perkara pertama yang boleh kita cuba ialah mengurangkan kerumitan model.
class Scenario_3_Model_2(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, 50)<br>self.lin2 = nn.Linear(50, out_features)<br>def forward(self, x):<br>x = F.relu(self.lin1(x))<br>x = self.lin2(x)<br>return x<br>get_data_train_and_show(Scenario_3_Model_2(), lr=0.001)
Ini menjadikannya lebih baik, anda juga boleh memperkenalkan penyelarasan pereputan berat L2 untuk menjadikannya lebih baik semula (untuk model yang lebih cetek ).
get_data_train_and_show(Scenario_3_Model_2(), lr=0.001, wd=0.02)
Jika kita ingin mengekalkan kedalaman dan saiz model, kita boleh cuba menggunakan dropout (untuk model yang lebih mendalam).
class Scenario_3_Model_3(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, 50)<br>self.lin2 = nn.Linear(50, 150)<br>self.lin3 = nn.Linear(150, 50)<br>self.lin4 = nn.Linear(50, out_features)<br>self.drops = nn.Dropout(0.4)<br>def forward(self, x):<br>x = F.relu(self.lin1(x))<br>x = self.drops(x)<br>x = F.relu(self.lin2(x))<br>x = self.drops(x)<br>x = F.relu(self.lin3(x))<br>x = self.drops(x)<br>x = self.lin4(x)<br>return x<br>get_data_train_and_show(Scenario_3_Model_3(), lr=0.001)
lr = 0.001,bs = 128(默认,分类类别= 5
class Scenario_4_Model_1(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, 2)<br>self.lin2 = nn.Linear(2, out_features)<br>def forward(self, x):<br>x = F.relu(self.lin1(x))<br>x = self.lin2(x)<br>return x<br>get_data_train_and_show(Scenario_4_Model_1(out_features=5), lr=0.001, n_classes=5)
没有足够的学习能力:模型中的其中一层的参数少于模型可能输出中的类。 在这种情况下,当有 5 个可能的输出类时,中间的参数只有 2 个。
这意味着模型会丢失信息,因为它不得不通过一个较小的层来填充它,因此一旦层的参数再次扩大,就很难恢复这些信息。
所以需要记录层的参数永远不要小于模型的输出大小。
class Scenario_4_Model_2(nn.Module):<br>def __init__(self, in_features=30, out_features=2):<br>super().__init__()<br>self.lin1 = nn.Linear(in_features, 50)<br>self.lin2 = nn.Linear(50, out_features)<br>def forward(self, x):<br>x = F.relu(self.lin1(x))<br>x = self.lin2(x)<br>return x<br>get_data_train_and_show(Scenario_4_Model_2(out_features=5), lr=0.001, n_classes=5)
以上就是一些常见的训练、验证时的曲线的示例,希望你在遇到相同情况时可以快速定位并且改进。
Atas ialah kandungan terperinci Apakah yang boleh diberitahu oleh graf metrik latihan dan pengesahan dalam pembelajaran mesin?. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!