search
HomeBackend DevelopmentPython TutorialSaving and restoring the model learned by tensorflow1.0 (Saver)_python

This article mainly introduces the saving and recovery (Saver) of the tensorflow1.0 learning model. Now I share it with you and give you a reference. Let’s take a look together

Save the trained model parameters for future verification or testing. This is something we often do. The tf.train.Saver() module that provides model saving in tf.

To save the model, you must first create a Saver object: such as

saver=tf.train.Saver()

When creating this Saver object, there is a parameter we often What is used is the max_to_keep parameter. This is used to set the number of saved models. The default is 5, that is, max_to_keep=5, which saves the latest 5 models. If you want to save the model every training generation (epoch), you can set max_to_keep to None or 0, such as:

saver=tf.train.Saver(max_to_keep=0)

But this does nothing but Taking up more hard disk space is of little practical use, so it is not recommended.

Of course, if you only want to save the model of the last generation, you only need to set max_to_keep to 1, that is,

saver=tf.train.Saver(max_to_keep=1)

Create After completing the saver object, you can save the trained model, such as:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

The first parameter sess, this goes without saying. The second parameter sets the saved path and name, and the third parameter adds the number of training times as a suffix to the model name.

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

Look at a mnist instance:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

The red part in the code is the code to save the model. Although I saved it after each generation of training, the model saved the next time will overwrite the previous one, and only the last time will be saved. Therefore, we can save time and put the saving code outside the loop (only applies to max_to_keep=1, otherwise it still needs to be placed inside the loop).

In the experiment, the last generation may not be the generation with the highest verification accuracy , so we don’t want to save the last generation by default, but want to save the generation with the highest verification accuracy, so just add an intermediate variable and a judgment statement.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

If we want to save the three generations with the highest verification accuracy, and also save the verification accuracy of each time, we can generate a txt file with to save.

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

The restore() function of the model is used, which requires two parameters restore(sess, save_path). save_path refers to the saved model path. . We can use tf.train.latest_checkpoint() to automatically get the last saved model. For example:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

Then we can change the second half of the program to:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

The place marked in red is the code related to saving and restoring the model. Use a bool variable is_train to control the training and verification phases.

Entire source program:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

Related recommendations:

Export TensorFlow’s model network as a single file Methods

The above is the detailed content of Saving and restoring the model learned by tensorflow1.0 (Saver)_python. For more information, please follow other related articles on the PHP Chinese website!

Statement
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Is Tuple Comprehension possible in Python? If yes, how and if not why?Is Tuple Comprehension possible in Python? If yes, how and if not why?Apr 28, 2025 pm 04:34 PM

Article discusses impossibility of tuple comprehension in Python due to syntax ambiguity. Alternatives like using tuple() with generator expressions are suggested for creating tuples efficiently.(159 characters)

What are Modules and Packages in Python?What are Modules and Packages in Python?Apr 28, 2025 pm 04:33 PM

The article explains modules and packages in Python, their differences, and usage. Modules are single files, while packages are directories with an __init__.py file, organizing related modules hierarchically.

What is docstring in Python?What is docstring in Python?Apr 28, 2025 pm 04:30 PM

Article discusses docstrings in Python, their usage, and benefits. Main issue: importance of docstrings for code documentation and accessibility.

What is a lambda function?What is a lambda function?Apr 28, 2025 pm 04:28 PM

Article discusses lambda functions, their differences from regular functions, and their utility in programming scenarios. Not all languages support them.

What is a break, continue and pass in Python?What is a break, continue and pass in Python?Apr 28, 2025 pm 04:26 PM

Article discusses break, continue, and pass in Python, explaining their roles in controlling loop execution and program flow.

What is a pass in Python?What is a pass in Python?Apr 28, 2025 pm 04:25 PM

The article discusses the 'pass' statement in Python, a null operation used as a placeholder in code structures like functions and classes, allowing for future implementation without syntax errors.

Can we Pass a function as an argument in Python?Can we Pass a function as an argument in Python?Apr 28, 2025 pm 04:23 PM

Article discusses passing functions as arguments in Python, highlighting benefits like modularity and use cases such as sorting and decorators.

What is the difference between / and // in Python?What is the difference between / and // in Python?Apr 28, 2025 pm 04:21 PM

Article discusses / and // operators in Python: / for true division, // for floor division. Main issue is understanding their differences and use cases.Character count: 158

See all articles

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

Video Face Swap

Video Face Swap

Swap faces in any video effortlessly with our completely free AI face swap tool!

Hot Tools

MantisBT

MantisBT

Mantis is an easy-to-deploy web-based defect tracking tool designed to aid in product defect tracking. It requires PHP, MySQL and a web server. Check out our demo and hosting services.

EditPlus Chinese cracked version

EditPlus Chinese cracked version

Small size, syntax highlighting, does not support code prompt function

SublimeText3 Chinese version

SublimeText3 Chinese version

Chinese version, very easy to use

ZendStudio 13.5.1 Mac

ZendStudio 13.5.1 Mac

Powerful PHP integrated development environment

SecLists

SecLists

SecLists is the ultimate security tester's companion. It is a collection of various types of lists that are frequently used during security assessments, all in one place. SecLists helps make security testing more efficient and productive by conveniently providing all the lists a security tester might need. List types include usernames, passwords, URLs, fuzzing payloads, sensitive data patterns, web shells, and more. The tester can simply pull this repository onto a new test machine and he will have access to every type of list he needs.