Heim >Technologie-Peripheriegeräte >KI >Neun Schlüsseloperationen von PyTorch!
Heute werden wir über PyTorch sprechen. Ich habe die neun wichtigsten PyTorch-Operationen zusammengefasst, die Ihnen ein Gesamtkonzept geben.
PyTorch-Tensoren ähneln NumPy-Arrays, verfügen jedoch über GPU-Beschleunigung und automatische Ableitungsfunktionen. Wir können die Funktion Torch.tensor verwenden, um Tensoren zu erstellen, oder wir können Torch.zeros, Torch.ones und andere Funktionen verwenden, um Tensoren zu erstellen. Diese Funktionen können uns helfen, Tensoren bequemer zu erstellen.
import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)
Das Modul „torch.autograd“ bietet einen automatischen Ableitungsmechanismus, der die Aufzeichnung von Vorgängen und die Berechnung von Verläufen ermöglicht.
x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)
torch.nn.Module ist die Grundkomponente für den Aufbau eines neuronalen Netzwerks. Es kann verschiedene Schichten umfassen, z. B. eine lineare Schicht (nn.Linear) und eine Faltungsschicht (nn.Conv2d). ) Warten.
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()
Der Optimierer wird verwendet, um Modellparameter anzupassen, um die Verlustfunktion zu reduzieren. Unten sehen Sie ein Beispiel für die Verwendung des Stochastic Gradient Descent (SGD)-Optimierers.
import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)
Die Verlustfunktion wird verwendet, um die Lücke zwischen der Modellausgabe und dem Ziel zu messen. Beispielsweise eignet sich der Kreuzentropieverlust für Klassifizierungsprobleme.
loss_function = nn.CrossEntropyLoss()
Das Modul Torch.utils.data von PyTorch stellt Dataset- und DataLoader-Klassen zum Laden und Vorverarbeiten von Daten bereit. Datensatzklassen können an verschiedene Datenformate und Aufgaben angepasst werden.
from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
Sie können Torch.save verwenden, um das Zustandswörterbuch des Modells zu speichern, und Torch.load verwenden, um das Modell zu laden.
# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))
Das Modul „torch.optim.lr_scheduler“ bietet Tools zur Anpassung der Lernrate. Beispielsweise kann StepLR verwendet werden, um die Lernrate nach jeder Epoche zu reduzieren.
from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
Nach Abschluss des Modelltrainings muss die Modellleistung bewertet werden. Bei der Auswertung müssen Sie das Modell in den Auswertungsmodus (model.eval()) schalten und den Kontextmanager Torch.no_grad() verwenden, um Gradientenberechnungen zu vermeiden.
model.eval()with torch.no_grad():# 运行模型并计算性能指标
Das obige ist der detaillierte Inhalt vonNeun Schlüsseloperationen von PyTorch!. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!