>기술 주변기기 >일체 포함 >Pytorch의 핵심 핵심인 CNN 복호화를 심층 분석!

Pytorch의 핵심 핵심인 CNN 복호화를 심층 분석!

王林
王林앞으로
2024-01-04 19:18:161349검색

안녕하세요 샤오좡입니다!

초보자는 CNN(컨볼루션 신경망) 생성에 익숙하지 않을 수 있습니다. 아래의 전체 사례를 통해 설명해 보겠습니다.

CNN은 이미지 분류, 타겟 감지, 이미지 생성 및 기타 작업에 널리 사용되는 딥 러닝 모델입니다. Convolutional Layer와 Pooling Layer를 통해 자동으로 이미지의 특징을 추출하고, Fully Connected Layer를 통해 분류를 수행합니다. 이 모델의 핵심은 컨볼루션 및 풀링 작업을 사용하여 이미지의 로컬 특징을 효과적으로 캡처하고 다층 네트워크를 통해 이를 결합하여 고급 특징 추출 및 이미지 분류를 달성하는 것입니다.

Principle

1. Convolutional Layer:

Convolutional 레이어는 Convolution 연산을 통해 입력 이미지에서 특징을 추출합니다. 이 작업에는 입력 이미지 위로 이동하고 슬라이딩 창 아래에서 내적을 계산하는 학습 가능한 컨볼루션 커널이 포함됩니다. 이 프로세스는 로컬 특징을 추출하는 데 도움이 되므로 번역 불변성에 대한 네트워크의 인식이 향상됩니다.

공식:

突破Pytorch核心点,CNN !!!

여기서 x는 입력, w는 컨볼루션 커널, b는 편향입니다.

2. 풀링 레이어:

풀링 레이어는 일반적으로 사용되는 차원 축소 기술로, 그 기능은 데이터의 공간 차원을 줄여 계산량을 줄이고 가장 중요한 특징을 추출하는 것입니다. 그 중 맥스 풀링(max pooling)은 각 윈도우에서 가장 큰 값을 대표로 선택하는 일반적인 풀링 방법이다. Max Pooling을 통해 중요한 정보를 유지하면서 데이터의 복잡성을 줄이고 모델의 계산 효율성을 향상시킬 수 있습니다.

공식(최대 풀링):

突破Pytorch核心点,CNN !!!

3. 완전 연결 레이어:

완전 연결 레이어는 신경망에서 중요한 역할을 하며 기능 맵은 출력 카테고리에 연결됩니다. . 완전 연결 계층의 각 뉴런은 이전 계층의 모든 뉴런과 연결되므로 특징 합성 및 분류가 가능합니다.

실용 단계 및 자세한 설명

1. 단계

  • 필요한 라이브러리와 모듈을 가져옵니다.
  • 네트워크 구조 정의: nn.Module을 사용하여 상속된 사용자 정의 신경망 클래스를 정의하고 컨볼루션 계층, 활성화 함수, 풀링 계층 및 완전 연결 계층을 정의합니다.
  • 손실 함수 및 최적화 프로그램을 정의합니다.
  • 데이터 로드 및 전처리.
  • 네트워크 훈련: 훈련 데이터를 사용하여 네트워크 매개변수를 반복적으로 훈련합니다.
  • 테스트 네트워크: 테스트 데이터를 사용하여 모델 성능을 평가합니다.

2. 코드 구현

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))

이 예는 MNIST 데이터 세트를 사용하여 훈련 및 테스트된 간단한 CNN 모델을 보여줍니다.

다음으로 모델의 성능과 훈련 과정을 보다 직관적으로 이해하기 위해 시각화 단계를 추가합니다.

Visualization

1. matplotlib 가져오기

import matplotlib.pyplot as plt

2. 훈련 중 손실 및 정확도 기록:

훈련 루프 동안 각 에포크의 손실 및 정확도를 기록합니다.

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)

3. 손실 및 정확도 시각화:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

이러한 방식으로 훈련 과정 후 훈련 손실 및 정확도의 변화를 확인할 수 있습니다.

코드를 가져온 후 필요에 따라 시각적 콘텐츠와 형식을 조정할 수 있습니다.

위 내용은 Pytorch의 핵심 핵심인 CNN 복호화를 심층 분석!의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제