>  기사  >  백엔드 개발  >  tensorflow1.0(Saver)_python으로 학습한 모델 저장 및 복원

tensorflow1.0(Saver)_python으로 학습한 모델 저장 및 복원

不言
不言원래의
2018-04-23 15:42:501820검색

이 글에서는 tensorflow1.0 학습 모델의 저장 및 복구(Saver)를 주로 소개하고 참고용으로 올려드립니다. 함께 살펴보겠습니다

학습된 모델 매개변수를 나중에 검증하거나 테스트하기 위해 저장하는 것은 우리가 자주 하는 일입니다. tf에 모델 저장을 제공하는 tf.train.Saver() 모듈입니다.

모델을 저장하려면 먼저 Saver 개체를 만들어야 합니다. 예:

saver=tf.train.Saver()

이 Saver 개체를 만들 때 자주 사용하는 매개 변수가 있는데, 이는 max_to_keep 매개 변수를 설정하는 데 사용됩니다. 모델 저장을 위한 매개변수입니다. 기본값은 5입니다. 즉, max_to_keep=5, 최신 5개 모델을 저장합니다. 훈련 세대(에포크)마다 모델을 저장하려면 max_to_keep을 None 또는 0으로 설정할 수 있습니다. 예:

saver=tf.train.Saver(max_to_keep=0)

그러나 이것은 더 많은 하드 디스크를 차지하는 것 외에는 실용적이지 않으므로 권장되지 않습니다. .

물론, 지난 세대 모델만 저장하고 싶다면 max_to_keep을 1로 설정하기만 하면 됩니다. 즉,

saver=tf.train.Saver(max_to_keep=1)

saver 객체를 생성한 후 다음과 같은 훈련된 모델을 저장할 수 있습니다. :

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

첫 번째 매개변수 세션은 말할 필요도 없습니다. 두 번째 파라미터는 저장 경로와 이름을 설정하고, 세 번째 파라미터는 모델명 뒤에 훈련 횟수를 접미사로 추가합니다.

saver.save(sess, 'my-model', global_step=0) ==> 파일 이름: 'my-model-0'
...
saver.save(sess, 'my-model', global_step= 1000) ==> filename: 'my-model-1000'

mnist 예를 보세요:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

코드의 빨간색 부분은 모델을 저장했지만, 나중에 저장된 모델은 이전 모델을 덮어쓰며 마지막 모델만 저장됩니다. 따라서 시간을 절약하고 저장 코드를 루프 외부에 배치할 수 있습니다(max_to_keep=1에만 적용되며 그렇지 않으면 여전히 루프 내부에 배치해야 합니다).

실험에서 마지막 세대는 다음 세대가 아닐 수도 있습니다. 기본적으로 마지막 세대를 저장하고 싶지 않지만 가장 높은 검증 정확도로 세대를 저장하고 싶다면 중간 변수와 판단문을 추가하면 됩니다.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

최고의 검증 정확도로 3대를 저장하고, 매번 검증 정확도도 저장하려면 txt 파일을 생성하여 저장할 수 있습니다.

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

restore() 함수는 모델을 복원하는 데 사용되며, 여기에는 두 개의 매개변수 복원(sess, save_path)이 필요하며, save_path는 저장된 모델 경로를 나타냅니다. tf.train.latest_checkpoint()를 사용하여 마지막으로 저장된 모델을 자동으로 가져올 수 있습니다. 예:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

그런 다음 프로그램 후반부의 코드를 다음과 같이 변경할 수 있습니다.

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

빨간색으로 표시된 부분은 모델 저장 및 복원과 관련된 코드입니다. 훈련 및 검증 단계를 제어하려면 부울 변수 is_train을 사용하십시오.

전체 소스 프로그램:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

관련 권장 사항:

TensorFlow의 모델 네트워크를 단일 파일로 내보내는 방법

위 내용은 tensorflow1.0(Saver)_python으로 학습한 모델 저장 및 복원의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.