ホームページ >バックエンド開発 >Python チュートリアル >TensorFlow のモデル ネットワークを単一のファイルにエクスポートする方法
この記事では主に TensorFlow ネットワークを単一のファイルにエクスポートする方法を紹介し、参考として提供します。一緒に見てみましょう
場合によっては、他の場所 (C++ でのネットワークのデプロイなど) で簡単に使用できるように、TensorFlow モデルを単一のファイル (モデル アーキテクチャの定義と重みを含む) にエクスポートする必要があります。 tf.train.write_graph() を使用すると、デフォルトではネットワークの定義 (重みなし) のみがエクスポートされますが、tf.train.Saver().save() を使用してエクスポートされたファイルgraph_def は重みから分離されるため、他のメソッドは次のことを行う必要があります。方法が使用されます。
graph_def ファイルにはネットワーク内の変数値が含まれていないことがわかっています (通常は重みが保存されています)。ただし、定数値は含まれているため、変数を定数に変換できれば、1 つのファイルを使用して保存できます。ネットワーク アーキテクチャと同時に重みを付けてターゲットを設定します。
次の方法で重みをフリーズしてネットワークを保存できます:
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants # 构造网络 a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') # 一定要给输出tensor取一个名字!! output = tf.add(a, b, name='out') # 转换Variable为constant,并将网络写入到文件 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 这里需要填入输出tensor的名字 graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
ネットワークを復元するときは、次の方法を使用できます:
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
出力結果は次のとおりです:
[array([[ 7 .],
[ 8.]], dtype=float32)]
以前の重みが実際に保存されていることがわかります!!
問題は、ネットワークにインターフェイスが必要であることです。カスタムデータを入力するためです!そうでなければ、これは何の役に立つのでしょう。 。心配しないでください、もちろん方法はあります。
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') input_tensor = tf.placeholder(tf.float32, name='input') output = tf.add((a+b), input_tensor, name='out') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
上記のコードを使用して、ネットワークをgraph.pbに再保存します。今回は、ネットワークを復元してカスタム データを入力する方法を見てみましょう。
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') print(sess.run(output))
出力結果は、
[array([[ 11.],
[ 12.]], dtype=float32)]
で問題ないことがわかります。もちろん、input_map 内の結果です。以下に示すように、新しいカスタム プレースホルダーに置き換えることができます。
import tensorflow as tf new_input = tf.placeholder(tf.float32, shape=()) with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') print(sess.run(output, feed_dict={new_input:4}))
出力を見てください。問題はありません。
[array([[ 11.],
[ 12.]], dtype=float32)]
もう 1 つ説明する必要がある点は、 tf.train.write_graph を使用してネットワーク アーキテクチャを記述する場合、 as_text=True の場合、ネットワークにインポートするときに若干の変更を加える必要があります。
import tensorflow as tf from google.protobuf import text_format with tf.Session() as sess: # 不使用'rb'模式 with open('./graph.pb', 'r') as f: graph_def = tf.GraphDef() # 不使用graph_def.ParseFromString(f.read()) text_format.Merge(f.read(), graph_def) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
関連する推奨事項:
TensorFlow のインストールと jupyter Notebook 構成の詳細な説明
以上がTensorFlow のモデル ネットワークを単一のファイルにエクスポートする方法の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。