Rumah  >  Artikel  >  Peranti teknologi  >  Sembilan operasi utama PyTorch!

Sembilan operasi utama PyTorch!

PHPz
PHPzke hadapan
2024-01-06 16:45:47408semak imbas

Hari ini kita akan bercakap tentang PyTorch saya telah meringkaskan sembilan operasi PyTorch yang paling penting, yang akan memberi anda konsep keseluruhan.

Sembilan operasi utama PyTorch!

Penciptaan tensor dan operasi asas

Tensor PyTorch adalah serupa dengan tatasusunan NumPy, tetapi ia mempunyai pecutan GPU dan fungsi terbitan automatik. Kita boleh menggunakan fungsi torch.tensor untuk mencipta tensor, atau kita boleh menggunakan torch.zeros, torch.ones dan fungsi lain untuk mencipta tensor. Fungsi ini boleh membantu kami mencipta tensor dengan lebih mudah.

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)

Penerbitan automatik (Autograd)

modul torch.autograd menyediakan mekanisme terbitan automatik, membenarkan operasi rakaman dan mengira kecerunan.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)

Lapisan rangkaian saraf (nn.Module)

torch.nn.Modul ialah komponen asas untuk membina rangkaian saraf Ia boleh merangkumi pelbagai lapisan, seperti lapisan linear (nn.Linear), lapisan convolutional (nn.Conv2d. ) tunggu.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()

Optimizer

Pengoptimum digunakan untuk melaraskan parameter model untuk mengurangkan fungsi kehilangan. Di bawah ialah contoh menggunakan pengoptimum Stochastic Gradient Descent (SGD).

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)

Fungsi Kehilangan

Fungsi kerugian digunakan untuk mengukur jurang antara output model dan sasaran. Sebagai contoh, kehilangan entropi silang sesuai untuk masalah pengelasan.

loss_function = nn.CrossEntropyLoss()

Pemuatan dan prapemprosesan data

Modul torch.utils.data PyTorch menyediakan kelas Set Data dan DataLoader untuk memuatkan dan prapemprosesan data. Kelas set data boleh disesuaikan untuk disesuaikan dengan format data dan tugasan yang berbeza.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Penyimpanan dan pemuatan model

Anda boleh menggunakan torch.save untuk menyimpan kamus keadaan model, dan gunakan torch.load untuk memuatkan model.

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))

Pelarasan kadar pembelajaran

modul torch.optim.lr_scheduler menyediakan alat untuk pelarasan kadar pembelajaran. Sebagai contoh, StepLR boleh digunakan untuk mengurangkan kadar pembelajaran selepas setiap zaman.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Penilaian model

Selepas latihan model selesai, prestasi model perlu dinilai. Semasa menilai, anda perlu menukar model kepada mod penilaian (model.eval()) dan menggunakan pengurus konteks torch.no_grad() untuk mengelakkan pengiraan kecerunan.

rreeee

Atas ialah kandungan terperinci Sembilan operasi utama PyTorch!. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:51cto.com. Jika ada pelanggaran, sila hubungi admin@php.cn Padam