>  기사  >  백엔드 개발  >  Tensorflow의 Saver를 사용하는 방법

Tensorflow의 Saver를 사용하는 방법

不言
不言원래의
2018-04-23 15:46:311991검색

이 글은 Tensorflow의 Saver에 대한 자세한 사용법을 주로 소개하고 참고용으로 올려드립니다. 함께 살펴볼까요

Saver의 사용법

1. Saver의 배경 소개

우리는 모델을 훈련한 후 훈련 결과를 저장하고 싶을 때가 많습니다. 다음 반복을 위한 교육 또는 테스트에 사용됩니다. Tensorflow는 이 요구 사항에 대해 Saver 클래스를 제공합니다.

Saver 클래스는 체크포인트 파일에서 변수를 저장하고 복원하는 관련 메서드를 제공합니다. 체크포인트 파일은 변수 이름을 해당 텐서 값에 매핑하는 바이너리 파일입니다.

카운터가 제공되는 한 Saver 클래스는 카운터가 트리거될 때 자동으로 체크포인트 파일을 생성할 수 있습니다. 이를 통해 훈련 중에 여러 중간 결과를 저장할 수 있습니다. 예를 들어 각 훈련 단계의 결과를 저장할 수 있습니다.

전체 디스크가 가득 차지 않도록 Saver는 체크포인트 파일을 자동으로 관리할 수 있습니다. 예를 들어, 가장 최근 N개의 체크포인트 파일을 저장하도록 지정할 수 있습니다.

2. Saver 인스턴스

다음은 Saver 클래스

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))

  1. isTrain을 사용하는 방법의 예입니다. 학습 단계와 테스트 단계를 구분하는 데 사용됩니다. True는 학습을 의미합니다. False는 테스트를 의미합니다.

  2. train_steps: 훈련 횟수를 나타냅니다.

  3. checkpoint_steps: 훈련 중에 체크포인트를 저장할 횟수를 나타냅니다.

  4. checkpoint_dir: 체크포인트 파일의 저장 경로를 나타냅니다. 예에서 현재 경로는

2.1 교육 단계

Saver.save() 메서드를 사용하여 모델을 저장합니다.

  1. sess: 현재 변수 값을 기록하는 현재 세션 ​​

  2. checkpoint_dir + 'model.ckpt': 저장된 파일 이름을 나타냅니다.

  3. global_step: 현재 단계를 나타냅니다.

훈련이 완료된 후 현재 디렉토리 아래에 5개의 파일이 더 있어야 합니다.

Tensorflow의 Saver를 사용하는 방법

Tensorflow의 Saver를 사용하는 방법2.1 테스트 단계 </p><p>테스트 단계에서는 saver.restore() 메서드를 사용하여 변수를 복원합니다. <br/></p><p>sess: 현재 세션을 나타내며 이전에 저장된 결과가 이 세션에 로드됩니다. <br/></p><p>ckpt .model_checkpoint_path: 모델을 나타냅니다. 저장 위치는 모델 이름을 제공할 필요가 없습니다. 체크포인트 파일을 확인하여 최신 모델이 누구인지, 이름이 무엇인지 확인합니다. <br/></p><p>실행 결과는 아래 그림에 표시되며, 이전에 학습된 매개변수 w 및 b의 결과를 로드합니다. <br/>tensorflow1.0 학습 모델 저장 및 복원(Saver)_python</p><p   style=

위 내용은 Tensorflow의 Saver를 사용하는 방법의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.