Home > Article > Backend Development > A brief discussion on saving and restoring loading of Tensorflow models
This article mainly introduces the saving and restoring of Tensorflow model. Now I share it with you and give it as a reference. Let’s come and take a look
Recently we have done some anti-spam work. In addition to using commonly used rule matching and filtering methods, we also used some machine learning methods for classification prediction. We use TensorFlow to train the model. The trained model needs to be saved. In the prediction phase, we need to load and restore the model for use, which involves saving and restoring the TensorFlow model.
Summarize the commonly used model saving methods of Tensorflow.
Save checkpoint model file (.ckpt)
First of all, TensorFlow provides a very convenient API, tf.train.Saver() to save and restore a machine learning model.
Model Saving
It is very convenient to use tf.train.Saver() to save model files. Here is a simple example:
import tensorflow as tf import os def save_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(ckpt_file_path)) if os.path.isdir(path) is False: os.makedirs(path) tf.train.Saver().save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
The program generates and saves four files (before version 0.11, only three files were generated: checkpoint, model.ckpt , model.ckpt.meta)
checkpoint text file, recording the path information list of the model file
model.ckpt.data-00000 -of-00001 Network weight information
model.ckpt.index The two files .data and .index are binary files that save the variable parameter (weight) information in the model
model.ckpt.meta binary file, which saves the calculation graph structure information of the model (the network structure of the model) protobuf
The above is tf.train The basic usage of .Saver().save(), the save() method also has many configurable parameters:
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
Adding the global_step parameter means saving the model after every 1000 iterations. "-1000" will be added after the model file, model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data- 1000-00000-of-00001
Save the model every 1000 iterations, but the structural information file of the model will not change. It will only be saved every 1000 iterations, not every 1000 times. Save once, so when we don’t need to save the meta file, we can add the write_meta_graph=False parameter, as follows:
Copy code Code As follows:
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
If you want to save the model every two hours and only save the latest 4 models, you can add max_to_keep (the default value is 5. If you want to save it every epoch of training, you can It is set to None or 0, but it is useless and not recommended), keep_checkpoint_every_n_hours parameter, as follows:
##Copy code The code is as follows:
tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
tf.train.Saver([x, y]).save(sess, ckpt_file_path)
##ps. During the model training process, the variable or parameter name attribute name that needs to be obtained after saving cannot be lost, otherwise the model cannot be obtained through get_tensor_by_name() after restoration.
For the above model saving example, the process of restoring the model is as follows:
import tensorflow as tf def restore_model_ckpt(ckpt_file_path): sess = tf.Session() saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构 saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息 # 直接获取保存的变量 print(sess.run('b:0')) # 获取placeholder变量 input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') # 获取需要进行计算的operator op = sess.graph.get_tensor_by_name('op_to_store:0') # 加入新的操作 add_on_op = tf.multiply(op, 2) ret = sess.run(add_on_op, {input_x: 5, input_y: 5}) print(ret)
First restore the model structure, then restore the variable (parameter) information, and finally we can obtain various information in the trained model (saved variables , placeholder variables, operators, etc.), and various new operations can be added to the obtained variables (see the above code comments).
Regarding the saving and restoration of ckpt model files, there is an answer on stackoverflow with a clear explanation, which you can refer to.
At the same time, the tutorial on saving and restoring TensorFlow models on cv-tricks.com is also very good, you can refer to it.
"Tensorflow 1.0 Learning: Model Saving and Restoration (Saver)" has some Saver usage tips.
Save a single model file (.pb)
I have run the demo of Tensorflow's inception-v3 myself and found that a .pb will be generated after the run is completed. Model file, this file is used for subsequent prediction or transfer learning. It is just one file, very cool and very convenient.
The main idea of this process is that the graph_def file does not contain the Variable value in the network (usually the weight is stored), but it does contain the constant value, so if we can convert the Variable to constant ( Using the graph_util.convert_variables_to_constants() function), you can achieve the goal of using one file to store both the network architecture and weights.
ps: Here .pb is the suffix name of the model file. Of course, we can also use other suffixes (use .pb to be consistent with Google ╮(╯▽╰)╭)
Similarly based on the above example, a simple demo:
import tensorflow as tf import os from tensorflow.python.framework import graph_util def save_mode_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 这里的输出需要加上name属性 op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(pb_file_path)) if os.path.isdir(path) is False: os.makedirs(path) # convert_variables_to_constants 需要指定output_node_names,list(),可以多个 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(constant_graph.SerializeToString()) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
程序生成并保存一个文件
model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息
模型加载还原
针对上面的模型保存例子,还原模型的过程如下:
import tensorflow as tf from tensorflow.python.platform import gfile def restore_mode_pb(pb_file_path): sess = tf.Session() with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') print(sess.run('b:0')) input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') op = sess.graph.get_tensor_by_name('op_to_store:0') ret = sess.run(op, {input_x: 5, input_y: 5}) print(ret)
模型的还原过程与checkpoint差不多一样。
《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。
思考
模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。
同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。
选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。
以上完整代码:github https://github.com/liuyan731/tf_demo
相关推荐:
The above is the detailed content of A brief discussion on saving and restoring loading of Tensorflow models. For more information, please follow other related articles on the PHP Chinese website!