Maison >Périphériques technologiques >IA >Neuf opérations clés de PyTorch !

Neuf opérations clés de PyTorch !

PHPz
PHPzavant
2024-01-06 16:45:47459parcourir

Aujourd'hui, nous allons parler de PyTorch. J'ai résumé les neuf opérations PyTorch les plus importantes, ce qui vous donnera un concept global.

Neuf opérations clés de PyTorch !

Création de tenseurs et opérations de base

Les tenseurs PyTorch sont similaires aux tableaux NumPy, mais ils ont des fonctions d'accélération GPU et de dérivation automatique. Nous pouvons utiliser la fonction torch.tensor pour créer des tenseurs, ou nous pouvons utiliser torch.zeros, torch.ones et d'autres fonctions pour créer des tenseurs. Ces fonctions peuvent nous aider à créer des tenseurs plus facilement. Le module

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)

Autograd (Autograd)

torch.autograd fournit un mécanisme de dérivation automatique, permettant les opérations d'enregistrement et le calcul des gradients.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)

Couche de réseau neuronal (nn.Module)

torch.nn.Module est le composant de base pour la construction d'un réseau neuronal. Il peut inclure diverses couches, telles qu'une couche linéaire (nn.Linear), une couche convolutive (nn.Conv2d). ) attendez.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()

Optimiseur

L'optimiseur est utilisé pour ajuster les paramètres du modèle afin de réduire la fonction de perte. Vous trouverez ci-dessous un exemple utilisant l'optimiseur Stochastic Gradient Descent (SGD).

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)

Fonction de perte

La fonction de perte est utilisée pour mesurer l'écart entre la sortie du modèle et la cible. Par exemple, la perte d’entropie croisée convient aux problèmes de classification.

loss_function = nn.CrossEntropyLoss()

Chargement et prétraitement des données

Le module torch.utils.data de PyTorch fournit les classes Dataset et DataLoader pour le chargement et le prétraitement des données. Les classes d'ensembles de données peuvent être personnalisées pour s'adapter à différents formats de données et tâches.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Sauvegarde et chargement du modèle

Vous pouvez utiliser torch.save pour enregistrer le dictionnaire d'état du modèle et torch.load pour charger le modèle. Le module

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))

Ajustement du taux d'apprentissage

torch.optim.lr_scheduler fournit des outils pour l'ajustement du taux d'apprentissage. Par exemple, StepLR peut être utilisé pour réduire le taux d'apprentissage après chaque époque.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Évaluation du modèle

Une fois la formation du modèle terminée, les performances du modèle doivent être évaluées. Lors de l'évaluation, vous devez passer le modèle en mode d'évaluation (model.eval()) et utiliser le gestionnaire de contexte torch.no_grad() pour éviter les calculs de dégradé.

model.eval()with torch.no_grad():# 运行模型并计算性能指标

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer