Maison >Périphériques technologiques >IA >Comment créer un réseau neuronal simple à l'aide de PyTorch

Comment créer un réseau neuronal simple à l'aide de PyTorch

WBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWB
WBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBavant
2024-01-25 09:27:06748parcourir

Comment créer un réseau neuronal simple à laide de PyTorch

PyTorch est un framework d'apprentissage en profondeur basé sur Python pour créer divers réseaux de neurones. Cet article montrera comment utiliser PyTorch pour créer un réseau neuronal simple et fournira des exemples de code.

Tout d'abord, nous devons installer PyTorch. Il peut être installé à partir de la ligne de commande avec :

pip install torch

Ensuite, nous utiliserons PyTorch pour construire un réseau neuronal simple entièrement connecté pour les tâches de classification binaire. Ce réseau de neurones aura deux couches cachées de 10 neurones chacune. Nous utiliserons la fonction d'activation sigmoïde et la fonction de perte d'entropie croisée.

Voici le code complet :

import torch
import torch.nn as nn
import torch.optim as optim

# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 10)  # 第一个隐藏层
        self.fc2 = nn.Linear(10, 10)  # 第二个隐藏层
        self.fc3 = nn.Linear(10, 1)  # 输出层

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# 创建数据集
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

# 创建神经网络实例
net = Net()

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

# 训练神经网络
for epoch in range(10000):
    optimizer.zero_grad()
    output = net(X)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

    # 打印训练损失
    if epoch % 1000 == 0:
    print('Epoch {}: loss = {}'.format(epoch, loss.item()))

# 使用训练好的神经网络进行预测
with torch.no_grad():
    output = net(X)
    predicted = (output > 0.5).float()
    print('Predicted: {}\n'.format(predicted))

Tout d'abord, nous définissons une classe appelée Net, qui hérite de nn.Module. Cette classe contient toutes les couches du réseau neuronal. Dans cet exemple, nous définissons trois couches entièrement connectées, dont les deux premières sont des couches cachées et la dernière est la couche de sortie.

Dans la classe Net, en plus de définir une méthode directe pour décrire le processus de propagation vers l'avant du réseau neuronal, nous utilisons également la fonction d'activation sigmoïde pour transmettre la sortie de chaque couche cachée à la couche suivante.

Ensuite, nous avons créé un ensemble de données contenant quatre échantillons, où chaque échantillon possède deux caractéristiques. Nous avons également défini une instance de réseau neuronal nommée net et sélectionné BCELoss comme fonction de perte et SGD comme optimiseur.

Ensuite, nous commençons à entraîner le réseau neuronal. À chaque itération, nous remettons d'abord à zéro le gradient de l'optimiseur, puis transmettons l'ensemble de données X dans le réseau neuronal pour obtenir le résultat. Nous calculons la perte et effectuons une rétropropagation, et enfin mettons à jour les paramètres du réseau à l'aide d'un optimiseur. Nous avons également imprimé la perte d'entraînement toutes les 1 000 itérations.

Une fois la formation terminée, nous utilisons le gestionnaire de contexte no_grad pour faire des prédictions sur l'ensemble de données. Nous allons sortir les quatre prédictions et les imprimer.

Ceci est un exemple simple montrant comment créer un réseau neuronal de base à l'aide de PyTorch. PyTorch fournit de nombreux outils et fonctions pour nous aider à créer et former plus facilement des réseaux de neurones.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer