Maison  >  Article  >  développement back-end  >  Comment implémenter le réseau neuronal convolutif CNN sur PyTorch

Comment implémenter le réseau neuronal convolutif CNN sur PyTorch

不言
不言original
2018-04-28 10:02:422702parcourir

Cet article présente principalement la méthode d'implémentation du réseau neuronal convolutif CNN sur PyTorch. Je vais maintenant le partager avec vous et vous donner une référence. Jetons un coup d'oeil ensemble

1. Réseau neuronal convolutif

Le réseau neuronal convolutif (CNN) a été conçu à l'origine pour résoudre la reconnaissance d'images Conçu pour un tel problèmes, les applications actuelles de CNN ne se limitent pas aux images et aux vidéos, mais peuvent également être utilisées pour les signaux de séries chronologiques, tels que les signaux audio et les données textuelles. L'attrait initial de CNN en tant qu'architecture d'apprentissage profond est de réduire les exigences de prétraitement des données d'image et d'éviter l'ingénierie de fonctionnalités complexe. Dans un réseau neuronal convolutif, la première couche convolutive acceptera directement l'entrée au niveau des pixels de l'image. Chaque couche de convolution (filtre) extraira les caractéristiques les plus efficaces des données. Cette méthode peut extraire les caractéristiques les plus élémentaires de l'image. Les fonctionnalités sont ensuite combinées et abstraites pour former des fonctionnalités d'ordre supérieur, de sorte que CNN est théoriquement invariant à la mise à l'échelle, à la translation et à la rotation de l'image.

Les points clés du réseau neuronal convolutif CNN sont la connexion locale (LocalConnection), le partage de poids (WeightsSharing) et le sous-échantillonnage (Down-Sampling) dans la couche de pooling (Pooling). Parmi eux, les connexions locales et le partage du poids réduisent le nombre de paramètres, réduisent considérablement la complexité de l'entraînement et atténuent le surajustement. Dans le même temps, le partage de poids donne également au réseau convolutif une tolérance à la traduction, et le sous-échantillonnage des couches de pooling réduit encore la quantité de paramètres de sortie et donne au modèle une tolérance à une légère déformation, améliorant ainsi la capacité de généralisation du modèle. L'opération de convolution de la couche de convolution peut être comprise comme un processus d'extraction de caractéristiques similaires à plusieurs emplacements de l'image avec un petit nombre de paramètres.

2. Mise en œuvre du code

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = True 
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
plt.title('%i' % training_data.train_labels[0]) 
plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) 
# 取前2000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), 
         volatile=True).type(torch.FloatTensor)[:2000]/255 
# (2000, 28, 28) to (2000, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels[:2000] 
 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
print(cnn) 
''''' 
CNN ( 
 (conv1): Sequential ( 
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (conv2): Sequential ( 
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (out): Linear (1568 -> 10) 
) 
''' 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 
for epoch in range(EPOCH): 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x) 
    b_y = Variable(y) 
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      test_output = cnn(test_x) 
      pred_y = torch.max(test_output, 1)[1].data.squeeze() 
      accuracy = sum(pred_y == test_y) / test_y.size(0) 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0], '|test accuracy:%.4f'%accuracy) 
 
test_output = cnn(test_x[:10]) 
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() 
print(pred_y, 'prediction number') 
print(test_y[:10].numpy(), 'real number') 
''''' 
Epoch: 0 |Step: 0 |train loss:2.3145 |test accuracy:0.1040 
Epoch: 0 |Step: 100 |train loss:0.5857 |test accuracy:0.8865 
Epoch: 0 |Step: 200 |train loss:0.0600 |test accuracy:0.9380 
Epoch: 0 |Step: 300 |train loss:0.0996 |test accuracy:0.9345 
Epoch: 0 |Step: 400 |train loss:0.0381 |test accuracy:0.9645 
Epoch: 0 |Step: 500 |train loss:0.0266 |test accuracy:0.9620 
Epoch: 0 |Step: 600 |train loss:0.0973 |test accuracy:0.9685 
Epoch: 0 |Step: 700 |train loss:0.0421 |test accuracy:0.9725 
Epoch: 0 |Step: 800 |train loss:0.0654 |test accuracy:0.9710 
Epoch: 0 |Step: 900 |train loss:0.1333 |test accuracy:0.9740 
Epoch: 0 |Step: 1000 |train loss:0.0289 |test accuracy:0.9720 
Epoch: 0 |Step: 1100 |train loss:0.0429 |test accuracy:0.9770 
[7 2 1 0 4 1 4 9 5 9] prediction number 
[7 2 1 0 4 1 4 9 5 9] real number 
'''

3. Analyse et interprétation

En utilisant torchvision.datasets, vous pouvez obtenir rapidement des données au format ensemble de données qui peuvent être placées directement dans le DataLoader. obtenu via le contrôle des paramètres du train Il s'agit toujours d'un ensemble de données de test, ou il peut être directement converti dans le format de données requis pour la formation lorsqu'il est obtenu.

La construction d'un réseau neuronal convolutif est réalisée en définissant une classe CNN. Les couches convolutives conv1, conv2 et out sont définies sous forme d'attributs de classe. Les informations de connexion entre chaque couche sont définies en avant. Le défini Faites toujours attention au nombre de neurones dans chaque couche.

La structure du réseau de CNN est la suivante :

CNN (
 (conv1): Sequential (
  (0): Conv2d(1, 16,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (conv2): Sequential (
  (0): Conv2d(16, 32,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (out): Linear (1568 ->10)
)

On peut voir à travers des expériences que dans les résultats de formation d'EPOCH= 1, la précision de l'ensemble de test est Elle peut atteindre 97,7 %.

Recommandations associées :

Explication détaillée de la formation par lots PyTorch et de la comparaison des optimiseurs

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn