Heim  >  Artikel  >  Technologie-Peripheriegeräte  >  Neun Schlüsseloperationen von PyTorch!

Neun Schlüsseloperationen von PyTorch!

PHPz
PHPznach vorne
2024-01-06 16:45:47408Durchsuche

Heute werden wir über PyTorch sprechen. Ich habe die neun wichtigsten PyTorch-Operationen zusammengefasst, die Ihnen ein Gesamtkonzept geben.

Neun Schlüsseloperationen von PyTorch!

Tensorerstellung und grundlegende Operationen

PyTorch-Tensoren ähneln NumPy-Arrays, verfügen jedoch über GPU-Beschleunigung und automatische Ableitungsfunktionen. Wir können die Funktion Torch.tensor verwenden, um Tensoren zu erstellen, oder wir können Torch.zeros, Torch.ones und andere Funktionen verwenden, um Tensoren zu erstellen. Diese Funktionen können uns helfen, Tensoren bequemer zu erstellen.

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)

Automatische Ableitung (Autograd)

Das Modul „torch.autograd“ bietet einen automatischen Ableitungsmechanismus, der die Aufzeichnung von Vorgängen und die Berechnung von Verläufen ermöglicht.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)

Neuronale Netzwerkschicht (nn.Module)

torch.nn.Module ist die Grundkomponente für den Aufbau eines neuronalen Netzwerks. Es kann verschiedene Schichten umfassen, z. B. eine lineare Schicht (nn.Linear) und eine Faltungsschicht (nn.Conv2d). ) Warten.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()

Optimierer

Der Optimierer wird verwendet, um Modellparameter anzupassen, um die Verlustfunktion zu reduzieren. Unten sehen Sie ein Beispiel für die Verwendung des Stochastic Gradient Descent (SGD)-Optimierers.

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)

Verlustfunktion

Die Verlustfunktion wird verwendet, um die Lücke zwischen der Modellausgabe und dem Ziel zu messen. Beispielsweise eignet sich der Kreuzentropieverlust für Klassifizierungsprobleme.

loss_function = nn.CrossEntropyLoss()

Laden und Vorverarbeiten von Daten

Das Modul Torch.utils.data von PyTorch stellt Dataset- und DataLoader-Klassen zum Laden und Vorverarbeiten von Daten bereit. Datensatzklassen können an verschiedene Datenformate und Aufgaben angepasst werden.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Speichern und Laden von Modellen

Sie können Torch.save verwenden, um das Zustandswörterbuch des Modells zu speichern, und Torch.load verwenden, um das Modell zu laden.

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))

Anpassung der Lernrate

Das Modul „torch.optim.lr_scheduler“ bietet Tools zur Anpassung der Lernrate. Beispielsweise kann StepLR verwendet werden, um die Lernrate nach jeder Epoche zu reduzieren.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Modellbewertung

Nach Abschluss des Modelltrainings muss die Modellleistung bewertet werden. Bei der Auswertung müssen Sie das Modell in den Auswertungsmodus (model.eval()) schalten und den Kontextmanager Torch.no_grad() verwenden, um Gradientenberechnungen zu vermeiden.

model.eval()with torch.no_grad():# 运行模型并计算性能指标

Das obige ist der detaillierte Inhalt vonNeun Schlüsseloperationen von PyTorch!. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen