Home  >  Article  >  Backend Development  >  How to use Saver in Tensorflow

How to use Saver in Tensorflow

不言
不言Original
2018-04-23 15:46:311944browse

This article mainly introduces the detailed usage of Tensorflow's Saver. Now I will share it with you and give you a reference. Let’s come and take a look

How to use Saver

1. Background introduction to Saver

We often want to save the training results after training a model. These results refer to the parameters of the model for training in the next iteration or for testing. Tensorflow provides the Saver class for this requirement.

The Saver class provides related methods for saving to checkpoints files and restoring variables from checkpoints files. The checkpoints file is a binary file that maps variable names to corresponding tensor values.

As long as a counter is provided, the Saver class can automatically generate a checkpoint file when the counter is triggered. This allows us to save multiple intermediate results during training. For example, we can save the results of each training step.

To avoid filling up the entire disk, Saver can automatically manage Checkpoints files. For example, we can specify to save the most recent N Checkpoints files.

2. Saver instance

The following is an example to describe how to use the Saver class

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))

  1. isTrain: used to distinguish the training phase and the testing phase, True represents training, False represents testing

  2. train_steps: represents the number of training times , 100

  3. checkpoint_steps is used in the example: indicates how many times to save checkpoints during training, 50

  4. checkpoint_dir is used in the example: checkpoints file is saved Path, the current path is used in the example

2.1 Training phase

Use the Saver.save() method to save the model:

  1. sess: indicates the current session, which records the current variable value

  2. checkpoint_dir 'model.ckpt': indicates the stored file name

  3. global_step: Indicates the current step

After the training is completed, there will be 5 more files in the current directory.

Open the file named "checkpoint", you can see the save record and the latest model storage location.

2.1 Test Phase

The saver.restore() method is used to restore variables during the test phase:

sess: represents the current session , the previously saved results will be loaded into this session

ckpt.model_checkpoint_path: Indicates the location where the model is stored. There is no need to provide the name of the model. It will check the checkpoint file to see who is the latest. , what is it called.

The running results are shown in the figure below, loading the results of the previously trained parameters w and b

Related recommendations:

tensorflow How to use flags to define command line parameters

Save and restore the model learned by tensorflow1.0 (Saver)_python

The above is the detailed content of How to use Saver in Tensorflow. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn