Home  >  Article  >  Backend Development  >  A brief discussion on saving and restoring loading of Tensorflow models

A brief discussion on saving and restoring loading of Tensorflow models

不言
不言Original
2018-04-26 16:40:542328browse

This article mainly introduces the saving and restoring of Tensorflow model. Now I share it with you and give it as a reference. Let’s come and take a look

Recently we have done some anti-spam work. In addition to using commonly used rule matching and filtering methods, we also used some machine learning methods for classification prediction. We use TensorFlow to train the model. The trained model needs to be saved. In the prediction phase, we need to load and restore the model for use, which involves saving and restoring the TensorFlow model.

Summarize the commonly used model saving methods of Tensorflow.

Save checkpoint model file (.ckpt)

First of all, TensorFlow provides a very convenient API, tf.train.Saver() to save and restore a machine learning model.

Model Saving

It is very convenient to use tf.train.Saver() to save model files. Here is a simple example:


import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))


The program generates and saves four files (before version 0.11, only three files were generated: checkpoint, model.ckpt , model.ckpt.meta)

  1. checkpoint text file, recording the path information list of the model file

  2. model.ckpt.data-00000 -of-00001 Network weight information

  3. model.ckpt.index The two files .data and .index are binary files that save the variable parameter (weight) information in the model

  4. model.ckpt.meta binary file, which saves the calculation graph structure information of the model (the network structure of the model) protobuf

The above is tf.train The basic usage of .Saver().save(), the save() method also has many configurable parameters:


tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)


Adding the global_step parameter means saving the model after every 1000 iterations. "-1000" will be added after the model file, model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data- 1000-00000-of-00001

Save the model every 1000 iterations, but the structural information file of the model will not change. It will only be saved every 1000 iterations, not every 1000 times. Save once, so when we don’t need to save the meta file, we can add the write_meta_graph=False parameter, as follows:


Copy code Code As follows:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

If you want to save the model every two hours and only save the latest 4 models, you can add max_to_keep (the default value is 5. If you want to save it every epoch of training, you can It is set to None or 0, but it is useless and not recommended), keep_checkpoint_every_n_hours parameter, as follows:


##Copy code The code is as follows:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)


At the same time, in the tf.train.Saver() class, if we do not specify any information, all parameter information will be saved. We can also specify some parts that we want to save. The content, for example, only save x, y parameters (parameter list or dict can be passed in):



tf.train.Saver([x, y]).save(sess, ckpt_file_path)


##ps. During the model training process, the variable or parameter name attribute name that needs to be obtained after saving cannot be lost, otherwise the model cannot be obtained through get_tensor_by_name() after restoration.


Model loading and restoration


For the above model saving example, the process of restoring the model is as follows:


import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)


First restore the model structure, then restore the variable (parameter) information, and finally we can obtain various information in the trained model (saved variables , placeholder variables, operators, etc.), and various new operations can be added to the obtained variables (see the above code comments).

Moreover, we can also load some models and add other operations on this basis. For details, please refer to the official documents and demo.


Regarding the saving and restoration of ckpt model files, there is an answer on stackoverflow with a clear explanation, which you can refer to.

At the same time, the tutorial on saving and restoring TensorFlow models on cv-tricks.com is also very good, you can refer to it.

"Tensorflow 1.0 Learning: Model Saving and Restoration (Saver)" has some Saver usage tips.

Save a single model file (.pb)


I have run the demo of Tensorflow's inception-v3 myself and found that a .pb will be generated after the run is completed. Model file, this file is used for subsequent prediction or transfer learning. It is just one file, very cool and very convenient.


The main idea of ​​this process is that the graph_def file does not contain the Variable value in the network (usually the weight is stored), but it does contain the constant value, so if we can convert the Variable to constant ( Using the graph_util.convert_variables_to_constants() function), you can achieve the goal of using one file to store both the network architecture and weights.


ps: Here .pb is the suffix name of the model file. Of course, we can also use other suffixes (use .pb to be consistent with Google ╮(╯▽╰)╭)


Model saving


Similarly based on the above example, a simple demo:


import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

def save_mode_pb(pb_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  # 这里的输出需要加上name属性
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(pb_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
  with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))

程序生成并保存一个文件

model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:


import tensorflow as tf
from tensorflow.python.platform import gfile

def restore_mode_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

  print(sess.run('b:0'))

  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op_to_store:0')

  ret = sess.run(op, {input_x: 5, input_y: 5})
  print(ret)


模型的还原过程与checkpoint差不多一样。

《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。

思考

模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。

同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。

选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。

以上完整代码:github https://github.com/liuyan731/tf_demo

相关推荐:

TensorFlow模型保存和提取方法示例


The above is the detailed content of A brief discussion on saving and restoring loading of Tensorflow models. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn