ホームページ  >  記事  >  バックエンド開発  >  Tensorflow モデルの読み込みの保存と復元に関する簡単な説明

Tensorflow モデルの読み込みの保存と復元に関する簡単な説明

不言
不言オリジナル
2018-04-26 16:40:542261ブラウズ

この記事では主に Tensorflow モデルの保存と復元について紹介しますので、参考にしてください。ぜひ見てみましょう

最近、私たちは共通ルール、マッチング、フィルタリングの使用に加えて、分類予測にいくつかの機械学習手法も使用しています。 TensorFlow を使用してモデルをトレーニングします。トレーニングされたモデルは、予測フェーズで保存する必要があります。これには、TensorFlow モデルの保存と復元が含まれます。

Tensorflow の一般的に使用されるモデル保存方法をまとめます。

チェックポイントモデルファイル(.ckpt)の保存

まず第一に、TensorFlowは、機械学習モデルを保存および復元するための非常に便利なAPI、tf.train.Saver()を提供します。

モデルの保存

モデル ファイルを保存するには tf.train.Saver() を使用すると非常に便利です。簡単な例を示します:


import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))


ファイル (バージョン 0.11 より前では、checkpoint、model.ckpt、model.ckpt.meta の 3 つのファイルのみが生成されました)

  1. checkpoint テキスト ファイル (モデル ファイルのパス情報リストを記録します)

  2. model.ckpt.data -00000 -of-00001 ネットワーク重み情報

  3. model.ckpt.index 2 つのファイル .data と .index は、モデル内の可変パラメータ (重み) 情報を保存するバイナリ ファイルです

  4. model.ckpt.metaモデルの計算グラフ構造情報(モデルのネットワーク構造)を保存するバイナリファイル protobuf

以上が tf.train.Saver().save() メソッドの基本的な使い方です。


tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)


global_step パラメーターを追加すると、1000 回の反復ごとにモデルが保存され、モデル ファイル model.ckpt-1000 の末尾に「-1000」が追加されます。 Index,model.ckpt-1000.meta,model.ckpt.data-1000-00000-of-00001

モデルは1000回の反復ごとに保存されますが、モデルの構造情報ファイルは変更されないだけになります。対応する 1000 回ごとに保存する必要がないため、メタ ファイルを保存する必要がない場合は、次のように write_meta_graph=False パラメーターを追加できます:


コードをコピーしますコードは次のとおりです:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

2 時間ごとにモデルを保存し、最新の 4 つのモデルのみを保存したい場合は、max_to_keep を使用できます (デフォルト値は 5 ですが、トレーニングのエポックごとに保存したい場合は、 None または 0 に設定できますが、役に立たず、推奨されません)、keep_checkpoint_every_n_hours パラメーター、次のように:


コードをコピーします コードは次のとおりです:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)


同時に, tf.train.Saver() クラスでは、何も指定しない場合、すべてのパラメータ情報が保存されます。また、保存したい内容の一部を指定することもできます。たとえば、x、y のみを保存するなどです。パラメーター (パラメーター リストまたは辞書を渡すことができます):


tf.train.Saver([x, y]).save(sess, ckpt_file_path)


ps。モデルのトレーニング プロセス中に、変数またはパラメーター名の属性名が失われることはありません。復元後に get_tensor_by_name() を介してモデルを取得することはできません。

モデルのロードと復元

上記のモデル保存の例では、モデルを復元するプロセスは次のとおりです:


import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)


最初にモデル構造を復元し、次に変数(トレーニング済みモデル内のさまざまな情報 (保存された変数、プレースホルダー変数、演算子など) を取得でき、取得した変数にさまざまな新しい操作を追加できます (上記のコードのコメントを参照)。
また、これに基づいていくつかのモデルをロードし、他の操作を追加することもできます。詳細については、公式ドキュメントとデモを参照してください。

ckpt モデル ファイルの保存と復元については、stackoverflow に明確に説明された回答があります。それを参照してください。

同時に、cv-tricks.com にある TensorFlow モデルの保存と復元に関するチュートリアルも非常に優れているので、参照してください。

「Tensorflow 1.0 Learning: Model Saving and Restoration (Saver)」には、Saver の使用に関するヒントがいくつか記載されています。

単一のモデル ファイル (.pb) を保存します

Tensorflow の inception-v3 のデモを自分で実行したところ、実行の完了後に .pb モデル ファイルが生成されることがわかりました。このファイルは後続のファイルに使用されます。はい、これは 1 つのファイルであり、非常に優れており、非常に便利です。

このプロセスの主な考え方は、graph_def ファイルにはネットワーク内の変数値が含まれていない (通常は重みが保存されている) が、定数値は含まれているため、変換できれば変数を定数に変換すると (graph_util.convert_variables_to_constants() 関数を使用)、1 つのファイルを使用してネットワーク アーキテクチャと重みの両方を保存するという目標を達成できます。

ps: ここで .pb はモデル ファイルのサフィックス名です。 もちろん、他のサフィックスも使用できます (Google との一貫性を保つために .pb を使用します╮(╯▽╰)╭)

モデルを保存します。

同様に、上記の例に基づいた簡単なデモ:


import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

def save_mode_pb(pb_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  # 这里的输出需要加上name属性
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(pb_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
  with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))


程序生成并保存一个文件

model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:


import tensorflow as tf
from tensorflow.python.platform import gfile

def restore_mode_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

  print(sess.run('b:0'))

  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op_to_store:0')

  ret = sess.run(op, {input_x: 5, input_y: 5})
  print(ret)


模型的还原过程与checkpoint差不多一样。

《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。

思考

模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。

同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。

选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。

以上完整代码:github https://github.com/liuyan731/tf_demo

相关推荐:

TensorFlow模型保存和提取方法示例


以上がTensorflow モデルの読み込みの保存と復元に関する簡単な説明の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。