이 글에서는 TensorFlow 네트워크를 단일 파일로 내보내는 방법을 주로 소개하고 있으니 참고용으로 올려보겠습니다. 함께 살펴보겠습니다
때로는 다른 곳에서 쉽게 사용할 수 있도록(예: C++에서 네트워크 배포) TensorFlow 모델을 단일 파일(모델 아키텍처 정의 및 가중치 포함)로 내보내야 할 때가 있습니다. tf.train.write_graph()를 사용하면 기본적으로 네트워크 정의(가중치 없이)만 내보내는 반면, tf.train.Saver().save()를 사용하여 내보낸 graph_def 파일은 가중치와 분리되므로 다른 방법을 사용해야 합니다. 방법을 사용합니다.
graph_def 파일에는 네트워크의 변수 값이 포함되어 있지 않지만(보통 가중치가 저장됨) 상수 값이 포함되어 있으므로 변수를 상수로 변환할 수 있으면 하나의 파일을 사용하여 저장할 수 있습니다. 네트워크 아키텍처를 동시에 목표로 가중치를 부여합니다.
다음과 같은 방법으로 가중치를 동결하고 네트워크를 저장할 수 있습니다.
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants # 构造网络 a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') # 一定要给输出tensor取一个名字!! output = tf.add(a, b, name='out') # 转换Variable为constant,并将网络写入到文件 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 这里需要填入输出tensor的名字 graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
네트워크를 복원할 때 다음 방법을 사용할 수 있습니다.
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
출력 결과는 다음과 같습니다.
[array([[ 7 .],
[ 8.]], dtype=float32)]
이전 가중치가 실제로 저장되는 것을 볼 수 있습니다!!
문제는 우리 네트워크에 인터페이스가 필요하다는 것입니다. 맞춤 데이터 입력을 위해 아! 그렇지 않으면 이 일이 무슨 소용이 있겠습니까? . 걱정하지 마십시오. 물론 방법이 있습니다.
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') input_tensor = tf.placeholder(tf.float32, name='input') output = tf.add((a+b), input_tensor, name='out') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
위 코드를 사용하여 네트워크를 graph.pb에 다시 저장합니다. 이번에는 네트워크를 복원하고 사용자 정의 데이터를 입력하는 방법을 살펴보겠습니다.
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') print(sess.run(output))
출력 결과는 다음과 같습니다.
[array([[ 11.],
[ 12.]], dtype=float32)]
에 문제가 없음을 확인할 수 있습니다. 결과는 물론 input_map에서 아래와 같이 새로운 사용자 정의 자리 표시자로 대체될 수 있습니다.
import tensorflow as tf new_input = tf.placeholder(tf.float32, shape=()) with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') print(sess.run(output, feed_dict={new_input:4}))
출력을 보면 문제가 없습니다.
[array([[ 11.],
[ 12.]], dtype=float32)]
설명해야 할 또 다른 요점은 tf.train.write_graph를 사용하여 네트워크 아키텍처를 작성할 때, as_text=True인 경우 네트워크로 가져올 때 약간의 수정이 필요합니다.
import tensorflow as tf from google.protobuf import text_format with tf.Session() as sess: # 不使用'rb'模式 with open('./graph.pb', 'r') as f: graph_def = tf.GraphDef() # 不使用graph_def.ParseFromString(f.read()) text_format.Merge(f.read(), graph_def) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
관련 추천:
TensorFlow 설치 및 jupyter 노트북 구성에 대한 자세한 설명
위 내용은 TensorFlow의 모델 네트워크를 단일 파일로 내보내는 방법의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!