ホームページ  >  記事  >  バックエンド開発  >  PyTorchでニューラルネットワークを素早く構築し、保存・抽出する方法を詳しく解説

PyTorchでニューラルネットワークを素早く構築し、保存・抽出する方法を詳しく解説

不言
不言オリジナル
2018-04-28 10:56:062516ブラウズ

この記事では主にニューラルネットワークを素早く構築するためのPyTorchとその保存方法と抽出方法の詳細な説明を紹介し、参考にさせていただきます。一緒に見てみましょう

モデルをトレーニングし、次回トレーニングに時間を費やすことなく直接使用できるように保存したい場合があります。このセクションでは、PyTorch でニューラル ネットワークをすばやく構築する方法と保存方法について説明します。

1. PyTorch でニューラル ネットワークをすばやく構築する方法

まずは実験用のコードを見てみましょう:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

以前は、次のように定義してニューラル ネットワークを構築する方法を学びました。 net クラスでは、最初に super を渡します。この関数は torch.nn.Module モジュールの構築メソッドを継承し、属性を追加してニューラル ネットワークの各層の構造情報を構築し、各層間の接続情報を改善します。 forward メソッドでニューラル ネットワークを定義し、Net クラス オブジェクト メソッドを定義してニューラル ネットワーク構造の構築を完了します。

ニューラル ネットワークを構築するもう 1 つの方法は、簡単な構築方法とも言えますが、torch.nn.Sequential を通じてニューラル ネットワークの構築を直接完了することです。

2 つの方法で構築されたニューラル ネットワークの構造はまったく同じであり、ネットワーク情報は print 関数を通じて印刷できますが、印刷結果は若干異なります。

2. PyTorch ニューラル ネットワークの保存と抽出

ディープ ラーニングを学習および研究するとき、一定期間のトレーニング後により良いモデルが得られたら、当然このモデルを使用したいと考え、モデル パラメーターは保存されます。後で使用するため、ニューラル ネットワークを保存し、モデル パラメーターを抽出して再ロードする必要があります。

まず、ネットワーク構造とそのモデルパラメーターを保存する必要があるニューラルネットワークの定義とトレーニング部分の後に、torch.save() を通じてネットワーク構造とモデルパラメーターを保存する必要があります。保存方法には、ニューラルネットワーク全体の構造情報とモデルパラメータ情報を保存対象とし、ネットワークネットを保存する方法と、ニューラルネットワークの学習モデルパラメータのみを保存する方法があります。保存のオブジェクトは net.state_dict() です。保存された結果は .pkl ファイルの形式で保存されます。

上記2つのセーブ方法に対応しており、リロード方法も2つあります。最初の完全なネットワーク構造情報に対応して、リロード時に torch.load(‘.pkl’) を通じて新しいニューラル ネットワーク オブジェクトを直接初期化できます。モデル パラメーター情報のみを保存する 2 番目の方法に対応して、最初に同じニューラル ネットワーク構造を構築し、net.load_state_dict(torch.load('.pkl')) を通じてモデル パラメーターの再読み込みを完了する必要があります。ネットワークが比較的大きい場合、最初の方法の方が時間がかかります。

コードの実装:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

実験結果:

関連する推奨事項:

PyTorchで畳み込みニューラルネットワークCNNを実装する方法

Pyの詳細な説明トーチのバッチ トレーニングとオプティマイザーの比較

Pytorch の紹介 - mnist 分類の例

以上がPyTorchでニューラルネットワークを素早く構築し、保存・抽出する方法を詳しく解説の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。