哈嘍,我是小壯!
初學者對於創建卷積神經網路(CNN)可能不太熟悉,下面我們以一個完整的案例來進行說明。
CNN是廣泛應用於影像分類、目標偵測、影像生成等任務的深度學習模型。它透過卷積層和池化層自動提取影像的特徵,並透過全連接層進行分類。這種模型的關鍵在於利用捲積和池化的操作,有效地捕捉影像中的局部特徵,並透過多層網路進行組合,從而實現對影像的高級特徵提取和分類。
原理
1.卷積層(Convolutional Layer):
#卷積層透過卷積操作來提取輸入影像中的特徵。這個操作涉及一個可學習的捲積核,它在輸入影像上滑動併計算滑動視窗下的點積。這個過程有助於提取局部特徵,從而增強網路對平移不變性的感知能力。
公式:
其中,x是輸入,w是卷積核,b是偏移。
2.池化層(Pooling Layer):
池化層是一種常用的降維技術,其作用是減少資料的空間維度,從而降低計算量,並提取出最顯著的特徵。其中,最大池化是一種常見的池化方式,它會在每個視窗中選擇最大的值作為代表。透過最大池化,我們可以在保留重要資訊的同時,減少資料的複雜度,提高模型的運算效率。
公式(最大池化):
3.全連接層(Fully Connected Layer):
#全連接層在神經網路中扮演著重要的角色,它將捲積和池化層提取的特徵映射連接到輸出類別。全連接層的每個神經元都與前一層的所有神經元相連,這樣可以實現特徵的綜合和分類。
實戰步驟與詳解
1.步驟
- 匯入必要的函式庫和模組。
- 定義網路結構:使用nn.Module定義一個繼承自它的自訂神經網路類,定義卷積層、激活函數、池化層和全連接層。
- 定義損失函數和最佳化器。
- 載入和預處理資料。
- 訓練網路:使用訓練資料迭代訓練網路參數。
- 測試網路:使用測試資料評估模型效能。
2.程式碼實作
import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))
這個範例展示了一個簡單的CNN模型,使用MNIST資料集進行訓練和測試。
接下來,咱們加入視覺化步驟,更直觀地了解模型的表現和訓練過程。
視覺化
1.導入matplotlib
import matplotlib.pyplot as plt
2.在訓練過程中記錄損失和準確率:
在訓練循環中,記錄每個epoch的損失和準確率。
# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)
3.視覺化損失和準確率:
# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()
這樣,咱們就可以在訓練過程結束後看到訓練損失和準確率的變化。
匯入程式碼後,大家可以依照需求調整視覺化的內容和格式。
以上是深入分析Pytorch核心要點,CNN解密!的詳細內容。更多資訊請關注PHP中文網其他相關文章!

最近,隨著大語言模型和AI的興起,我們看到了自然語言處理方面的無數進步。文本,代碼和圖像/視頻生成等域中的模型具有存檔的人類的推理和P

介紹 從面部圖像中檢測性別是計算機視覺的眾多迷人應用之一。在這個項目中,我們將OpenCV結合在一起,以解決位置與性別分類的Roboflow API

介紹 自易貨系統概念以來,廣告世界一直在進化。廣告商找到了創造性的方法來引起我們的關注。在當前年齡,消費者期望BR

介紹 9月12日,OpenAI發布了一項名為“與LLM的學習推理”的更新。他們介紹了O1模型,該模型是使用強化學習來應對複雜推理任務的訓練。是什麼設置了此mod

介紹 OpenAI O1模型家族大大提高了推理能力和經濟表現,尤其是在科學,編碼和解決問題方面。 Openai的目標是創建越來越高的AI和O1模型

介紹 如今,客戶查詢管理的世界正在以前所未有的速度移動,每天都有新的工具成為頭條新聞。大型語言模型(LLM)代理是在這種情況下的最新創新,增強了Cu

介紹 採用生成AI可能是任何公司的變革旅程。但是,Genai實施過程通常會繁瑣且令人困惑。 Niit Lim的董事長兼聯合創始人Rajendra Singh Pawar

介紹 人工智能革命引起了創造力的新時代,文本對圖像模型正在重新定義藝術,設計和技術的交集。 pixtral 12b和qwen2-vl-72b是兩個開創性的力量。


熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

AI Hentai Generator
免費產生 AI 無盡。

熱門文章

熱工具

Atom編輯器mac版下載
最受歡迎的的開源編輯器

SecLists
SecLists是最終安全測試人員的伙伴。它是一個包含各種類型清單的集合,這些清單在安全評估過程中經常使用,而且都在一個地方。 SecLists透過方便地提供安全測試人員可能需要的所有列表,幫助提高安全測試的效率和生產力。清單類型包括使用者名稱、密碼、URL、模糊測試有效載荷、敏感資料模式、Web shell等等。測試人員只需將此儲存庫拉到新的測試機上,他就可以存取所需的每種類型的清單。

DVWA
Damn Vulnerable Web App (DVWA) 是一個PHP/MySQL的Web應用程序,非常容易受到攻擊。它的主要目標是成為安全專業人員在合法環境中測試自己的技能和工具的輔助工具,幫助Web開發人員更好地理解保護網路應用程式的過程,並幫助教師/學生在課堂環境中教授/學習Web應用程式安全性。 DVWA的目標是透過簡單直接的介面練習一些最常見的Web漏洞,難度各不相同。請注意,該軟體中

SublimeText3 Linux新版
SublimeText3 Linux最新版

EditPlus 中文破解版
體積小,語法高亮,不支援程式碼提示功能