Heim >Backend-Entwicklung >Python-Tutorial >Ausführliche Erklärung, wie man mit PyTorch schnell ein neuronales Netzwerk aufbaut und speichert und extrahiert

Ausführliche Erklärung, wie man mit PyTorch schnell ein neuronales Netzwerk aufbaut und speichert und extrahiert

不言
不言Original
2018-04-28 10:56:062573Durchsuche

In diesem Artikel wird hauptsächlich der schnelle Aufbau eines neuronalen Netzwerks mit PyTorch vorgestellt und die Speicher- und Extraktionsmethoden ausführlich erläutert. Jetzt teile ich es mit Ihnen und gebe Ihnen eine Referenz. Werfen wir gemeinsam einen Blick darauf

Manchmal haben wir ein Modell trainiert und möchten es für die direkte Verwendung beim nächsten Mal speichern, ohne es beim nächsten Mal erneut trainieren zu müssen. In diesem Abschnitt erklären wir, wie wir schnell ein neuronales Netzwerk aufbauen können PyTorch und seine detaillierte Erklärung der Speicher- und Extraktionsmethode

1. PyTorch-Methode zum schnellen Aufbau eines neuronalen Netzwerks

Schauen Sie sich an Zuerst der experimentelle Code:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

Zuvor habe ich gelernt, wie man ein neuronales Netzwerk aufbaut, indem man in classNet zunächst die Konstruktionsmethode erbt Erstellen Sie die Strukturinformationen jeder Schicht des neuronalen Netzwerks durch Hinzufügen von Attributen, verbessern Sie die Verbindungsinformationen zwischen jeder Schicht des neuronalen Netzwerks in der Vorwärtsmethode und schließen Sie sie dann ab der Aufbau der neuronalen Netzwerkstruktur durch Definition von Net-Klassenobjekten.

Eine andere Möglichkeit zum Aufbau eines neuronalen Netzwerks, die auch als schnelle Konstruktionsmethode bezeichnet werden kann, besteht darin, den Aufbau des neuronalen Netzwerks direkt über Torch.nn.Sequential abzuschließen.

Die durch die beiden Methoden aufgebauten neuronalen Netzwerkstrukturen sind genau gleich, und die Netzwerkinformationen können über die Druckfunktion ausgedruckt werden, die Druckergebnisse unterscheiden sich jedoch geringfügig.

2. Erhaltung und Extraktion des neuronalen Netzwerks von PyTorch

Wenn wir tiefes Lernen lernen und erforschen, wenn wir eine bestimmte Trainingsphase durchlaufen, wann Wir erhalten ein besseres Modell. Natürlich möchten wir das Modell und die Modellparameter für die spätere Verwendung speichern. Daher ist es erforderlich, das neuronale Netzwerk zu speichern und die Modellparameter zu extrahieren und neu zu laden.

Zuerst müssen wir die Netzwerkstruktur und die Modellparameter über Torch.save() speichern, nachdem der Definitions- und Trainingsteil des neuronalen Netzwerks abgeschlossen ist, der die Netzwerkstruktur und seine Modellparameter speichern muss. Es gibt zwei Speichermethoden: Eine besteht darin, die Strukturinformationen und Modellparameterinformationen des gesamten neuronalen Netzwerks zu speichern, und die andere besteht darin, nur die Trainingsmodellparameter des neuronalen Netzwerks zu speichern Das Objekt des Speicherns ist net.state_dict(). Die gespeicherten Ergebnisse werden in Form von .pkl-Dateien gespeichert.

entspricht den beiden oben genannten Speichermethoden, und es gibt auch zwei Nachlademethoden. Entsprechend den ersten vollständigen Netzwerkstrukturinformationen können Sie das neue neuronale Netzwerkobjekt beim Neuladen direkt über Torch.load(‘.pkl’) initialisieren. Entsprechend der zweiten Methode, bei der nur Modellparameterinformationen gespeichert werden, müssen Sie zunächst dieselbe neuronale Netzwerkstruktur aufbauen und das Neuladen der Modellparameter über net.load_state_dict (torch.load ('.pkl')) abschließen. Wenn das Netzwerk relativ groß ist, dauert die erste Methode länger.

Code-Implementierung:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

Experimentelle Ergebnisse:

Verwandte Empfehlungen:

So implementieren Sie das Faltungs-Neuronale Netzwerk CNN auf PyTorch

Detaillierte Erläuterung des PyTorch-Batch-Trainings und des Optimierervergleichs

Mnist-Klassifizierungsbeispiel für den Einstieg in Pytorch

Das obige ist der detaillierte Inhalt vonAusführliche Erklärung, wie man mit PyTorch schnell ein neuronales Netzwerk aufbaut und speichert und extrahiert. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn