Heim  >  Artikel  >  Backend-Entwicklung  >  Speichern und Wiederherstellen des von tensorflow1.0 (Saver)_python gelernten Modells

Speichern und Wiederherstellen des von tensorflow1.0 (Saver)_python gelernten Modells

不言
不言Original
2018-04-23 15:42:501748Durchsuche

In diesem Artikel wird hauptsächlich das Speichern und Wiederherstellen (Saver) des Tensorflow1.0-Lernmodells vorgestellt. Jetzt teile ich es mit Ihnen und gebe es als Referenz. Werfen wir gemeinsam einen Blick darauf

Speichern Sie die trainierten Modellparameter zur späteren Überprüfung oder zum Testen. Dies ist etwas, was wir oft tun. Das Modul tf.train.Saver(), das das Speichern von Modellen in tf ermöglicht.

Um das Modell zu speichern, müssen Sie zunächst ein Saver-Objekt erstellen: z. B.

saver=tf.train.Saver()

Beim Erstellen dieses Saver-Objekts gibt es Folgendes Ein Parameter wir Der Parameter max_to_keep wird häufig verwendet, um die Anzahl der gespeicherten Modelle festzulegen. Der Standardwert ist 5, dh max_to_keep = 5, wodurch die letzten 5 Modelle gespeichert werden. Wenn Sie das Modell in jeder Trainingsgeneration (Epoche) speichern möchten, können Sie max_to_keep auf None oder 0 setzen, wie zum Beispiel:

saver=tf.train.Saver(max_to_keep=0)

Aber so Es belegt nicht nur mehr Festplatte, sondern hat auch keinen praktischen Nutzen und wird daher nicht empfohlen.

Wenn Sie nur das Modell der letzten Generation speichern möchten, müssen Sie max_to_keep natürlich nur auf 1 setzen, also

saver=tf.train.Saver(max_to_keep=1)

Nachdem Sie das Speicherobjekt erstellt haben, können Sie das trainierte Modell speichern, z. B.:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

Die erste Parametersitzung, natürlich Das . Der zweite Parameter legt den gespeicherten Pfad und Namen fest und der dritte Parameter fügt die Anzahl der Trainingszeiten als Suffix zum Modellnamen hinzu.

saver.save(sess, 'my-model', global_step=0) ==> Dateiname: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> Dateiname: 'my-model-1000'

Sehen Sie sich ein Mnist-Beispiel an:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

Der rote Teil im Code ist der Code zum Speichern des Modells. Obwohl ich es nach jeder Trainingsgeneration speichere, überschreibt das beim nächsten Mal gespeicherte Modell das vorherige und nur das letzte Mal . Daher können wir Zeit sparen und den Speichercode außerhalb der Schleife platzieren (gilt nur für max_to_keep=1, andernfalls muss er noch innerhalb der Schleife platziert werden).

Im Experiment ist die letzte Generation möglicherweise nicht vorhanden Die Generation mit der höchsten Überprüfungsgenauigkeit. Daher möchten wir nicht standardmäßig die letzte Generation speichern, sondern die Generation mit der höchsten Überprüfungsgenauigkeit. Fügen Sie daher einfach eine Zwischenvariable und eine Beurteilungsaussage hinzu.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

Wenn wir die drei Generationen mit der höchsten Verifizierungsgenauigkeit speichern und auch die Verifizierungsgenauigkeit jedes Mal speichern möchten, können wir einen TXT generieren Datei zum Speichern.

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

Das Modell wird mit der Funktion „restore()“ wiederhergestellt, die zwei Parameter „restore(sess, save_path)“ erfordert. save_path bezieht sich auf den gespeicherten Modellpfad . Wir können tf.train.latest_checkpoint() verwenden, um automatisch das zuletzt gespeicherte Modell abzurufen. Zum Beispiel:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

Dann können wir die zweite Hälfte des Programms ändern in:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

Der rot markierte Bereich ist der Code zum Speichern und Wiederherstellen des Modells. Verwenden Sie die Bool-Variable is_train, um die Trainings- und Verifizierungsphasen zu steuern.

Gesamtes Quellprogramm:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

Verwandte Empfehlungen:

Modellnetzwerk von TensorFlow als einzelne Datei exportieren Methode

Das obige ist der detaillierte Inhalt vonSpeichern und Wiederherstellen des von tensorflow1.0 (Saver)_python gelernten Modells. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn