ホームページ >テクノロジー周辺機器 >AI >PyTorch の 9 つの主要な操作!
今日は PyTorch について話します。全体的な概念を示す 9 つの最も重要な PyTorch 操作をまとめました。
PyTorch テンソルは NumPy 配列に似ていますが、GPU アクセラレーションと自動導出機能を備えています。 torch.tensor 関数を使用してテンソルを作成することも、torch.zeros、torch.ones およびその他の関数を使用してテンソルを作成することもできます。これらの関数は、テンソルをより簡単に作成するのに役立ちます。
import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)
torch.autograd モジュールは自動導出メカニズムを提供し、操作の記録と勾配の計算を可能にします。
x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)
torch.nn.Module は、ニューラル ネットワークを構築するための基本コンポーネントです。線形層 (nn.Linear など) など、さまざまな層を含めることができます。 )、畳み込み層(nn.Conv2d)など。
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()
オプティマイザーは、モデル パラメーターを調整して損失関数を削減するために使用されます。以下は、確率的勾配降下 (SGD) オプティマイザーを使用した例です。
import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)
損失関数は、モデルの出力とターゲットの差を測定するために使用されます。たとえば、クロスエントロピー損失は分類問題に適しています。
loss_function = nn.CrossEntropyLoss()
PyTorch の torch.utils.data モジュールは、データのロードと前処理のための Dataset クラスと DataLoader クラスを提供します。データセット クラスは、さまざまなデータ形式やタスクに合わせてカスタマイズできます。
from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
torch.save を使用してモデルの状態辞書を保存し、torch.load を使用してモデルをロードできます。
# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))
torch.optim.lr_scheduler モジュールは、学習率調整のためのツールを提供します。たとえば、StepLR を使用すると、各エポック後の学習率を下げることができます。
from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
モデルのトレーニングが完了したら、モデルのパフォーマンスを評価する必要があります。評価するときは、モデルを評価モード (model.eval()) に切り替え、torch.no_grad() コンテキスト マネージャーを使用して勾配計算を回避する必要があります。
rree以上がPyTorch の 9 つの主要な操作!の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。