>  기사  >  백엔드 개발  >  TensorFlow의 모델 네트워크를 단일 파일로 내보내는 방법

TensorFlow의 모델 네트워크를 단일 파일로 내보내는 방법

不言
不言원래의
2018-04-23 15:39:491710검색

이 글에서는 TensorFlow 네트워크를 단일 파일로 내보내는 방법을 주로 소개하고 있으니 참고용으로 올려보겠습니다. 함께 살펴보겠습니다

때로는 다른 곳에서 쉽게 사용할 수 있도록(예: C++에서 네트워크 배포) TensorFlow 모델을 단일 파일(모델 아키텍처 정의 및 가중치 포함)로 내보내야 할 때가 있습니다. tf.train.write_graph()를 사용하면 기본적으로 네트워크 정의(가중치 없이)만 내보내는 반면, tf.train.Saver().save()를 사용하여 내보낸 graph_def 파일은 가중치와 분리되므로 다른 방법을 사용해야 합니다. 방법을 사용합니다.

graph_def 파일에는 네트워크의 변수 값이 포함되어 있지 않지만(보통 가중치가 저장됨) 상수 값이 포함되어 있으므로 변수를 상수로 변환할 수 있으면 하나의 파일을 사용하여 저장할 수 있습니다. 네트워크 아키텍처를 동시에 목표로 가중치를 부여합니다.

다음과 같은 방법으로 가중치를 동결하고 네트워크를 저장할 수 있습니다.

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

네트워크를 복원할 때 다음 방법을 사용할 수 있습니다.

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

출력 결과는 다음과 같습니다.

[array([[ 7 .],
                 [ 8.]], dtype=float32)]

이전 가중치가 실제로 저장되는 것을 볼 수 있습니다!!

문제는 우리 네트워크에 인터페이스가 필요하다는 것입니다. 맞춤 데이터 입력을 위해 아! 그렇지 않으면 이 일이 무슨 소용이 있겠습니까? . 걱정하지 마십시오. 물론 방법이 있습니다.

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

위 코드를 사용하여 네트워크를 graph.pb에 다시 저장합니다. 이번에는 네트워크를 복원하고 사용자 정의 데이터를 입력하는 방법을 살펴보겠습니다.

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

출력 결과는 다음과 같습니다.

[array([[ 11.],
 [ 12.]], dtype=float32)]

에 문제가 없음을 확인할 수 있습니다. 결과는 물론 input_map에서 아래와 같이 새로운 사용자 정의 자리 표시자로 대체될 수 있습니다.

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

출력을 보면 문제가 없습니다.

[array([[ 11.],
 [ 12.]], dtype=float32)]

설명해야 할 또 다른 요점은 tf.train.write_graph를 사용하여 네트워크 아키텍처를 작성할 때, as_text=True인 경우 네트워크로 가져올 때 약간의 수정이 필요합니다.

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

관련 추천:

TensorFlow 설치 및 jupyter 노트북 구성에 대한 자세한 설명


위 내용은 TensorFlow의 모델 네트워크를 단일 파일로 내보내는 방법의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.