


Detailed explanation of how to quickly build a neural network with PyTorch and save and extract it
This article mainly introduces PyTorch to quickly build a neural network and a detailed explanation of its saving and extraction methods. Now I share it with you and give it a reference. Let’s take a look together
Sometimes we have trained a model and want to save it for direct use next time without spending time training it again next time. In this section we will explain how to quickly build a neural network with PyTorch and its Detailed explanation of the saving and extraction method
1. How to quickly build a neural network with PyTorch
First look at the experimental code:
import torch import torch.nn.functional as F # 方法1,通过定义一个Net类来建立神经网络 class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net1 = Net(2, 10, 2) print('方法1:\n', net1) # 方法2 通过torch.nn.Sequential快速建立神经网络结构 net2 = torch.nn.Sequential( torch.nn.Linear(2, 10), torch.nn.ReLU(), torch.nn.Linear(10, 2), ) print('方法2:\n', net2) # 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 ''''' 方法1: Net ( (hidden): Linear (2 -> 10) (predict): Linear (10 -> 2) ) 方法2: Sequential ( (0): Linear (2 -> 10) (1): ReLU () (2): Linear (10 -> 2) ) '''
I previously learned how to build a neural network by defining a Net class. In classNet, first inherit the construction method of the torch.nn.Module module through the super function, and then pass Build the structural information of each layer of the neural network by adding attributes, improve the connection information between each layer of the neural network in the forward method, and then complete the construction of the neural network structure by defining Net class objects.
Another way to build a neural network, which can also be said to be a quick construction method, is to directly complete the establishment of the neural network through torch.nn.Sequential.
The neural network structures constructed by the two methods are exactly the same, and the network information can be printed out through the print function, but the print results will be slightly different.
2. PyTorch neural network storage and extraction
When learning and researching deep learning, when we go through a certain period of training, When we get a better model, of course we want to save the model and model parameters for later use, so it is necessary to save the neural network and extract and reload the model parameters.
First of all, we need to save the network structure and model parameters through torch.save() after the definition and training part of the neural network that needs to save the network structure and its model parameters. There are two saving methods: one is to save the structural information and model parameter information of the entire neural network, and the object of save is the network net; the other is to save only the training model parameters of the neural network, and the object of save is net.state_dict(), The saved results are stored in the form of .pkl files.
Corresponds to the above two saving methods, and there are two reloading methods. Corresponding to the first complete network structure information, you can directly initialize the new neural network object through torch.load(‘.pkl’) when reloading. Corresponding to the second method, which only saves model parameter information, you need to first build the same neural network structure and complete the reloading of model parameters through net.load_state_dict(torch.load('.pkl')). When the network is relatively large, the first method will take more time.
Code implementation:
import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed(1) # 设定随机数种子 # 创建数据 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) y = x.pow(2) + 0.2*torch.rand(x.size()) x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) # 将待保存的神经网络定义在一个函数中 def save(): # 神经网络结构 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_function = torch.nn.MSELoss() # 训练部分 for i in range(300): prediction = net1(x) loss = loss_function(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() # 绘图部分 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('net1') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 保存神经网络 torch.save(net1, '7-net.pkl') # 保存整个神经网络的结构和模型参数 torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 # 载入整个神经网络的结构及其模型参数 def reload_net(): net2 = torch.load('7-net.pkl') prediction = net2(x) plt.subplot(132) plt.title('net2') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 def reload_params(): # 首先搭建相同的神经网络结构 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) # 载入神经网络的模型参数 net3.load_state_dict(torch.load('7-net_params.pkl')) prediction = net3(x) plt.subplot(133) plt.title('net3') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 运行测试 save() reload_net() reload_params()
Experimental results:
Related recommendations:
Method of implementing convolutional neural network CNN on PyTorch
Detailed explanation of PyTorch batch training and optimizer comparison
mnist classification example for getting started with Pytorch
The above is the detailed content of Detailed explanation of how to quickly build a neural network with PyTorch and save and extract it. For more information, please follow other related articles on the PHP Chinese website!

Python excels in automation, scripting, and task management. 1) Automation: File backup is realized through standard libraries such as os and shutil. 2) Script writing: Use the psutil library to monitor system resources. 3) Task management: Use the schedule library to schedule tasks. Python's ease of use and rich library support makes it the preferred tool in these areas.

To maximize the efficiency of learning Python in a limited time, you can use Python's datetime, time, and schedule modules. 1. The datetime module is used to record and plan learning time. 2. The time module helps to set study and rest time. 3. The schedule module automatically arranges weekly learning tasks.

Python excels in gaming and GUI development. 1) Game development uses Pygame, providing drawing, audio and other functions, which are suitable for creating 2D games. 2) GUI development can choose Tkinter or PyQt. Tkinter is simple and easy to use, PyQt has rich functions and is suitable for professional development.

Python is suitable for data science, web development and automation tasks, while C is suitable for system programming, game development and embedded systems. Python is known for its simplicity and powerful ecosystem, while C is known for its high performance and underlying control capabilities.

You can learn basic programming concepts and skills of Python within 2 hours. 1. Learn variables and data types, 2. Master control flow (conditional statements and loops), 3. Understand the definition and use of functions, 4. Quickly get started with Python programming through simple examples and code snippets.

Python is widely used in the fields of web development, data science, machine learning, automation and scripting. 1) In web development, Django and Flask frameworks simplify the development process. 2) In the fields of data science and machine learning, NumPy, Pandas, Scikit-learn and TensorFlow libraries provide strong support. 3) In terms of automation and scripting, Python is suitable for tasks such as automated testing and system management.

You can learn the basics of Python within two hours. 1. Learn variables and data types, 2. Master control structures such as if statements and loops, 3. Understand the definition and use of functions. These will help you start writing simple Python programs.

How to teach computer novice programming basics within 10 hours? If you only have 10 hours to teach computer novice some programming knowledge, what would you choose to teach...


Hot AI Tools

Undresser.AI Undress
AI-powered app for creating realistic nude photos

AI Clothes Remover
Online AI tool for removing clothes from photos.

Undress AI Tool
Undress images for free

Clothoff.io
AI clothes remover

AI Hentai Generator
Generate AI Hentai for free.

Hot Article

Hot Tools

Atom editor mac version download
The most popular open source editor

MinGW - Minimalist GNU for Windows
This project is in the process of being migrated to osdn.net/projects/mingw, you can continue to follow us there. MinGW: A native Windows port of the GNU Compiler Collection (GCC), freely distributable import libraries and header files for building native Windows applications; includes extensions to the MSVC runtime to support C99 functionality. All MinGW software can run on 64-bit Windows platforms.

EditPlus Chinese cracked version
Small size, syntax highlighting, does not support code prompt function

Dreamweaver Mac version
Visual web development tools

Notepad++7.3.1
Easy-to-use and free code editor