>백엔드 개발 >파이썬 튜토리얼 >Tensorflow 모델 저장 및 복원에 대한 간략한 토론

Tensorflow 모델 저장 및 복원에 대한 간략한 토론

不言
不言원래의
2018-04-26 16:40:542398검색

이 글은 주로 Tensorflow 모델의 저장과 복원을 소개하고 있으니 참고용으로 올려보겠습니다. 와서 살펴보세요

최근에는 일반적인 규칙, 매칭 및 필터링 외에도 분류 예측을 위해 몇 가지 기계 학습 방법을 사용하여 스팸 방지 작업을 수행했습니다. TensorFlow를 사용하여 모델을 훈련합니다. 훈련된 모델을 저장해야 합니다. 예측 단계에서는 TensorFlow 모델을 저장하고 복원하는 작업을 포함하여 사용할 모델을 로드하고 복원해야 합니다.

Tensorflow에서 일반적으로 사용되는 모델 저장 방법을 요약합니다.

체크포인트 모델 파일(.ckpt) 저장

먼저 TensorFlow는 머신러닝 모델을 저장하고 복원할 수 있는 매우 편리한 API인 tf.train.Saver()를 제공합니다.

모델 저장

모델 파일을 저장하려면 tf.train.Saver()를 사용하는 것이 매우 편리합니다. 다음은 간단한 예입니다.


import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))


프로그램은 4개를 생성하고 저장합니다. files (0.11 이전 버전에서는 checkpoint, model.ckpt, model.ckpt.meta 3개의 파일만 생성되었습니다.)

  1. 모델 파일의 경로 정보 목록을 기록하는 checkpoint 텍스트 파일

  2. model.ckpt.data -00000 -of-00001 네트워크 가중치 정보

  3. model.ckpt.index .data, .index 두 파일은 모델의 가변 매개변수(가중치) 정보를 저장하는 바이너리 파일입니다

  4. model.ckpt.meta 모델의 계산 그래프 구조 정보(모델의 네트워크 구조)를 저장하는 바이너리 파일 protobuf

위는 tf.train.Saver().save()의 기본 사용법입니다. 또한 많은 구성 가능한 매개변수가 있습니다:


tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)


global_step 매개변수를 추가한다는 것은 매 1000회 반복 후에 모델을 저장하는 것을 의미하며 모델 파일 model.ckpt-1000의 끝에 "-1000"이 추가됩니다. .index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

모델은 1000회 반복마다 저장되지만 모델의 구조 정보 파일만 변경되지는 않습니다. 해당하는 모든 반복 없이 1000회마다 저장됩니다. 1000회마다 한 번씩 저장하므로 메타 파일을 저장할 필요가 없는 경우 다음과 같이 write_meta_graph=False 매개변수를 추가할 수 있습니다. the code
코드는 다음과 같습니다.

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

2시간마다 모델을 저장하고 최신 4개 모델만 저장하려면 max_to_keep을 사용하면 됩니다(매 에포크마다 저장하려면 기본값은 5입니다). 훈련의 경우 None 또는 0으로 설정할 수 있지만 쓸모가 없으며 권장되지 않습니다), keep_checkpoint_every_n_hours 매개변수, 다음과 같습니다:

코드를 복사합니다
코드는 다음과 같습니다:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)


동시에 tf.train.Saver() 클래스에서 정보를 지정하지 않으면 모든 매개변수 정보가 저장됩니다. 또한 저장하려는 콘텐츠의 일부를 지정할 수도 있습니다. x, y 매개변수 저장(매개변수 목록 또는 dict가 전달될 수 있음):


tf.train.Saver([x, y]).save(sess, ckpt_file_path)


ps. 변수 또는 매개변수 이름 속성 이름은 가져올 수 없습니다. 그렇지 않으면 복원 후 get_tensor_by_name()을 통해 모델을 얻을 수 없습니다.


모델 로드 및 복원


위의 모델 저장 예시에서 모델을 복원하는 과정은 다음과 같습니다.


import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)


먼저 모델 구조를 복원한 다음 변수( 매개변수) 정보, 그리고 마지막으로 훈련된 모델의 다양한 정보(저장된 변수, 자리 표시자 변수, 연산자 등)를 얻을 수 있으며, 얻은 변수에 다양한 새로운 작업을 추가할 수 있습니다(위 코드 설명 참조).

또한 이를 기반으로 일부 모델을 로드하고 다른 작업을 추가할 수도 있습니다. 자세한 내용은 공식 문서와 데모를 참조하세요.

ckpt 모델 파일 저장 및 복원에 대해서는 stackoverflow에 명확하게 설명되어 있는 답변이 있으니 참고하시면 됩니다.


동시에 cv-tricks.com의 TensorFlow 모델 저장 및 복원에 대한 튜토리얼도 매우 훌륭하므로 참고하실 수 있습니다.

"Tensorflow 1.0 학습: 모델 저장 및 복원(Saver)"에는 Saver 사용 팁이 있습니다.

단일 모델 파일(.pb) 저장

Tensorflow의 inception-v3 데모를 직접 실행했는데 실행이 완료된 후 .pb 모델 파일이 생성되는 것을 확인했습니다. 이 파일은 후속 작업에 사용됩니다. 예측 또는 마이그레이션 학습입니다. 단 하나의 파일이므로 매우 멋지고 편리합니다.

이 프로세스의 주요 아이디어는 graph_def 파일에 네트워크의 변수 값이 포함되어 있지 않지만(일반적으로 가중치가 저장됨) 상수 값이 포함되어 있으므로 변환할 수 있다면 변수를 상수로 변경(graph_util.convert_variables_to_constants() 함수 사용)하면 하나의 파일을 사용하여 네트워크 아키텍처와 가중치를 모두 저장한다는 목표를 달성할 수 있습니다.

ps: 여기서 .pb는 모델 파일의 접미사 이름입니다. 물론 다른 접미사도 사용할 수 있습니다(Google과 일관성을 유지하려면 .pb를 사용하세요╮(╯▽╰)╭)


모델을 저장하세요.


마찬가지로 위의 예를 기반으로 한 간단한 데모:


import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

def save_mode_pb(pb_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  # 这里的输出需要加上name属性
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(pb_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
  with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))


程序生成并保存一个文件

model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:


import tensorflow as tf
from tensorflow.python.platform import gfile

def restore_mode_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

  print(sess.run('b:0'))

  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op_to_store:0')

  ret = sess.run(op, {input_x: 5, input_y: 5})
  print(ret)


模型的还原过程与checkpoint差不多一样。

《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。

思考

模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。

同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。

选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。

以上完整代码:github https://github.com/liuyan731/tf_demo

相关推荐:

TensorFlow模型保存和提取方法示例


위 내용은 Tensorflow 모델 저장 및 복원에 대한 간략한 토론의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.