Maison > Article > développement back-end > Une brève discussion sur la sauvegarde et la restauration du chargement des modèles Tensorflow
Cet article présente principalement la sauvegarde et la restauration du modèle Tensorflow. Maintenant, je le partage avec vous et le donne comme référence. Venez y jeter un coup d'œil
Récemment, nous avons effectué un travail anti-spam. En plus d'utiliser des méthodes de correspondance et de filtrage de règles couramment utilisées, nous utilisons également certaines méthodes d'apprentissage automatique pour la prédiction de classification. Nous utilisons TensorFlow pour entraîner le modèle. Le modèle entraîné doit être enregistré. Dans la phase de prédiction, nous devons charger et restaurer le modèle pour l'utiliser, ce qui implique la sauvegarde et la restauration du modèle TensorFlow.
Résumez les méthodes d'enregistrement de modèle couramment utilisées par Tensorflow.
Enregistrer le fichier de modèle de point de contrôle (.ckpt)
Tout d'abord, TensorFlow fournit une API très pratique, tf.train.Saver() pour enregistrer et restaurer un modèle d’apprentissage automatique.
Sauvegarde du modèle
Il est très pratique d'utiliser tf.train.Saver() pour enregistrer les fichiers de modèle. Voici un exemple simple :
import tensorflow as tf import os def save_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(ckpt_file_path)) if os.path.isdir(path) is False: os.makedirs(path) tf.train.Saver().save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
Le programme génère et enregistre quatre fichiers (avant la version 0.11, seuls trois fichiers étaient générés : point de contrôle, model .ckpt, model.ckpt.meta)
fichier texte de point de contrôle, qui enregistre la liste des informations de chemin du fichier modèle
model. ckpt.data -00000-of-00001 Informations sur le poids du réseau
model.ckpt.index Les deux fichiers .data et .index sont des fichiers binaires qui enregistrent les informations sur les paramètres variables (poids) dans le modèle
model.ckpt.meta fichier binaire, qui enregistre les informations sur la structure du graphe informatique du modèle (la structure de réseau du modèle) protobuf
Ce qui précède est l'utilisation tf Basic de .train.Saver().save(), la méthode save() a également de nombreux paramètres configurables :
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
L'ajout du paramètre global_step signifie enregistrer le modèle toutes les 1000 itérations. "-1000" sera ajouté après le fichier modèle, model.ckpt-1000.index, model.ckpt-1000. .meta, model.ckpt.data-1000-00000-of-00001
Le modèle est enregistré toutes les 1000 itérations, mais le fichier d'informations structurelles du modèle ne changera pas. Il sera uniquement enregistré. après 1000 itérations sans correspondance est sauvegardé toutes les 1000 fois, donc quand on n'a pas besoin de sauvegarder le méta-fichier, on peut ajouter le paramètre write_meta_graph=False, comme suit :
Copier le code Le code est le suivant :
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
Si vous souhaitez enregistrer le modèle toutes les deux heures et enregistrer uniquement les 4 derniers modèles, vous pouvez ajouter max_to_keep (le la valeur par défaut est 5, si vous souhaitez entraîner une époque tous les enregistrez-la simplement une fois, vous pouvez la définir sur Aucun ou 0, mais c'est inutile et déconseillé), le paramètre keep_checkpoint_every_n_hours est le suivant :
Copier le code Le code est le suivant :
tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
En même temps, dans le tf.train.Saver(), si nous ne spécifions aucune information, toutes les informations sur les paramètres seront enregistrées. Nous pouvons également spécifier une partie du contenu que nous souhaitons enregistrer, par exemple, enregistrer uniquement les paramètres x, y (le paramètre). list ou dict peut être transmis) :
tf.train.Saver([x, y]).save(sess, ckpt_file_path)
ps. Nom du paramètre Le nom de l'attribut qui doit être obtenu après l'enregistrement ne peut pas être perdu, sinon le modèle ne peut pas être obtenu via get_tensor_by_name() après la restauration.
Chargement et restauration du modèle
Pour l'exemple d'enregistrement de modèle ci-dessus, le processus de restauration du modèle est le suivant :
import tensorflow as tf def restore_model_ckpt(ckpt_file_path): sess = tf.Session() saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构 saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息 # 直接获取保存的变量 print(sess.run('b:0')) # 获取placeholder变量 input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') # 获取需要进行计算的operator op = sess.graph.get_tensor_by_name('op_to_store:0') # 加入新的操作 add_on_op = tf.multiply(op, 2) ret = sess.run(add_on_op, {input_x: 5, input_y: 5}) print(ret)
Restaurez d'abord la structure du modèle, puis restaurez les informations sur la variable (paramètre), et enfin nous pouvons obtenir diverses informations dans le système formé modèle (variables de sauvegarde, variables d'espace réservé, opérateurs, etc.), et diverses nouvelles opérations peuvent être ajoutées aux variables obtenues (voir les commentaires de code ci-dessus).
De plus, nous pouvons également charger certains modèles et ajouter d'autres opérations sur cette base. Pour plus de détails, veuillez vous référer à la documentation officielle et à la démo.
Concernant la sauvegarde et la restauration des fichiers de modèle ckpt, il existe une réponse sur stackoverflow avec une explication claire, à laquelle vous pouvez vous référer.
En parallèle, le tutoriel sur la sauvegarde et la restauration des modèles TensorFlow sur cv-tricks.com est également très bien, vous pouvez vous y référer.
"Apprentissage Tensorflow 1.0 : sauvegarde et restauration du modèle (Saver)" contient quelques conseils d'utilisation de Saver.
Enregistrer un seul fichier de modèle (.pb)
J'ai moi-même exécuté la démo d'inception-v3 de Tensorflow et j'ai constaté qu'un .pb serait généré après l'exécution du fichier Model, ce fichier est utilisé pour l'apprentissage ultérieur de la prédiction ou de la migration. Ce n'est qu'un seul fichier, très cool et très pratique.
L'idée principale de ce processus est que le fichier graph_def ne contient pas la valeur de la variable dans le réseau (généralement le poids est stocké), mais il contient la valeur constante, donc si nous peut convertir la variable en constante (à l'aide de la fonction graph_util.convert_variables_to_constants()), vous pouvez atteindre l'objectif d'utiliser un seul fichier pour stocker à la fois l'architecture du réseau et les poids.
ps : Ici .pb est le nom du suffixe du fichier modèle. Bien entendu, nous pouvons également utiliser d'autres suffixes (utilisez .pb pour être cohérent avec Google╮(╯▽╰)╭)<.>
Sauvegarde du modèle
import tensorflow as tf import os from tensorflow.python.framework import graph_util def save_mode_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 这里的输出需要加上name属性 op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(pb_file_path)) if os.path.isdir(path) is False: os.makedirs(path) # convert_variables_to_constants 需要指定output_node_names,list(),可以多个 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(constant_graph.SerializeToString()) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
程序生成并保存一个文件
model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息
模型加载还原
针对上面的模型保存例子,还原模型的过程如下:
import tensorflow as tf from tensorflow.python.platform import gfile def restore_mode_pb(pb_file_path): sess = tf.Session() with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') print(sess.run('b:0')) input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') op = sess.graph.get_tensor_by_name('op_to_store:0') ret = sess.run(op, {input_x: 5, input_y: 5}) print(ret)
模型的还原过程与checkpoint差不多一样。
《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。
思考
模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。
同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。
选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。
以上完整代码:github https://github.com/liuyan731/tf_demo
相关推荐:
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!