Rumah >Peranti teknologi >AI >Analisis mendalam tentang titik teras Pytorch, penyahsulitan CNN!

Analisis mendalam tentang titik teras Pytorch, penyahsulitan CNN!

王林
王林ke hadapan
2024-01-04 19:18:161361semak imbas

Helo, saya Xiaozhuang!

Pemula mungkin tidak biasa membuat rangkaian saraf konvolusi (CNN). Mari kita ilustrasikan dengan kes lengkap di bawah.

CNN ialah model pembelajaran mendalam yang digunakan secara meluas dalam pengelasan imej, pengesanan sasaran, penjanaan imej dan tugasan lain. Ia secara automatik mengekstrak ciri imej melalui lapisan konvolusi dan lapisan pengumpulan, dan melakukan pengelasan melalui lapisan bersambung sepenuhnya. Kunci kepada model ini ialah menggunakan operasi lilitan dan pengumpulan untuk menangkap ciri tempatan secara berkesan dalam imej dan menggabungkannya melalui rangkaian berbilang lapisan untuk mencapai pengekstrakan ciri lanjutan dan pengelasan imej.

Prinsip

1. Lapisan Konvolusi:

Lapisan konvolusi mengekstrak ciri daripada imej input melalui operasi konvolusi. Operasi ini melibatkan kernel lilitan yang boleh dipelajari yang meluncur ke atas imej input dan mengira produk titik di bawah tetingkap gelongsor. Proses ini membantu mengekstrak ciri tempatan, dengan itu meningkatkan persepsi rangkaian terhadap invarian terjemahan.

Formula:

突破Pytorch核心点,CNN !!!

di mana, x ialah input, w ialah isirong lilitan, dan b ialah bias.

2. Lapisan Pengumpulan:

Lapisan penyatuan ialah teknologi pengurangan dimensi yang biasa digunakan adalah untuk mengurangkan dimensi spatial data, dengan itu mengurangkan jumlah pengiraan dan mengekstrak ciri yang paling ketara. Antaranya, pengumpulan maks ialah kaedah pengumpulan biasa, yang memilih nilai terbesar dalam setiap tetingkap sebagai wakil. Melalui pengumpulan maksimum, kami boleh mengurangkan kerumitan data dan meningkatkan kecekapan pengiraan model sambil mengekalkan maklumat penting.

Formula (penghimpunan maksimum):

突破Pytorch核心点,CNN !!!

3 Lapisan Bersambung Sepenuhnya:

Lapisan yang bersambung sepenuhnya memainkan peranan penting dalam rangkaian saraf Ia mengekstrak lapisan konvolusi dan penyatuan disambungkan kepada lapisan keluaran . Setiap neuron dalam lapisan bersambung sepenuhnya disambungkan kepada semua neuron dalam lapisan sebelumnya, supaya sintesis dan pengelasan ciri boleh dicapai.

Langkah praktikal dan penerangan terperinci

1. Langkah

  • Import perpustakaan dan modul yang diperlukan.
  • Tentukan struktur rangkaian: Gunakan nn.Module untuk menentukan kelas rangkaian saraf tersuai yang diwarisi daripadanya, dan tentukan lapisan lilitan, fungsi pengaktifan, lapisan pengumpulan dan lapisan bersambung sepenuhnya.
  • Tentukan fungsi kehilangan dan pengoptimum.
  • Muat dan praproses data.
  • Latih rangkaian: Latih parameter rangkaian secara berulang menggunakan data latihan.
  • Rangkaian ujian: Gunakan data ujian untuk menilai prestasi model.

2. Pelaksanaan kod

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))

Contoh ini menunjukkan model CNN yang mudah, dilatih dan diuji menggunakan set data MNIST.

Seterusnya, kami menambah langkah visualisasi untuk memahami prestasi dan proses latihan model dengan lebih intuitif.

Visualisasi

1. Import matplotlib

import matplotlib.pyplot as plt

2 Rekod kehilangan dan ketepatan semasa latihan:

Semasa gelung latihan, rekodkan kehilangan dan ketepatan setiap zaman.

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)

3 Visualisasikan kehilangan dan ketepatan:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

Dengan cara ini, kita dapat melihat perubahan kehilangan dan ketepatan latihan selepas proses latihan.

Selepas mengimport kod, anda boleh melaraskan kandungan dan format visual mengikut keperluan.

Atas ialah kandungan terperinci Analisis mendalam tentang titik teras Pytorch, penyahsulitan CNN!. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:51cto.com. Jika ada pelanggaran, sila hubungi admin@php.cn Padam