ホームページ  >  記事  >  バックエンド開発  >  Tensorflow で Saver を使用する方法

Tensorflow で Saver を使用する方法

不言
不言オリジナル
2018-04-23 15:46:311993ブラウズ

この記事では主に Tensorflow の Saver の詳しい使い方を紹介しますので、参考にしてください。一緒に見てみましょう

Saverの使い方

1. Saverの背景紹介

モデルをトレーニングした後、これらの結果はモデルのパラメータを参照することがよくあります。次の反復のためのトレーニング、またはテストに使用されます。 Tensorflow は、この要件に対応する Saver クラスを提供します。


Saver クラスは、チェックポイント ファイルから変数を保存および復元するための関連メソッドを提供します。チェックポイント ファイルは、変数名を対応するテンソル値にマップするバイナリ ファイルです。


カウンターが提供されている限り、カウンターがトリガーされたときに Saver クラスはチェックポイント ファイルを自動的に生成できます。これにより、トレーニング中に複数の中間結果を保存できます。たとえば、各トレーニング ステップの結果を保存できます。


ディスク全体がいっぱいになるのを避けるために、Saver はチェックポイント ファイルを自動的に管理できます。たとえば、最新の N 個のチェックポイント ファイルを保存するように指定できます。


2. Saver インスタンス

以下は Saver クラスの使用方法の例です


import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))

  1. : トレーニングフェーズとテストフェーズを区別するために使用されます。True はトレーニングを意味します。 False はテストを意味します

  2. train_steps: トレーニング回数を示します。例では 100 が使用されます。

  3. checkpoint_steps: トレーニング中にチェックポイントを保存する回数を示します。

    checkpoint_dir: チェックポイント ファイルの保存パスを示します。この例では、現在のパス
  4. 2.1 トレーニング フェーズ
Saver.save() メソッドを使用してモデルを保存します。現在の変数値を記録する現在のセッション

checkpoint_dir + 'model.ckpt': 保存されているファイル名を表します
  1. global_step: 現在のステップを示します
  2. トレーニングが完了すると、現在のディレクトリの下にさらに 5 つのファイルが存在します。
  3. 「checkpoint」という名前のファイルを開くと、保存記録と最新モデルの保存場所が表示されます。

2.1 テストフェーズ

テストフェーズでは、saver.restore() メソッドを使用して変数を復元します:

sess: 現在のセッションを表し、以前に保存された結果がこのセッションにロードされます

ckpt .model_checkpoint_path: モデルを表します。 保存場所にモデルの名前を指定する必要はありません。チェックポイント ファイルをチェックして、最新のファイルがどれであるか、またその名前が何であるかを確認します。


実行結果は、以前にトレーニングされたパラメーター w と b の結果をロードして、以下の図に示されています



関連推奨事項:


tensorflow フラグを使用してコマンドラインパラメーターを定義する方法


tensorflow1.0 学習モデルの保存と復元(Saver)_python

以上がTensorflow で Saver を使用する方法の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。