Heim >Technologie-Peripheriegeräte >KI >So erstellen Sie mit PyTorch ein einfaches neuronales Netzwerk

So erstellen Sie mit PyTorch ein einfaches neuronales Netzwerk

WBOY
WBOYnach vorne
2024-01-25 09:27:06702Durchsuche

So erstellen Sie mit PyTorch ein einfaches neuronales Netzwerk

PyTorch ist ein Python-basiertes Deep-Learning-Framework zum Aufbau verschiedener neuronaler Netze. In diesem Artikel wird gezeigt, wie Sie mit PyTorch ein einfaches neuronales Netzwerk aufbauen und Codebeispiele bereitstellen.

Zuerst müssen wir PyTorch installieren. Es kann über die Befehlszeile installiert werden mit:

pip install torch

Als nächstes werden wir PyTorch verwenden, um ein einfaches, vollständig verbundenes neuronales Netzwerk für binäre Klassifizierungsaufgaben aufzubauen. Dieses neuronale Netzwerk wird zwei verborgene Schichten mit jeweils 10 Neuronen haben. Wir werden die Sigmoid-Aktivierungsfunktion und die Kreuzentropieverlustfunktion verwenden.

Hier ist der vollständige Code:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 10)  # 第一个隐藏层
        self.fc2 = nn.Linear(10, 10)  # 第二个隐藏层
        self.fc3 = nn.Linear(10, 1)  # 输出层

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# 创建数据集
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

# 创建神经网络实例
net = Net()

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

# 训练神经网络
for epoch in range(10000):
    optimizer.zero_grad()
    output = net(X)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

    # 打印训练损失
    if epoch % 1000 == 0:
    print('Epoch {}: loss = {}'.format(epoch, loss.item()))

# 使用训练好的神经网络进行预测
with torch.no_grad():
    output = net(X)
    predicted = (output > 0.5).float()
    print('Predicted: {}\n'.format(predicted))

Zuerst definieren wir eine Klasse namens Net, die von nn.Module erbt. Diese Klasse enthält alle Schichten des neuronalen Netzwerks. In diesem Beispiel definieren wir drei vollständig verbundene Schichten, von denen die ersten beiden verborgene Schichten sind und die letzte die Ausgabeschicht ist.

In der Net-Klasse definieren wir nicht nur eine Vorwärtsmethode zur Beschreibung des Vorwärtsausbreitungsprozesses des neuronalen Netzwerks, sondern verwenden auch die Sigmoid-Aktivierungsfunktion, um die Ausgabe jeder verborgenen Schicht an die nächste Schicht weiterzuleiten.

Als nächstes haben wir einen Datensatz mit vier Stichproben erstellt, wobei jede Stichprobe zwei Merkmale aufweist. Wir haben auch eine neuronale Netzwerkinstanz namens net definiert und BCELoss als Verlustfunktion und SGD als Optimierer ausgewählt.

Dann beginnen wir mit dem Training des neuronalen Netzwerks. In jeder Iteration setzen wir zunächst den Gradienten des Optimierers auf Null und übergeben dann den Datensatz X an das neuronale Netzwerk, um die Ausgabe zu erhalten. Wir berechnen den Verlust, führen eine Backpropagation durch und aktualisieren schließlich die Netzwerkparameter mithilfe eines Optimierers. Wir haben auch den Trainingsverlust für alle 1000 Iterationen ausgedruckt.

Nach Abschluss des Trainings verwenden wir den Kontextmanager no_grad, um Vorhersagen für den Datensatz zu treffen. Wir werden die vier Vorhersagen ausgeben und ausdrucken.

Dies ist ein einfaches Beispiel, das zeigt, wie man mit PyTorch ein einfaches neuronales Netzwerk aufbaut. PyTorch bietet viele Tools und Funktionen, die uns helfen, neuronale Netze einfacher aufzubauen und zu trainieren.

Das obige ist der detaillierte Inhalt vonSo erstellen Sie mit PyTorch ein einfaches neuronales Netzwerk. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:163.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen