>  기사  >  백엔드 개발  >  TensorFlow 시작하기 및 tf.train.Saver()를 사용하여 모델 저장

TensorFlow 시작하기 및 tf.train.Saver()를 사용하여 모델 저장

不言
不言원래의
2018-04-24 14:15:074074검색

이 글에서는 TensorFlow를 시작할 때 tf.train.Saver()를 사용하여 모델을 저장하는 방법을 주로 소개하고 참고용으로 제공합니다. 함께 살펴볼까요

모델 저장에 대한 생각

saver = tf.train.Saver(max_to_keep=3)

보통 저장기를 정의할 때 저장되는 모델의 최대 개수를 정의하는 경우가 많습니다. 일반적으로 모델 자체가 크다면 하드를 고려해야 합니다. 디스크 크기. 현재 학습된 모델을 기반으로 미세 조정을 수행해야 하는 경우 가능한 한 많은 모델을 저장해야 하며, 한 번에 과적합될 수 있으므로 반드시 최상의 ckpt에서 후속 미세 조정을 수행할 필요는 없습니다. 하지만 너무 많은 파일을 저장하면 하드 디스크에 부담이 가해집니다. 최상의 모델만 유지하려는 경우 특정 단계까지 반복할 때마다 검증 세트에 대한 정확도 또는 f1 값을 계산하는 방법이 있습니다. 이번 결과가 지난번보다 좋으면 새로 저장합니다. 그렇지 않으면 저장할 필요가 없습니다.

다른 시대에 저장된 모델을 융합에 사용하려면 3~5개의 모델이면 충분합니다. 융합된 모델이 M이 되고, 가장 좋은 단일 모델이 M_best라고 가정하면 실제로 더 좋을 수 있습니다. m_best보다. 그러나 이 모델을 다른 구조의 모델과 융합하면 M의 효과는 m_best만큼 좋지 않습니다. M은 모델의 "특성"을 감소시키는 평균 연산과 동일하기 때문입니다.

하지만 학습률을 조정하여 여러 로컬 최적점을 얻는 새로운 융합 방법이 있습니다. 즉, 손실을 줄일 수 없는 경우 ckpt를 저장한 후 학습률을 높여 다음 항목을 계속 찾는 것입니다. 아직까지 이러한 ckpts를 융합해 본 적은 없지만, 단일 모델은 확실히 개선되겠지만, 위의 다른 모델과의 융합이 개선되지 않는 상황이 있을지는 모르겠습니다.

모델을 저장하기 위해 tf.train.Saver()를 사용하는 방법

이전에 주로 코딩 문제로 인해 오류가 발생했습니다. 따라서 파일 경로에 한자가 들어가지 않도록 주의하세요.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

파일에 저장된 모델: ./ckpt/test-model.ckpt-1

위의 모델을 저장한 후 참고하세요. 다음 모델을 사용하여 가져오기 전에 커널을 다시 시작해야 합니다. 그렇지 않으면 "v1"이라는 이름을 두 번 지정하여 이름이 잘못됩니다.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:./ckpt/test-model.ckpt-1
모델에서 매개변수 복원 중.
[ 1.       2.29999995]
55.5

모델을 가져오기 전 , 재정의해야 함 다시 변수.

하지만 모든 변수를 재정의할 필요는 없으며 필요한 변수만 정의하면 됩니다.

즉, 정의한 변수는 체크포인트에 존재해야 하지만 체크포인트의 모든 변수를 다시 정의해야 하는 것은 아닙니다.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:./ckpt/test-model.ckpt-1
모델에서 매개변수 복원 중.
[ 1.       2.29999995]

tf.Saver([tensors_to _be _저장됨]) 확인 목록과 저장할 텐서를 전달합니다. 목록이 제공되지 않으면 기본적으로 현재 모든 텐서를 저장합니다. 일반적으로 tf.Saver는 tf.variable_scope()와 영리하게 결합할 수 있습니다. [전이 학습] 이미 저장된 모델에 새 변수 추가 및 미세 조정

관련 권장 사항:

Tensorflow 정보 tf.train.batch 함수

위 내용은 TensorFlow 시작하기 및 tf.train.Saver()를 사용하여 모델 저장의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.