ホームページ >バックエンド開発 >Python チュートリアル >tensorflow1.0で学習したモデルの保存と復元(Saver)_python
この記事では、tensorflow1.0 学習モデルの保存と回復 (Saver) を主に紹介しますので、参考として共有します。ぜひ一緒に見てみましょう
後で検証またはテストできるように、トレーニング済みのモデルのパラメーターを保存します。 tf.train.Saver() モジュールは、tf でのモデルの保存を提供します。
モデルを保存するには、まず Saver オブジェクトを作成する必要があります:
saver=tf.train.Saver()
この Saver オブジェクトを作成するときに、よく使用するパラメータがあります。これは、max_to_keep パラメータを設定するために使用されます。モデルを保存するためのパラメーターの数。デフォルトは 5、つまり max_to_keep=5 で、最新の 5 つのモデルを保存します。トレーニング世代 (エポック) ごとにモデルを保存したい場合は、次のように max_to_keep を None または 0 に設定できます:
saver=tf.train.Saver(max_to_keep=0)
ただし、これはより多くのハードディスクを占有すること以外に実用的ではないため、お勧めできません。
もちろん、最後の世代のモデルのみを保存したい場合は、max_to_keep を 1 に設定するだけです。つまり、
saver=tf.train.Saver(max_to_keep=1)
セーバー オブジェクトを作成した後、トレーニングされたモデルを次のように保存できます。 :
saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
最初のパラメータsess、これは言うまでもありません。 2 番目のパラメーターは保存されたパスと名前を設定し、3 番目のパラメーターはトレーニング回数をサフィックスとしてモデル名に追加します。
saver.save(sess, 'my-model', global_step=0) ==> ファイル名: 'my-model-0'
...
saver(sess, 'my-model', global_step= 1000) ==> ファイル名: 'my-model-1000'
mnist の例を見てください:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017 @author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,]) dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) saver=tf.train.Saver(max_to_keep=1) for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
コードの赤い部分はモデルを保存するコードです。後で保存されたモデルは前のモデルを上書きし、最後のモデルのみが保存されます。したがって、時間を節約し、保存コードをループの外に置くことができます (max_to_keep=1 にのみ適用されます。それ以外の場合は、ループ内に置く必要があります
実験では、最後の世代が、その世代ではない可能性があります)。デフォルトでは最後の世代を保存せず、最も検証精度の高い世代を保存したい場合は、中間変数と判定ステートメントを追加するだけです。
saver=tf.train.Saver(max_to_keep=1) max_acc=0 for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
最も検証精度の高い3世代を保存し、各回の検証精度も保存したい場合は、保存用のtxtファイルを生成できます。
saver=tf.train.Saver(max_to_keep=3) max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() sess.close()
restore() 関数はモデルを復元するために使用されます。これには 2 つのパラメーターが必要です。restore(sess、save_path)、save_path は保存されたモデル パスを指します。 tf.train.latest_checkpoint() を使用して、最後に保存されたモデルを自動的に取得できます。例:
model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file)
次に、プログラムの後半のコードを次のように変更できます:
sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) is_train=False saver=tf.train.Saver(max_to_keep=3) #训练阶段 if is_train: max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() #验证阶段 else: model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
赤でマークされた場所は、モデルの保存と復元に関連するコードです。ブール変数 is_train を使用して、トレーニング フェーズと検証フェーズを制御します。
ソースプログラム全体:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017 @author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,]) dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) is_train=True saver=tf.train.Saver(max_to_keep=3) #训练阶段 if is_train: max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() #验证阶段 else: model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
関連する推奨事項:
TensorFlow のモデルネットワークを単一のファイルにエクスポートする方法
以上がtensorflow1.0で学習したモデルの保存と復元(Saver)_pythonの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。