Home  >  Article  >  Technology peripherals  >  Nine key operations of PyTorch!

Nine key operations of PyTorch!

PHPz
PHPzforward
2024-01-06 16:45:47450browse

Today we will talk about PyTorch. I have summarized the nine most important PyTorch operations, which will give you an overall concept.

Nine key operations of PyTorch!

Tensor creation and basic operations

PyTorch tensors are similar to NumPy arrays, but they have GPU acceleration and automatic derivation functions. We can use the torch.tensor function to create tensors, or we can use torch.zeros, torch.ones and other functions to create tensors. These functions can help us create tensors more conveniently.

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)

Autograd (Autograd)

The torch.autograd module provides an automatic derivation mechanism, allowing recording operations and calculating gradients.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)

Neural network layer (nn.Module)

torch.nn.Module is the basic component for building a neural network. It can include various layers, such as linear layer (nn.Linear), Convolutional layer (nn.Conv2d) etc.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()

Optimizer

The optimizer is used to adjust model parameters to reduce the loss function. Below is an example using the Stochastic Gradient Descent (SGD) optimizer.

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)

Loss Function

The loss function is used to measure the difference between the model output and the target. For example, cross-entropy loss is suitable for classification problems.

loss_function = nn.CrossEntropyLoss()

Data loading and preprocessing

The torch.utils.data module of PyTorch provides the Dataset and DataLoader classes for loading and preprocessing data. Dataset classes can be customized to suit different data formats and tasks.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Model saving and loading

You can use torch.save to save the model's state dictionary, and use torch.load to load the model.

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))

Learning rate adjustment

The torch.optim.lr_scheduler module provides tools for learning rate adjustment. For example, StepLR can be used to reduce the learning rate after each epoch.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Model evaluation

After the model training is completed, the model performance needs to be evaluated. When evaluating, you need to switch the model to evaluation mode (model.eval()) and use the torch.no_grad() context manager to avoid gradient calculations.

model.eval()with torch.no_grad():# 运行模型并计算性能指标

The above is the detailed content of Nine key operations of PyTorch!. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete