Home >Technology peripherals >AI >In-depth analysis of the core points of Pytorch, CNN decryption!

In-depth analysis of the core points of Pytorch, CNN decryption!

王林
王林forward
2024-01-04 19:18:161349browse

Hello, I am Xiaozhuang!

Beginners may not be familiar with creating convolutional neural networks (CNN). Let’s illustrate it with a complete case below.

CNN is a deep learning model widely used in image classification, target detection, image generation and other tasks. It automatically extracts features of images through convolutional layers and pooling layers, and performs classification through fully connected layers. The key to this model is to use convolution and pooling operations to effectively capture local features in images and combine them through multi-layer networks to achieve advanced feature extraction and classification of images.

Principle

1. Convolutional Layer:

The convolutional layer extracts features from the input image through convolution operations. This operation involves a learnable convolution kernel that slides over the input image and computes the dot product under the sliding window. This process helps extract local features, thereby enhancing the network’s perception of translation invariance.

Formula:

突破Pytorch核心点,CNN !!!

Where, x is the input, w is the convolution kernel, and b is the bias.

2. Pooling Layer:

The pooling layer is a commonly used dimensionality reduction technology. Its function is to reduce the spatial dimension of the data, thereby reducing the amount of calculation and Extract the most significant features. Among them, max pooling is a common pooling method, which selects the largest value in each window as a representative. Through max pooling, we can reduce the complexity of the data and improve the computational efficiency of the model while retaining important information.

Formula (maximum pooling):

突破Pytorch核心点,CNN !!!

3. Fully Connected Layer:

The fully connected layer is in the neural network The network plays an important role in connecting the feature maps extracted by the convolution and pooling layers to the output categories. Each neuron in the fully connected layer is connected to all neurons in the previous layer, so that feature synthesis and classification can be achieved.

Practical steps and detailed explanation

1. Steps

  • Import the necessary libraries and modules.
  • Define the network structure: Use nn.Module to define a custom neural network class inherited from it, and define the convolution layer, activation function, pooling layer and fully connected layer.
  • Define the loss function and optimizer.
  • Loading and preprocessing data.
  • Training network: Use training data to iteratively train network parameters.
  • Test network: Use test data to evaluate model performance.

2. Code implementation

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))

This example shows a simple CNN model, trained and tested using the MNIST data set.

Next, we add visualization steps to understand the performance and training process of the model more intuitively.

Visualization

1. Import matplotlib

import matplotlib.pyplot as plt

2. Record loss and accuracy during training:

In the training loop, record each Epoch loss and accuracy.

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)

3. Visualize loss and accuracy:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

In this way, we can see the changes in training loss and accuracy after the training process.

After importing the code, you can adjust the visual content and format as needed.

The above is the detailed content of In-depth analysis of the core points of Pytorch, CNN decryption!. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete