이 글은 Tensorflow의 Saver에 대한 자세한 사용법을 주로 소개하고 참고용으로 올려드립니다. 함께 살펴볼까요
Saver의 사용법
1. Saver의 배경 소개
우리는 모델을 훈련한 후 훈련 결과를 저장하고 싶을 때가 많습니다. 다음 반복을 위한 교육 또는 테스트에 사용됩니다. Tensorflow는 이 요구 사항에 대해 Saver 클래스를 제공합니다.
Saver 클래스는 체크포인트 파일에서 변수를 저장하고 복원하는 관련 메서드를 제공합니다. 체크포인트 파일은 변수 이름을 해당 텐서 값에 매핑하는 바이너리 파일입니다.
카운터가 제공되는 한 Saver 클래스는 카운터가 트리거될 때 자동으로 체크포인트 파일을 생성할 수 있습니다. 이를 통해 훈련 중에 여러 중간 결과를 저장할 수 있습니다. 예를 들어 각 훈련 단계의 결과를 저장할 수 있습니다.
전체 디스크가 가득 차지 않도록 Saver는 체크포인트 파일을 자동으로 관리할 수 있습니다. 예를 들어, 가장 최근 N개의 체크포인트 파일을 저장하도록 지정할 수 있습니다.
2. Saver 인스턴스
다음은 Saver 클래스
import tensorflow as tf import numpy as np x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b))
isTrain을 사용하는 방법의 예입니다. 학습 단계와 테스트 단계를 구분하는 데 사용됩니다. True는 학습을 의미합니다. False는 테스트를 의미합니다.
train_steps: 훈련 횟수를 나타냅니다.
checkpoint_steps: 훈련 중에 체크포인트를 저장할 횟수를 나타냅니다.
checkpoint_dir: 체크포인트 파일의 저장 경로를 나타냅니다. 예에서 현재 경로는
2.1 교육 단계
Saver.save() 메서드를 사용하여 모델을 저장합니다.
sess: 현재 변수 값을 기록하는 현재 세션
checkpoint_dir + 'model.ckpt': 저장된 파일 이름을 나타냅니다.
global_step: 현재 단계를 나타냅니다.
훈련이 완료된 후 현재 디렉토리 아래에 5개의 파일이 더 있어야 합니다.
위 내용은 Tensorflow의 Saver를 사용하는 방법의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!