search
HomeBackend DevelopmentPython TutorialHow to implement convolutional neural network CNN on PyTorch

How to implement convolutional neural network CNN on PyTorch

Apr 28, 2018 am 10:02 AM
pytorchmethodNeural Networks

This article mainly introduces the method of implementing convolutional neural network CNN on PyTorch. Now I will share it with you and give you a reference. Let’s take a look together

1. Convolutional Neural Network

Convolutional Neural Network (Convolutional Neural Network, CNN) was originally designed to solve image recognition Designed for such problems, the current applications of CNN are not limited to images and videos, but can also be used for time series signals, such as audio signals and text data. The initial appeal of CNN as a deep learning architecture is to reduce the requirements for image data preprocessing and avoid complex feature engineering. In a convolutional neural network, the first convolutional layer will directly accept the pixel-level input of the image. Each layer of convolution (filter) will extract the most effective features in the data. This method can extract the most basic features of the image. Features are then combined and abstracted to form higher-order features, so CNN is theoretically invariant to image scaling, translation and rotation.

The key points of convolutional neural network CNN are local connection (LocalConnection), weight sharing (WeightsSharing) and down-sampling (Down-Sampling) in the pooling layer (Pooling). Among them, local connections and weight sharing reduce the amount of parameters, greatly reduce the training complexity and alleviate over-fitting. At the same time, weight sharing also gives the convolutional network tolerance to translation, and pooling layer downsampling further reduces the amount of output parameters and gives the model tolerance to mild deformation, improving the generalization ability of the model. The convolution operation of the convolution layer can be understood as the process of extracting similar features at multiple locations in the image with a small number of parameters.

2. Code implementation

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = True 
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
plt.title('%i' % training_data.train_labels[0]) 
plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) 
# 取前2000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), 
         volatile=True).type(torch.FloatTensor)[:2000]/255 
# (2000, 28, 28) to (2000, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels[:2000] 
 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
print(cnn) 
''''' 
CNN ( 
 (conv1): Sequential ( 
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (conv2): Sequential ( 
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (out): Linear (1568 -> 10) 
) 
''' 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 
for epoch in range(EPOCH): 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x) 
    b_y = Variable(y) 
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      test_output = cnn(test_x) 
      pred_y = torch.max(test_output, 1)[1].data.squeeze() 
      accuracy = sum(pred_y == test_y) / test_y.size(0) 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0], '|test accuracy:%.4f'%accuracy) 
 
test_output = cnn(test_x[:10]) 
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() 
print(pred_y, 'prediction number') 
print(test_y[:10].numpy(), 'real number') 
''''' 
Epoch: 0 |Step: 0 |train loss:2.3145 |test accuracy:0.1040 
Epoch: 0 |Step: 100 |train loss:0.5857 |test accuracy:0.8865 
Epoch: 0 |Step: 200 |train loss:0.0600 |test accuracy:0.9380 
Epoch: 0 |Step: 300 |train loss:0.0996 |test accuracy:0.9345 
Epoch: 0 |Step: 400 |train loss:0.0381 |test accuracy:0.9645 
Epoch: 0 |Step: 500 |train loss:0.0266 |test accuracy:0.9620 
Epoch: 0 |Step: 600 |train loss:0.0973 |test accuracy:0.9685 
Epoch: 0 |Step: 700 |train loss:0.0421 |test accuracy:0.9725 
Epoch: 0 |Step: 800 |train loss:0.0654 |test accuracy:0.9710 
Epoch: 0 |Step: 900 |train loss:0.1333 |test accuracy:0.9740 
Epoch: 0 |Step: 1000 |train loss:0.0289 |test accuracy:0.9720 
Epoch: 0 |Step: 1100 |train loss:0.0429 |test accuracy:0.9770 
[7 2 1 0 4 1 4 9 5 9] prediction number 
[7 2 1 0 4 1 4 9 5 9] real number 
'''

## 3. Analysis and Interpretation

By using torchvision.datasets, you can quickly obtain data in dataset format that can be placed directly in DataLoader. Use the train parameter to control whether to obtain a training data set or a test The data set can also be directly converted into the data format required for training when obtained.

The construction of a convolutional neural network is achieved by defining a CNN class. The convolutional layers conv1, conv2 and out layers are defined in the form of class attributes. The connection information between each layer is defined in forward. The defined Always pay attention to the number of neurons in each layer.

The network structure of CNN is as follows:

CNN (
 (conv1): Sequential (
  (0): Conv2d(1, 16,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (conv2): Sequential (
  (0): Conv2d(16, 32,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (out): Linear (1568 ->10)
)

It can be seen through experiments that in the training results of EPOCH=1, the test set accuracy can reach 97.7%.


Related recommendations:


Detailed explanation of PyTorch batch training and optimizer comparison

The above is the detailed content of How to implement convolutional neural network CNN on PyTorch. For more information, please follow other related articles on the PHP Chinese website!

Statement
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Python: A Deep Dive into Compilation and InterpretationPython: A Deep Dive into Compilation and InterpretationMay 12, 2025 am 12:14 AM

Pythonusesahybridmodelofcompilationandinterpretation:1)ThePythoninterpretercompilessourcecodeintoplatform-independentbytecode.2)ThePythonVirtualMachine(PVM)thenexecutesthisbytecode,balancingeaseofusewithperformance.

Is Python an interpreted or a compiled language, and why does it matter?Is Python an interpreted or a compiled language, and why does it matter?May 12, 2025 am 12:09 AM

Pythonisbothinterpretedandcompiled.1)It'scompiledtobytecodeforportabilityacrossplatforms.2)Thebytecodeistheninterpreted,allowingfordynamictypingandrapiddevelopment,thoughitmaybeslowerthanfullycompiledlanguages.

For Loop vs While Loop in Python: Key Differences ExplainedFor Loop vs While Loop in Python: Key Differences ExplainedMay 12, 2025 am 12:08 AM

Forloopsareidealwhenyouknowthenumberofiterationsinadvance,whilewhileloopsarebetterforsituationswhereyouneedtoloopuntilaconditionismet.Forloopsaremoreefficientandreadable,suitableforiteratingoversequences,whereaswhileloopsoffermorecontrolandareusefulf

For and While loops: a practical guideFor and While loops: a practical guideMay 12, 2025 am 12:07 AM

Forloopsareusedwhenthenumberofiterationsisknowninadvance,whilewhileloopsareusedwhentheiterationsdependonacondition.1)Forloopsareidealforiteratingoversequenceslikelistsorarrays.2)Whileloopsaresuitableforscenarioswheretheloopcontinuesuntilaspecificcond

Python: Is it Truly Interpreted? Debunking the MythsPython: Is it Truly Interpreted? Debunking the MythsMay 12, 2025 am 12:05 AM

Pythonisnotpurelyinterpreted;itusesahybridapproachofbytecodecompilationandruntimeinterpretation.1)Pythoncompilessourcecodeintobytecode,whichisthenexecutedbythePythonVirtualMachine(PVM).2)Thisprocessallowsforrapiddevelopmentbutcanimpactperformance,req

Python concatenate lists with same elementPython concatenate lists with same elementMay 11, 2025 am 12:08 AM

ToconcatenatelistsinPythonwiththesameelements,use:1)the operatortokeepduplicates,2)asettoremoveduplicates,or3)listcomprehensionforcontroloverduplicates,eachmethodhasdifferentperformanceandorderimplications.

Interpreted vs Compiled Languages: Python's PlaceInterpreted vs Compiled Languages: Python's PlaceMay 11, 2025 am 12:07 AM

Pythonisaninterpretedlanguage,offeringeaseofuseandflexibilitybutfacingperformancelimitationsincriticalapplications.1)InterpretedlanguageslikePythonexecuteline-by-line,allowingimmediatefeedbackandrapidprototyping.2)CompiledlanguageslikeC/C transformt

For and While loops: when do you use each in python?For and While loops: when do you use each in python?May 11, 2025 am 12:05 AM

Useforloopswhenthenumberofiterationsisknowninadvance,andwhileloopswheniterationsdependonacondition.1)Forloopsareidealforsequenceslikelistsorranges.2)Whileloopssuitscenarioswheretheloopcontinuesuntilaspecificconditionismet,usefulforuserinputsoralgorit

See all articles

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

Video Face Swap

Video Face Swap

Swap faces in any video effortlessly with our completely free AI face swap tool!

Hot Article

Hot Tools

Notepad++7.3.1

Notepad++7.3.1

Easy-to-use and free code editor

SublimeText3 Chinese version

SublimeText3 Chinese version

Chinese version, very easy to use

Zend Studio 13.0.1

Zend Studio 13.0.1

Powerful PHP integrated development environment

SublimeText3 Linux new version

SublimeText3 Linux new version

SublimeText3 Linux latest version

WebStorm Mac version

WebStorm Mac version

Useful JavaScript development tools