>기술 주변기기 >일체 포함 >PyTorch의 9가지 주요 동작!

PyTorch의 9가지 주요 동작!

PHPz
PHPz앞으로
2024-01-06 16:45:47682검색

오늘은 PyTorch에 대해 이야기하겠습니다. 전반적인 개념을 알려줄 가장 중요한 PyTorch 작업 9가지를 요약했습니다.

PyTorch의 9가지 주요 동작!

Tensor 생성 및 기본 작업

PyTorch 텐서는 NumPy 배열과 유사하지만 GPU 가속 및 자동 파생 기능이 있습니다. torch.tensor 함수를 사용하여 텐서를 생성하거나 torch.zeros, torch.ones 및 기타 함수를 사용하여 텐서를 생성할 수 있습니다. 이러한 기능을 사용하면 텐서를 보다 편리하게 생성할 수 있습니다.

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)

Autograd(Autograd)

torch.autograd 모듈은 자동 파생 메커니즘을 제공하여 작업을 기록하고 경사도를 계산할 수 있습니다.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)

신경망 계층(nn.Module)

torch.nn.Module은 신경망을 구축하기 위한 기본 구성 요소로 선형 계층(nn.Linear), 컨볼루션 계층(nn.Conv2d) 등 다양한 계층을 포함할 수 있습니다. ) 기다리다.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()

Optimizer

옵티마이저는 모델 매개변수를 조정하여 손실 함수를 줄이는 데 사용됩니다. 아래는 SGD(Stochastic Gradient Descent) 최적화 프로그램을 사용한 예입니다.

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)

손실 함수

손실 함수는 모델 출력과 목표 사이의 간격을 측정하는 데 사용됩니다. 예를 들어 교차 엔트로피 손실은 분류 문제에 적합합니다.

loss_function = nn.CrossEntropyLoss()

데이터 로드 및 사전 처리

PyTorch의 torch.utils.data 모듈은 데이터 로드 및 사전 처리를 위한 Dataset 및 DataLoader 클래스를 제공합니다. 데이터 세트 클래스는 다양한 데이터 형식 및 작업에 맞게 사용자 정의할 수 있습니다.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

모델 저장 및 로드

torch.save를 사용하여 모델의 상태 사전을 저장할 수 있고, torch.load를 사용하여 모델을 로드할 수 있습니다.

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))

학습률 조정

torch.optim.lr_scheduler 모듈은 학습률 조정을 위한 도구를 제공합니다. 예를 들어 StepLR을 사용하면 각 에포크 이후 학습 속도를 줄일 수 있습니다.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

모델 평가

모델 학습이 완료된 후 모델 성능을 평가해야 합니다. 평가할 때 모델을 평가 모드(model.eval())로 전환하고 기울기 계산을 피하기 위해 torch.no_grad() 컨텍스트 관리자를 사용해야 합니다.

아아아아

위 내용은 PyTorch의 9가지 주요 동작!의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제