ホームページ  >  記事  >  テクノロジー周辺機器  >  Pytorch、CNN解読の核心ポイントを徹底分析!

Pytorch、CNN解読の核心ポイントを徹底分析!

王林
王林転載
2024-01-04 19:18:161251ブラウズ

こんにちは、Xiaozhuangです!

初心者は畳み込みニューラル ネットワーク (CNN) の作成に慣れていないかもしれません。以下の完全なケースで説明しましょう。

CNN は、画像分類、ターゲット検出、画像生成、その他のタスクで広く使用されている深層学習モデルです。畳み込み層とプーリング層を通じて画像の特徴を自動的に抽出し、全結合層を通じて分類を実行します。このモデルの鍵は、畳み込み操作とプーリング操作を使用して画像内の局所的な特徴を効果的にキャプチャし、それらを多層ネットワークを通じて組み合わせて、画像の高度な特徴抽出と分類を実現することです。

原理

1. 畳み込み層:

畳み込み層は、畳み込み演算を通じて入力画像から特徴を抽出します。この操作には、入力画像上をスライドし、スライディング ウィンドウの下でドット積を計算する学習可能な畳み込みカーネルが含まれます。このプロセスは局所的な特徴を抽出するのに役立ち、それによって翻訳の不変性に対するネットワークの認識が強化されます。

式:

突破Pytorch核心点,CNN !!!

ここで、x は入力、w はコンボリューション カーネル、b はバイアスです。

2. プーリング レイヤー:

プーリング レイヤーは、一般的に使用される次元削減テクノロジであり、その機能はデータの空間次元を削減し、それによって計算量を削減し、最大限のデータを抽出することです。重要な機能。このうち、最大プーリングは各ウィンドウの最大値を代表として選択するプーリング方法が一般的です。最大プーリングにより、重要な情報を保持しながら、データの複雑さを軽減し、モデルの計算効率を向上させることができます。

式 (最大プーリング):

突破Pytorch核心点,CNN !!!

3. 完全接続層:

完全接続層はニューラル ネットワーク内にあります。畳み込み層とプーリング層によって抽出された特徴マップを出力カテゴリに接続する際に重要な役割を果たします。全結合層の各ニューロンは前の層のすべてのニューロンに接続されているため、特徴の合成と分類を行うことができます。

実践的な手順と詳細な説明

1. 手順

  • 必要なライブラリとモジュールをインポートします。
  • ネットワーク構造の定義: nn.Module を使用して、そこから継承されたカスタム ニューラル ネットワーク クラスを定義し、畳み込み層、活性化関数、プーリング層、および全結合層を定義します。
  • 損失関数とオプティマイザーを定義します。
  • データのロードと前処理。
  • トレーニング ネットワーク: トレーニング データを使用して、ネットワーク パラメーターを反復的にトレーニングします。
  • テスト ネットワーク: テスト データを使用してモデルのパフォーマンスを評価します。

2. コードの実装

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))

この例は、MNIST データ セットを使用してトレーニングおよびテストされた単純な CNN モデルを示しています。

次に、モデルのパフォーマンスとトレーニング プロセスをより直観的に理解するために、視覚化ステップを追加します。

視覚化

1. matplotlibをインポート

import matplotlib.pyplot as plt

2. トレーニング中の損失と精度を記録します:

トレーニング ループで、各エポックの損失と精度を記録します。 。

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)

3. 損失​​と精度の視覚化:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

このようにして、トレーニング プロセス後のトレーニングの損失と精度の変化を確認できます。

コードをインポートした後、必要に応じてビジュアル コンテンツと形式を調整できます。

以上がPytorch、CNN解読の核心ポイントを徹底分析!の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事は51cto.comで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。