Heim >Technologie-Peripheriegeräte >KI >Eingehende Analyse der Kernpunkte von Pytorch, CNN-Entschlüsselung!

Eingehende Analyse der Kernpunkte von Pytorch, CNN-Entschlüsselung!

王林
王林nach vorne
2024-01-04 19:18:161349Durchsuche

Hallo, ich bin Xiaozhuang!

Anfänger sind möglicherweise nicht mit der Erstellung von Faltungs-Neuronalen Netzen (CNN) vertraut. Lassen Sie uns dies anhand eines vollständigen Falls veranschaulichen.

CNN ist ein Deep-Learning-Modell, das häufig bei der Bildklassifizierung, Zielerkennung, Bildgenerierung und anderen Aufgaben verwendet wird. Es extrahiert automatisch Merkmale von Bildern durch Faltungsschichten und Pooling-Schichten und führt die Klassifizierung durch vollständig verbundene Schichten durch. Der Schlüssel zu diesem Modell besteht darin, Faltungs- und Pooling-Operationen zu verwenden, um lokale Merkmale in Bildern effektiv zu erfassen und sie über mehrschichtige Netzwerke zu kombinieren, um eine erweiterte Merkmalsextraktion und Klassifizierung von Bildern zu erreichen.

Prinzip

1. Faltungsschicht:

Die Faltungsschicht extrahiert Merkmale aus dem Eingabebild durch Faltungsoperationen. Bei diesem Vorgang handelt es sich um einen lernbaren Faltungskern, der über das Eingabebild gleitet und das Skalarprodukt unter dem Schiebefenster berechnet. Dieser Prozess hilft bei der Extraktion lokaler Merkmale und verbessert so die Wahrnehmung der Übersetzungsinvarianz durch das Netzwerk.

Formel:

突破Pytorch核心点,CNN !!!

wobei x die Eingabe, w der Faltungskern und b der Bias ist.

2. Pooling-Schicht:

Die Pooling-Schicht ist eine häufig verwendete Dimensionsreduktionstechnologie. Ihre Funktion besteht darin, die räumliche Dimension der Daten zu reduzieren und dadurch den Berechnungsaufwand zu reduzieren und die wichtigsten Merkmale zu extrahieren. Unter diesen ist Max Pooling eine gängige Pooling-Methode, bei der der größte Wert in jedem Fenster als Vertreter ausgewählt wird. Durch maximales Pooling können wir die Komplexität der Daten reduzieren und die Recheneffizienz des Modells verbessern, während wichtige Informationen erhalten bleiben.

Formel (Max Pooling):

突破Pytorch核心点,CNN !!!

3. Vollständig verbundene Schicht:

Die vollständig verbundene Schicht spielt eine wichtige Rolle im neuronalen Netzwerk. Sie extrahiert die Faltungs- und Pooling-Feature-Maps . Jedes Neuron in der vollständig verbundenen Schicht ist mit allen Neuronen in der vorherigen Schicht verbunden, sodass eine Merkmalssynthese und -klassifizierung erreicht werden kann.

Praktische Schritte und detaillierte Erklärung

1. Schritte

  • Importieren Sie die erforderlichen Bibliotheken und Module.
  • Definieren Sie die Netzwerkstruktur: Verwenden Sie nn.Module, um eine davon geerbte benutzerdefinierte neuronale Netzwerkklasse zu definieren und die Faltungsschicht, die Aktivierungsfunktion, die Pooling-Schicht und die vollständig verbundene Schicht zu definieren.
  • Verlustfunktion und Optimierer definieren.
  • Daten laden und vorverarbeiten.
  • Trainieren Sie das Netzwerk: Trainieren Sie Netzwerkparameter mithilfe von Trainingsdaten iterativ.
  • Testnetzwerk: Verwenden Sie Testdaten, um die Modellleistung zu bewerten.

2. Code-Implementierung

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))

Dieses Beispiel zeigt ein einfaches CNN-Modell, trainiert und getestet mit dem MNIST-Datensatz.

Als nächstes fügen wir einen Visualisierungsschritt hinzu, um die Leistung und den Trainingsprozess des Modells intuitiver zu verstehen.

Visualisierung

1. Matplotlib importieren

import matplotlib.pyplot as plt

2. Zeichnen Sie Verlust und Genauigkeit während des Trainings auf:

Zeichnen Sie während der Trainingsschleife den Verlust und die Genauigkeit jeder Epoche auf.

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)

3. Verlust und Genauigkeit visualisieren:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

Auf diese Weise können wir die Veränderungen des Trainingsverlusts und der Genauigkeit nach dem Trainingsprozess sehen.

Nach dem Import des Codes können Sie den visuellen Inhalt und das Format nach Bedarf anpassen.

Das obige ist der detaillierte Inhalt vonEingehende Analyse der Kernpunkte von Pytorch, CNN-Entschlüsselung!. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen