>  기사  >  기술 주변기기  >  PyTorch를 배우는 방법? 너무 쉽다

PyTorch를 배우는 방법? 너무 쉽다

WBOY
WBOY앞으로
2024-03-07 19:46:11574검색

많은 친구들이 저에게 PyTorch를 배우는 방법을 물었습니다. 초보자는 몇 가지 개념과 사용법만 익히면 된다는 것이 입증되었습니다. 이 간결한 가이드의 요약을 살펴보겠습니다!

PyTorch 该怎么学?太简单了

Building Tensors

PyTorch의 Tensors는 NumPy의 ndarray와 유사한 다차원 배열이지만 GPU에서 실행할 수 있습니다.

import torch# Create a 2x3 tensortensor = torch.tensor([[1, 2, 3], [4, 5, 6]])print(tensor)

동적 계산 그래프

PyTorch는 동적 계산 그래프를 사용하여 작업을 수행합니다. 연산 런타임에 그래프를 수정할 수 있는 유연성을 제공하는 즉석에서 계산 그래프를 구축합니다.

# Define two tensorsa = torch.tensor([2.], requires_grad=True)b = torch.tensor([3.], requires_grad=True)# Compute resultc = a * bc.backward()# Gradientsprint(a.grad)# Gradient w.r.t a

GPU 가속

PyTorch를 사용하면 CPU와 GPU 간을 쉽게 전환할 수 있습니다. .to(장치)를 사용하세요:

device = "cuda" if torch.cuda.is_available() else "cpu"tensor = tensor.to(device)

Autograd: 자동 미분

PyTorch의 autograd는 텐서의 모든 작업에 대해 자동 미분 기능을 제공합니다. require_grad=True로 설정하면 계산을 추적할 수 있습니다:

x = torch.tensor([2.], requires_grad=True)y = x**2y.backward()print(x.grad)# Gradient of y w.r.t x

Modular Neural Network

PyTorch 신경망 아키텍처를 정의하고 서브클래싱을 통해 사용자 정의 레이어를 생성하기 위한 nn.Module 클래스 제공:

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)

사전 정의된 레이어 및 손실 함수

PyTorch는 nn 모듈, 손실 함수 및 최적화 알고리즘에서 사전 정의된 다양한 레이어를 제공합니다:

loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Dataset 및 DataLoader

효율적인 데이터 처리 및 일괄 처리를 달성하기 위해 PyTorch는 Dataset 및 DataLoader 클래스를 제공합니다.

from torch.utils.data import Dataset, DataLoaderclass CustomDataset(Dataset):# ... (methods to define)data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

모델 훈련(루프)

보통 PyTorch 훈련은 다음 패턴을 따릅니다: 순방향 전달, 계산 손실, 역방향 전달 및 매개변수 업데이트:

for epoch in range(epochs):for data, target in data_loader:optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()

모델 직렬화

torch.save() 및 torch.load()를 사용하여 모델을 저장하고 로드합니다.

# Savetorch.save(model.state_dict(), 'model_weights.pth')# Loadmodel.load_state_dict(torch.load('model_weights.pth'))

JIT

PyTorch는 기본적으로 Eager 모드에서 실행되지만 또한 모델을 위한 JIT(Just-In-Time) 컴파일도 제공합니다.

scripted_model = torch.jit.script(model)scripted_model.save("model_jit.pt")

위 내용은 PyTorch를 배우는 방법? 너무 쉽다의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제