この記事では主に Tensorflow の Saver の詳しい使い方を紹介しますので、参考にしてください。一緒に見てみましょう
Saverの使い方
1. Saverの背景紹介
モデルをトレーニングした後、これらの結果はモデルのパラメータを参照することがよくあります。次の反復のためのトレーニング、またはテストに使用されます。 Tensorflow は、この要件に対応する Saver クラスを提供します。2. Saver インスタンス
以下は Saver クラスの使用方法の例ですimport tensorflow as tf import numpy as np x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b))
2.1 テストフェーズ
テストフェーズでは、saver.restore() メソッドを使用して変数を復元します:
sess: 現在のセッションを表し、以前に保存された結果がこのセッションにロードされます
ckpt .model_checkpoint_path: モデルを表します。 保存場所にモデルの名前を指定する必要はありません。チェックポイント ファイルをチェックして、最新のファイルがどれであるか、またその名前が何であるかを確認します。tensorflow フラグを使用してコマンドラインパラメーターを定義する方法
tensorflow1.0 学習モデルの保存と復元(Saver)_python
以上がTensorflow で Saver を使用する方法の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。