这篇文章主要介绍了tensorflow1.0学习之模型的保存与恢复(Saver) ,现在分享给大家,也给大家做个参考。一起过来看看吧
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。
模型保存,先要创建一个Saver对象:如
saver=tf.train.Saver()
在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:
saver=tf.train.Saver(max_to_keep=0)
但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。
当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即
saver=tf.train.Saver(max_to_keep=1)
创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
看一个mnist实例:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017 @author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,]) dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) saver=tf.train.Saver(max_to_keep=1) for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).
在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。
saver=tf.train.Saver(max_to_keep=1) max_acc=0 for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。
saver=tf.train.Saver(max_to_keep=3) max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() sess.close()
模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file)
则程序后半段代码我们可以改为:
sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) is_train=False saver=tf.train.Saver(max_to_keep=3) #训练阶段 if is_train: max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() #验证阶段 else: model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。
整个源程序:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017 @author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,]) dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer()) is_train=True saver=tf.train.Saver(max_to_keep=3) #训练阶段 if is_train: max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc: max_acc=val_acc saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() #验证阶段 else: model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
相关推荐:
以上是tensorflow1.0学习之模型的保存与恢复(Saver)_python的详细内容。更多信息请关注PHP中文网其他相关文章!

使用NumPy创建多维数组可以通过以下步骤实现:1)使用numpy.array()函数创建数组,例如np.array([[1,2,3],[4,5,6]])创建2D数组;2)使用np.zeros(),np.ones(),np.random.random()等函数创建特定值填充的数组;3)理解数组的shape和size属性,确保子数组长度一致,避免错误;4)使用np.reshape()函数改变数组形状;5)注意内存使用,确保代码清晰高效。

播放innumpyisamethodtoperformoperationsonArraySofDifferentsHapesbyAutapityallate AligningThem.itSimplifififiesCode,增强可读性,和Boostsperformance.Shere'shore'showitworks:1)较小的ArraySaraySaraysAraySaraySaraySaraySarePaddedDedWiteWithOnestOmatchDimentions.2)

forpythondataTastorage,choselistsforflexibilityWithMixedDatatypes,array.ArrayFormeMory-effficityHomogeneousnumericalData,andnumpyArraysForAdvancedNumericalComputing.listsareversareversareversareversArversatilebutlessEbutlesseftlesseftlesseftlessforefforefforefforefforefforefforefforefforefforlargenumerdataSets; arrayoffray.array.array.array.array.array.ersersamiddreddregro

Pythonlistsarebetterthanarraysformanagingdiversedatatypes.1)Listscanholdelementsofdifferenttypes,2)theyaredynamic,allowingeasyadditionsandremovals,3)theyofferintuitiveoperationslikeslicing,but4)theyarelessmemory-efficientandslowerforlargedatasets.

toAccesselementsInapyThonArray,useIndIndexing:my_array [2] accessEsthethEthErlement,returning.3.pythonosezero opitedEndexing.1)usepositiveandnegativeIndexing:my_list [0] fortefirstElment,fortefirstelement,my_list,my_list [-1] fornelast.2] forselast.2)

文章讨论了由于语法歧义而导致的Python中元组理解的不可能。建议使用tuple()与发电机表达式使用tuple()有效地创建元组。(159个字符)

本文解释了Python中的模块和包装,它们的差异和用法。模块是单个文件,而软件包是带有__init__.py文件的目录,在层次上组织相关模块。

文章讨论了Python中的Docstrings,其用法和收益。主要问题:Docstrings对于代码文档和可访问性的重要性。


热AI工具

Undresser.AI Undress
人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover
用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool
免费脱衣服图片

Clothoff.io
AI脱衣机

Video Face Swap
使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热门文章

热工具

SublimeText3 Linux新版
SublimeText3 Linux最新版

SecLists
SecLists是最终安全测试人员的伙伴。它是一个包含各种类型列表的集合,这些列表在安全评估过程中经常使用,都在一个地方。SecLists通过方便地提供安全测试人员可能需要的所有列表,帮助提高安全测试的效率和生产力。列表类型包括用户名、密码、URL、模糊测试有效载荷、敏感数据模式、Web shell等等。测试人员只需将此存储库拉到新的测试机上,他就可以访问到所需的每种类型的列表。

SublimeText3汉化版
中文版,非常好用

VSCode Windows 64位 下载
微软推出的免费、功能强大的一款IDE编辑器

PhpStorm Mac 版本
最新(2018.2.1 )专业的PHP集成开发工具