Home  >  Article  >  Backend Development  >  Detailed explanation of how to quickly build a neural network with PyTorch and save and extract it

Detailed explanation of how to quickly build a neural network with PyTorch and save and extract it

不言
不言Original
2018-04-28 10:56:062517browse

This article mainly introduces PyTorch to quickly build a neural network and a detailed explanation of its saving and extraction methods. Now I share it with you and give it a reference. Let’s take a look together

Sometimes we have trained a model and want to save it for direct use next time without spending time training it again next time. In this section we will explain how to quickly build a neural network with PyTorch and its Detailed explanation of the saving and extraction method

1. How to quickly build a neural network with PyTorch

First look at the experimental code:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

I previously learned how to build a neural network by defining a Net class. In classNet, first inherit the construction method of the torch.nn.Module module through the super function, and then pass Build the structural information of each layer of the neural network by adding attributes, improve the connection information between each layer of the neural network in the forward method, and then complete the construction of the neural network structure by defining Net class objects.

Another way to build a neural network, which can also be said to be a quick construction method, is to directly complete the establishment of the neural network through torch.nn.Sequential.

The neural network structures constructed by the two methods are exactly the same, and the network information can be printed out through the print function, but the print results will be slightly different.

2. PyTorch neural network storage and extraction

When learning and researching deep learning, when we go through a certain period of training, When we get a better model, of course we want to save the model and model parameters for later use, so it is necessary to save the neural network and extract and reload the model parameters.

First of all, we need to save the network structure and model parameters through torch.save() after the definition and training part of the neural network that needs to save the network structure and its model parameters. There are two saving methods: one is to save the structural information and model parameter information of the entire neural network, and the object of save is the network net; the other is to save only the training model parameters of the neural network, and the object of save is net.state_dict(), The saved results are stored in the form of .pkl files.

Corresponds to the above two saving methods, and there are two reloading methods. Corresponding to the first complete network structure information, you can directly initialize the new neural network object through torch.load(‘.pkl’) when reloading. Corresponding to the second method, which only saves model parameter information, you need to first build the same neural network structure and complete the reloading of model parameters through net.load_state_dict(torch.load('.pkl')). When the network is relatively large, the first method will take more time.

Code implementation:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

Experimental results:

Related recommendations:

Method of implementing convolutional neural network CNN on PyTorch

Detailed explanation of PyTorch batch training and optimizer comparison

mnist classification example for getting started with Pytorch

The above is the detailed content of Detailed explanation of how to quickly build a neural network with PyTorch and save and extract it. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn