많은 친구들이 저에게 PyTorch를 배우는 방법을 물었습니다. 초보자는 몇 가지 개념과 사용법만 익히면 된다는 것이 입증되었습니다. 이 간결한 가이드의 요약을 살펴보겠습니다!
PyTorch의 Tensors는 NumPy의 ndarray와 유사한 다차원 배열이지만 GPU에서 실행할 수 있습니다.
import torch# Create a 2x3 tensortensor = torch.tensor([[1, 2, 3], [4, 5, 6]])print(tensor)
PyTorch는 동적 계산 그래프를 사용하여 작업을 수행합니다. 연산 런타임에 그래프를 수정할 수 있는 유연성을 제공하는 즉석에서 계산 그래프를 구축합니다.
# Define two tensorsa = torch.tensor([2.], requires_grad=True)b = torch.tensor([3.], requires_grad=True)# Compute resultc = a * bc.backward()# Gradientsprint(a.grad)# Gradient w.r.t a
PyTorch를 사용하면 CPU와 GPU 간을 쉽게 전환할 수 있습니다. .to(장치)를 사용하세요:
device = "cuda" if torch.cuda.is_available() else "cpu"tensor = tensor.to(device)
PyTorch의 autograd는 텐서의 모든 작업에 대해 자동 미분 기능을 제공합니다. require_grad=True로 설정하면 계산을 추적할 수 있습니다:
x = torch.tensor([2.], requires_grad=True)y = x**2y.backward()print(x.grad)# Gradient of y w.r.t x
PyTorch 신경망 아키텍처를 정의하고 서브클래싱을 통해 사용자 정의 레이어를 생성하기 위한 nn.Module 클래스 제공:
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)
PyTorch는 nn 모듈, 손실 함수 및 최적화 알고리즘에서 사전 정의된 다양한 레이어를 제공합니다:
loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
효율적인 데이터 처리 및 일괄 처리를 달성하기 위해 PyTorch는 Dataset 및 DataLoader 클래스를 제공합니다.
from torch.utils.data import Dataset, DataLoaderclass CustomDataset(Dataset):# ... (methods to define)data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
보통 PyTorch 훈련은 다음 패턴을 따릅니다: 순방향 전달, 계산 손실, 역방향 전달 및 매개변수 업데이트:
for epoch in range(epochs):for data, target in data_loader:optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()
torch.save() 및 torch.load()를 사용하여 모델을 저장하고 로드합니다.
# Savetorch.save(model.state_dict(), 'model_weights.pth')# Loadmodel.load_state_dict(torch.load('model_weights.pth'))
PyTorch는 기본적으로 Eager 모드에서 실행되지만 또한 모델을 위한 JIT(Just-In-Time) 컴파일도 제공합니다.
scripted_model = torch.jit.script(model)scripted_model.save("model_jit.pt")
위 내용은 PyTorch를 배우는 방법? 너무 쉽다의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!