この記事では主に Tensorflow モデルの保存と復元について紹介しますので、参考にしてください。ぜひ見てみましょう
最近、私たちは共通ルール、マッチング、フィルタリングの使用に加えて、分類予測にいくつかの機械学習手法も使用しています。 TensorFlow を使用してモデルをトレーニングします。トレーニングされたモデルは、予測フェーズで保存する必要があります。これには、TensorFlow モデルの保存と復元が含まれます。
Tensorflow の一般的に使用されるモデル保存方法をまとめます。
チェックポイントモデルファイル(.ckpt)の保存
まず第一に、TensorFlowは、機械学習モデルを保存および復元するための非常に便利なAPI、tf.train.Saver()を提供します。
モデルの保存
モデル ファイルを保存するには tf.train.Saver() を使用すると非常に便利です。簡単な例を示します:
import tensorflow as tf import os def save_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(ckpt_file_path)) if os.path.isdir(path) is False: os.makedirs(path) tf.train.Saver().save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
ファイル (バージョン 0.11 より前では、checkpoint、model.ckpt、model.ckpt.meta の 3 つのファイルのみが生成されました)
checkpoint テキスト ファイル (モデル ファイルのパス情報リストを記録します)
model.ckpt.data -00000 -of-00001 ネットワーク重み情報
model.ckpt.index 2 つのファイル .data と .index は、モデル内の可変パラメータ (重み) 情報を保存するバイナリ ファイルです
model.ckpt.metaモデルの計算グラフ構造情報(モデルのネットワーク構造)を保存するバイナリファイル protobuf
以上が tf.train.Saver().save() メソッドの基本的な使い方です。
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
global_step パラメーターを追加すると、1000 回の反復ごとにモデルが保存され、モデル ファイル model.ckpt-1000 の末尾に「-1000」が追加されます。 Index,model.ckpt-1000.meta,model.ckpt.data-1000-00000-of-00001
モデルは1000回の反復ごとに保存されますが、モデルの構造情報ファイルは変更されないだけになります。対応する 1000 回ごとに保存する必要がないため、メタ ファイルを保存する必要がない場合は、次のように write_meta_graph=False パラメーターを追加できます:
コードをコピーしますコードは次のとおりです:
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
2 時間ごとにモデルを保存し、最新の 4 つのモデルのみを保存したい場合は、max_to_keep を使用できます (デフォルト値は 5 ですが、トレーニングのエポックごとに保存したい場合は、 None または 0 に設定できますが、役に立たず、推奨されません)、keep_checkpoint_every_n_hours パラメーター、次のように:
コードをコピーします コードは次のとおりです:
tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
同時に, tf.train.Saver() クラスでは、何も指定しない場合、すべてのパラメータ情報が保存されます。また、保存したい内容の一部を指定することもできます。たとえば、x、y のみを保存するなどです。パラメーター (パラメーター リストまたは辞書を渡すことができます):
tf.train.Saver([x, y]).save(sess, ckpt_file_path)
ps。モデルのトレーニング プロセス中に、変数またはパラメーター名の属性名が失われることはありません。復元後に get_tensor_by_name() を介してモデルを取得することはできません。
モデルのロードと復元
上記のモデル保存の例では、モデルを復元するプロセスは次のとおりです:
import tensorflow as tf def restore_model_ckpt(ckpt_file_path): sess = tf.Session() saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构 saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息 # 直接获取保存的变量 print(sess.run('b:0')) # 获取placeholder变量 input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') # 获取需要进行计算的operator op = sess.graph.get_tensor_by_name('op_to_store:0') # 加入新的操作 add_on_op = tf.multiply(op, 2) ret = sess.run(add_on_op, {input_x: 5, input_y: 5}) print(ret)
最初にモデル構造を復元し、次に変数(トレーニング済みモデル内のさまざまな情報 (保存された変数、プレースホルダー変数、演算子など) を取得でき、取得した変数にさまざまな新しい操作を追加できます (上記のコードのコメントを参照)。
また、これに基づいていくつかのモデルをロードし、他の操作を追加することもできます。詳細については、公式ドキュメントとデモを参照してください。
ckpt モデル ファイルの保存と復元については、stackoverflow に明確に説明された回答があります。それを参照してください。
同時に、cv-tricks.com にある TensorFlow モデルの保存と復元に関するチュートリアルも非常に優れているので、参照してください。
「Tensorflow 1.0 Learning: Model Saving and Restoration (Saver)」には、Saver の使用に関するヒントがいくつか記載されています。
単一のモデル ファイル (.pb) を保存します
Tensorflow の inception-v3 のデモを自分で実行したところ、実行の完了後に .pb モデル ファイルが生成されることがわかりました。このファイルは後続のファイルに使用されます。はい、これは 1 つのファイルであり、非常に優れており、非常に便利です。
このプロセスの主な考え方は、graph_def ファイルにはネットワーク内の変数値が含まれていない (通常は重みが保存されている) が、定数値は含まれているため、変換できれば変数を定数に変換すると (graph_util.convert_variables_to_constants() 関数を使用)、1 つのファイルを使用してネットワーク アーキテクチャと重みの両方を保存するという目標を達成できます。
ps: ここで .pb はモデル ファイルのサフィックス名です。 もちろん、他のサフィックスも使用できます (Google との一貫性を保つために .pb を使用します╮(╯▽╰)╭)
モデルを保存します。
同様に、上記の例に基づいた簡単なデモ:
import tensorflow as tf import os from tensorflow.python.framework import graph_util def save_mode_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 这里的输出需要加上name属性 op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(pb_file_path)) if os.path.isdir(path) is False: os.makedirs(path) # convert_variables_to_constants 需要指定output_node_names,list(),可以多个 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(constant_graph.SerializeToString()) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
程序生成并保存一个文件
model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息
模型加载还原
针对上面的模型保存例子,还原模型的过程如下:
import tensorflow as tf from tensorflow.python.platform import gfile def restore_mode_pb(pb_file_path): sess = tf.Session() with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') print(sess.run('b:0')) input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') op = sess.graph.get_tensor_by_name('op_to_store:0') ret = sess.run(op, {input_x: 5, input_y: 5}) print(ret)
模型的还原过程与checkpoint差不多一样。
《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。
思考
模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。
同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。
选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。
以上完整代码:github https://github.com/liuyan731/tf_demo
相关推荐:
以上がTensorflow モデルの読み込みの保存と復元に関する簡単な説明の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

slicingapythonlistisdoneusingtheyntaxlist [start:stop:step] .hore'showitworks:1)startisthe indexofthefirstelementtoinclude.2)spotisthe indexofthefirmenttoeexclude.3)staptistheincrementbetbetinelements

numpyallows forvariousoperationsonarrays:1)basicarithmeticlikeaddition、減算、乗算、および分割; 2)AdvancedperationssuchasmatrixMultiplication;

Arraysinpython、特にnumpyandpandas、aresentialfordataanalysis、offeringspeedandeficiency.1)numpyarraysenable numpyarraysenable handling forlaredatasents andcomplexoperationslikemoverages.2)Pandasextendsnumpy'scapabivitieswithdataframesfortruc

listsandnumpyarraysinpythonhavedifferentmemoryfootprints:listsaremoreflexiblellessmemory-efficient、whileenumpyarraysaraysareoptimizedfornumericaldata.1)listsstorereferencesto objects、with whowedaround64byteson64-bitedatigu

toensurepythonscriptsbehaveCorrectlyAcrossDevelosment、staging、and Production、usetheseStrategies:1)環境variablesforsimplestetings、2)configurationfilesforcomplexsetups、and3)dynamicloadingforadaptability.eachtododododododofersuniquebentandrequiresca

Pythonリストスライスの基本的な構文はリストです[start:stop:step]。 1.STARTは最初の要素インデックス、2。ストップは除外された最初の要素インデックスであり、3.ステップは要素間のステップサイズを決定します。スライスは、データを抽出するためだけでなく、リストを変更および反転させるためにも使用されます。

ListSoutPerformArraysIn:1)ダイナミシジョンアンドフレーケンティオン/削除、2)ストーリングヘテロゼンダタ、および3)メモリ効率の装飾、ButmayhaveslightPerformancostsinceNASOPERATIONS。

toconvertapythonarraytoalist、usetheList()constructororageneratorexpression.1)importhearraymoduleandcreateanarray.2)useList(arr)または[xforxinarr] toconvertoalistは、largedatatessを変えることを伴うものです。


ホットAIツール

Undresser.AI Undress
リアルなヌード写真を作成する AI 搭載アプリ

AI Clothes Remover
写真から衣服を削除するオンライン AI ツール。

Undress AI Tool
脱衣画像を無料で

Clothoff.io
AI衣類リムーバー

Video Face Swap
完全無料の AI 顔交換ツールを使用して、あらゆるビデオの顔を簡単に交換できます。

人気の記事

ホットツール

SublimeText3 中国語版
中国語版、とても使いやすい

VSCode Windows 64 ビットのダウンロード
Microsoft によって発売された無料で強力な IDE エディター

ドリームウィーバー CS6
ビジュアル Web 開発ツール

Dreamweaver Mac版
ビジュアル Web 開発ツール

SublimeText3 Linux 新バージョン
SublimeText3 Linux 最新バージョン

ホットトピック









