この記事では、tensorflow1.0 学習モデルの保存と回復 (Saver) を主に紹介しますので、参考として共有します。ぜひ一緒に見てみましょう
後で検証またはテストできるように、トレーニング済みのモデルのパラメーターを保存します。 tf.train.Saver() モジュールは、tf でのモデルの保存を提供します。
モデルを保存するには、まず Saver オブジェクトを作成する必要があります:
saver=tf.train.Saver()
この Saver オブジェクトを作成するときに、よく使用するパラメータがあります。これは、max_to_keep パラメータを設定するために使用されます。モデルを保存するためのパラメーターの数。デフォルトは 5、つまり max_to_keep=5 で、最新の 5 つのモデルを保存します。トレーニング世代 (エポック) ごとにモデルを保存したい場合は、次のように max_to_keep を None または 0 に設定できます:
saver=tf.train.Saver(max_to_keep=0)
ただし、これはより多くのハードディスクを占有すること以外に実用的ではないため、お勧めできません。
もちろん、最後の世代のモデルのみを保存したい場合は、max_to_keep を 1 に設定するだけです。つまり、
saver=tf.train.Saver(max_to_keep=1)
セーバー オブジェクトを作成した後、トレーニングされたモデルを次のように保存できます。 :
saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
最初のパラメータsess、これは言うまでもありません。 2 番目のパラメーターは保存されたパスと名前を設定し、3 番目のパラメーターはトレーニング回数をサフィックスとしてモデル名に追加します。
saver.save(sess, 'my-model', global_step=0) ==> ファイル名: 'my-model-0'
...
saver(sess, 'my-model', global_step= 1000) ==> ファイル名: 'my-model-1000'
mnist の例を見てください:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017 @author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,]) dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) saver=tf.train.Saver(max_to_keep=1) for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
コードの赤い部分はモデルを保存するコードです。後で保存されたモデルは前のモデルを上書きし、最後のモデルのみが保存されます。したがって、時間を節約し、保存コードをループの外に置くことができます (max_to_keep=1 にのみ適用されます。それ以外の場合は、ループ内に置く必要があります
実験では、最後の世代が、その世代ではない可能性があります)。デフォルトでは最後の世代を保存せず、最も検証精度の高い世代を保存したい場合は、中間変数と判定ステートメントを追加するだけです。
saver=tf.train.Saver(max_to_keep=1) max_acc=0 for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
最も検証精度の高い3世代を保存し、各回の検証精度も保存したい場合は、保存用のtxtファイルを生成できます。
saver=tf.train.Saver(max_to_keep=3) max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() sess.close()
restore() 関数はモデルを復元するために使用されます。これには 2 つのパラメーターが必要です。restore(sess、save_path)、save_path は保存されたモデル パスを指します。 tf.train.latest_checkpoint() を使用して、最後に保存されたモデルを自動的に取得できます。例:
model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file)
次に、プログラムの後半のコードを次のように変更できます:
sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) is_train=False saver=tf.train.Saver(max_to_keep=3) #训练阶段 if is_train: max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() #验证阶段 else: model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
赤でマークされた場所は、モデルの保存と復元に関連するコードです。ブール変数 is_train を使用して、トレーニング フェーズと検証フェーズを制御します。
ソースプログラム全体:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017 @author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,]) dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) is_train=True saver=tf.train.Saver(max_to_keep=3) #训练阶段 if is_train: max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() #验证阶段 else: model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
関連する推奨事項:
TensorFlow のモデルネットワークを単一のファイルにエクスポートする方法
以上がtensorflow1.0で学習したモデルの保存と復元(Saver)_pythonの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

Numpyを使用して多次元配列を作成すると、次の手順を通じて実現できます。1)numpy.array()関数を使用して、np.array([[1,2,3]、[4,5,6]])などの配列を作成して2D配列を作成します。 2)np.zeros()、np.ones()、np.random.random()およびその他の関数を使用して、特定の値で満たされた配列を作成します。 3)アレイの形状とサイズの特性を理解して、サブアレイの長さが一貫していることを確認し、エラーを回避します。 4)np.reshape()関数を使用して、配列の形状を変更します。 5)コードが明確で効率的であることを確認するために、メモリの使用に注意してください。

BroadcastinginNumPyisamethodtoperformoperationsonarraysofdifferentshapesbyautomaticallyaligningthem.Itsimplifiescode,enhancesreadability,andboostsperformance.Here'showitworks:1)Smallerarraysarepaddedwithonestomatchdimensions.2)Compatibledimensionsare

Forpythondatastorage、chooseLists forfficability withmixeddatypes、array.arrayformemory-efficienthogeneousnumericaldata、およびnumpyArrays foradvancednumericalcomputing.listSareversatilebuteficient efficient forlargeNumericaldatates;

pythonlistsarebetterthanarrays formangingdiversedatypes.1)listscanholdelementsofdifferenttypes、2)adearedditionsandremovals、3)theeofferintutiveoperation likeslicing、but4)theearlessememory-effice-hemory-hemory-hemory-hemory-hemory-adlower-dslorededatas。

toaccesselementsinapythonarray、useindexing:my_array [2] Accessesthirderement、Returning3.pythonuseszero basedIndexing.1)usepositiveandnegativeindexing:my_list [0] forteefirstelement、my_list [-1] exterarast.2)

記事では、構文のあいまいさのためにPythonにおけるタプル理解の不可能性について説明します。 Tupple式を使用してTuple()を使用するなどの代替は、Tuppleを効率的に作成するためにお勧めします。(159文字)

この記事では、Pythonのモジュールとパッケージ、その違い、および使用について説明しています。モジュールは単一のファイルであり、パッケージは__init__.pyファイルを備えたディレクトリであり、関連するモジュールを階層的に整理します。

記事では、PythonのDocstrings、それらの使用、および利点について説明します。主な問題:コードのドキュメントとアクセシビリティに関するドキュストリングの重要性。


ホットAIツール

Undresser.AI Undress
リアルなヌード写真を作成する AI 搭載アプリ

AI Clothes Remover
写真から衣服を削除するオンライン AI ツール。

Undress AI Tool
脱衣画像を無料で

Clothoff.io
AI衣類リムーバー

Video Face Swap
完全無料の AI 顔交換ツールを使用して、あらゆるビデオの顔を簡単に交換できます。

人気の記事

ホットツール

メモ帳++7.3.1
使いやすく無料のコードエディター

SublimeText3 Linux 新バージョン
SublimeText3 Linux 最新バージョン

WebStorm Mac版
便利なJavaScript開発ツール

MinGW - Minimalist GNU for Windows
このプロジェクトは osdn.net/projects/mingw に移行中です。引き続きそこでフォローしていただけます。 MinGW: GNU Compiler Collection (GCC) のネイティブ Windows ポートであり、ネイティブ Windows アプリケーションを構築するための自由に配布可能なインポート ライブラリとヘッダー ファイルであり、C99 機能をサポートする MSVC ランタイムの拡張機能が含まれています。すべての MinGW ソフトウェアは 64 ビット Windows プラットフォームで実行できます。

PhpStorm Mac バージョン
最新(2018.2.1)のプロフェッショナル向けPHP統合開発ツール

ホットトピック









