Maison >développement back-end >Tutoriel Python >Sauvegarde et restauration du modèle appris par tensorflow1.0 (Saver)_python

Sauvegarde et restauration du modèle appris par tensorflow1.0 (Saver)_python

不言
不言original
2018-04-23 15:42:501878parcourir

Cet article présente principalement la sauvegarde et la récupération (Saver) du modèle d'apprentissage tensorflow1.0. Maintenant, je le partage avec vous et le donne comme référence. Jetons un coup d'œil ensemble

Enregistrons les paramètres du modèle entraîné pour une vérification ou des tests ultérieurs. C'est quelque chose que nous faisons souvent. Le module tf.train.Saver() qui permet la sauvegarde du modèle dans tf.

Pour enregistrer le modèle, vous devez d'abord créer un objet Saver : tel que

saver=tf.train.Saver()

Lors de la création de cet objet Saver, il y a un paramètre we Le paramètre max_to_keep est souvent utilisé pour définir le nombre de modèles enregistrés. La valeur par défaut est 5, c'est-à-dire max_to_keep=5, qui enregistre les 5 derniers modèles. Si vous souhaitez enregistrer le modèle à chaque génération (époque) d'entraînement, vous pouvez définir max_to_keep sur Aucun ou 0, par exemple :

saver=tf.train.Saver(max_to_keep=0)

Mais comme ça En plus d’occuper plus de disque dur, il n’a aucune utilité pratique, il n’est donc pas recommandé.

Bien sûr, si vous souhaitez uniquement enregistrer le modèle de dernière génération, il vous suffit de définir max_to_keep sur 1, c'est-à-dire

saver=tf.train.Saver(max_to_keep=1)

Après avoir créé l'objet économiseur, vous pouvez enregistrer le modèle entraîné, tel que :

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

La première session de paramètres, il va sans dire. Le deuxième paramètre définit le chemin et le nom enregistrés, et le troisième paramètre ajoute le nombre de temps de formation comme suffixe au nom du modèle.

saver.save(sess, 'mon-modèle', global_step=0) ==> nom de fichier : 'mon-modèle-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> nom de fichier : 'my-model-1000'

Regardez un exemple de mnist :

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

La partie rouge dans le code est le code pour sauvegarder le modèle. Bien que je le sauvegarde après chaque génération d'entraînement, le modèle enregistré la prochaine fois écrasera le précédent, et seule la dernière fois sera enregistrée. . Par conséquent, nous pouvons gagner du temps et mettre le code de sauvegarde en dehors de la boucle (ne s'applique qu'à max_to_keep=1, sinon il doit quand même être placé à l'intérieur de la boucle).

Dans l'expérience, la dernière génération peut ne pas être la génération avec la précision de vérification la plus élevée, nous ne voulons donc pas enregistrer la dernière génération par défaut, mais nous voulons enregistrer la génération avec la précision de vérification la plus élevée, il suffit donc d'ajouter une variable intermédiaire et une déclaration de jugement.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

Si nous voulons enregistrer les trois générations avec la précision de vérification la plus élevée, et également enregistrer la précision de vérification de chaque fois, nous pouvons générer un txt fichier à sauvegarder.

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

Le modèle est restauré à l'aide de la fonction restaurer(), qui nécessite deux paramètres de restauration (sess, save_path), save_path fait référence au chemin du modèle enregistré . Nous pouvons utiliser tf.train.latest_checkpoint() pour obtenir automatiquement le dernier modèle enregistré. Par exemple :

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

Ensuite, nous pouvons changer la seconde moitié du programme en :

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

La zone marquée en rouge est le code lié à la sauvegarde et à la restauration du modèle. Utilisez une variable booléenne is_train pour contrôler les phases de formation et de vérification.

Programme source complet :

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

Recommandations associées :

Exporter le réseau de modèles TensorFlow en tant que fichier unique méthode

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn