Home >Backend Development >Python Tutorial >How to export TensorFlow's model network to a single file
This article mainly introduces the method of exporting the TensorFlow network into a single file. Now I share it with you and give it as a reference. Let’s take a look together
Sometimes, we need to export the TensorFlow model to a single file (including model architecture definition and weights) for easy use in other places (such as deploying a network in c). Using tf.train.write_graph() only exports the definition of the network (without weights) by default, while the file graph_def exported using tf.train.Saver().save() is separated from the weights, so other methods need to be used. method.
We know that the graph_def file does not contain the Variable value in the network (usually the weight is stored), but it does contain the constant value, so if we can convert the Variable to constant, we can use a file The goal of simultaneously storing network architecture and weights.
We can freeze the weights and save the network in the following way:
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants # 构造网络 a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') # 一定要给输出tensor取一个名字!! output = tf.add(a, b, name='out') # 转换Variable为constant,并将网络写入到文件 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 这里需要填入输出tensor的名字 graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
When restoring the network, we can use the following way:
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
The output result is:
[array([[ 7.],
[ 8.]], dtype=float32) ]
You can see that the previous weights are indeed saved!!
The problem is, our network needs to have an interface for inputting custom data! Otherwise, what's the use of this thing. . Don't worry, of course there is a way.
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') input_tensor = tf.placeholder(tf.float32, name='input') output = tf.add((a+b), input_tensor, name='out') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Use the above code to resave the network to graph.pb. This time we have an input placeholder. Let’s see how to restore the network and enter custom data.
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') print(sess.run(output))
The output result is:
[array([[ 11.],
[ 12.]], dtype=float32)]
You can see that there is no problem with the result. Of course, the input_map can be replaced with a new custom placeholder, as shown below:
import tensorflow as tf new_input = tf.placeholder(tf.float32, shape=()) with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') print(sess.run(output, feed_dict={new_input:4}))
Look at the output, there is no problem.
[array([[ 11.],
[ 12.]], dtype=float32)]
Another point that needs to be explained is , when using tf.train.write_graph to write the network architecture, if as_text=True is set, a small modification needs to be made when importing the network.
import tensorflow as tf from google.protobuf import text_format with tf.Session() as sess: # 不使用'rb'模式 with open('./graph.pb', 'r') as f: graph_def = tf.GraphDef() # 不使用graph_def.ParseFromString(f.read()) text_format.Merge(f.read(), graph_def) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
Related recommendations:
TensorFlow installation and detailed explanation of jupyter notebook configuration
The above is the detailed content of How to export TensorFlow's model network to a single file. For more information, please follow other related articles on the PHP Chinese website!