Maison  >  Article  >  développement back-end  >  Démarrer avec TensorFlow et utiliser tf.train.Saver() pour enregistrer le modèle

Démarrer avec TensorFlow et utiliser tf.train.Saver() pour enregistrer le modèle

不言
不言original
2018-04-24 14:15:074035parcourir

Cet article explique principalement comment enregistrer le modèle à l'aide de tf.train.Saver() pour démarrer avec TensorFlow. Maintenant, je le partage avec vous et le donne comme référence. Jetons un coup d'œil ensemble

Quelques réflexions sur la sauvegarde du modèle

saver = tf.train.Saver(max_to_keep=3)

Lors de la définition d'un économiseur, vous définirez généralement celui qui enregistre le plus de modèles. Quantité, d'une manière générale, si le modèle lui-même est grand, il faut tenir compte de la taille du disque dur. Si vous devez effectuer un réglage fin sur la base du modèle actuellement formé, enregistrez autant de modèles que possible. Les réglages fins ultérieurs ne seront pas nécessairement effectués à partir du meilleur ckpt, car il peut être surajusté en même temps. Mais si vous enregistrez trop de fichiers, le disque dur sera sous pression. Si vous souhaitez conserver uniquement le meilleur modèle, la méthode consiste à calculer la précision ou la valeur f1 sur l'ensemble de vérification à chaque itération sur un certain nombre d'étapes. Si le résultat cette fois est meilleur que la dernière fois, enregistrez le nouveau. Sinon, il n’est pas nécessaire de le sauvegarder.

Si vous souhaitez utiliser des modèles enregistrés à différentes époques pour la fusion, 3 à 5 modèles suffisent. Supposons que les modèles fusionnés deviennent M, et que le meilleur modèle unique s'appelle m_best, donc Fusion peut en effet être meilleur que. m_meilleur pour M. Mais si vous fusionnez ce modèle avec des modèles d'autres structures, l'effet de M n'est pas aussi bon que m_best, car M équivaut à une opération moyenne, ce qui réduit les « caractéristiques » du modèle.

Mais il existe une nouvelle méthode de fusion, qui consiste à utiliser l'ajustement du taux d'apprentissage pour obtenir plusieurs points optimaux locaux, c'est-à-dire que lorsque la perte ne peut pas être réduite, enregistrer un ckpt, puis augmenter le taux d'apprentissage à. continuez à trouver le prochain point optimal local, puis utilisez ces ckpt pour la fusion. Je ne l'ai pas encore essayé. Le modèle unique sera certainement amélioré, mais je ne sais pas s'il y aura une situation où l'amélioration. ne sera pas amélioré lorsqu’il sera combiné avec d’autres modèles.

Comment utiliser tf.train.Saver() pour enregistrer le modèle

J'ai déjà reçu des erreurs, principalement à cause de problèmes de codage de triche. Faites donc attention à ne pas avoir de caractères chinois dans le chemin du fichier.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Modèle enregistré dans le fichier : ./ckpt/test-model.ckpt-1

Remarque, après avoir enregistré le modèle ci-dessus. Vous devez redémarrer le noyau avant d'utiliser le modèle suivant pour importer. Sinon, le nom sera erroné en nommant "v1" deux fois.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restauration des paramètres à partir de ./ckpt/test-model.ckpt-1
Modèle restauré.
[ 1.           2.29999995]
55.5

Avant d'importer le modèle, vous devez redéfinir les variables.

Mais il n'est pas nécessaire de redéfinir toutes les variables, il suffit de définir les variables dont nous avons besoin.

En d'autres termes, les variables que vous définissez doivent exister dans le point de contrôle mais toutes les variables du point de contrôle ne doivent pas être redéfinies.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restauration des paramètres à partir de ./ckpt/test-model.ckpt-1
Modèle restauré.
[ 1.             2.29999995]

tf.Saver([tensors_to_be_saved]) Vous pouvez transmettre une liste et transmettre les tenseurs à enregistrer. Si cette liste n'est pas donnée, il Tous. les tenseurs actuels seront enregistrés par défaut. D'une manière générale, tf.Saver peut être intelligemment combiné avec tf.variable_scope() Vous pouvez vous référer à : [Transfer Learning] Ajouter de nouvelles variables à un modèle déjà enregistré et l'affiner

Recommandations associées :

À propos de la fonction tf.train.batch dans Tensorflow

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn