首頁  >  文章  >  後端開發  >  Tensorflow之Saver的用法

Tensorflow之Saver的用法

不言
不言原創
2018-04-23 15:46:311943瀏覽

這篇文章主要介紹了Tensorflow之Saver的用法詳解,現在分享給大家,也給大家做個參考。一起來看看吧

Saver的用法

1. Saver的背景介紹

我們經常在訓練完一個模型之後希望保存訓練的結果,這些結果指的是模型的參數,以便下次迭代的訓練或用作測試。 Tensorflow針對此需求提供了Saver類別。

Saver類別提供了向checkpoints檔案保存和從checkpoints檔案中復原變數的相關方法。 Checkpoints檔案是一個二進位文件,它把變數名稱映射到對應的tensor值 。

只要提供一個計數器,當計數器觸發時,Saver類別可以自動的產生checkpoint檔案。這讓我們可以在訓練過程中保存多個中間結果。例如,我們可以儲存每一步訓練的結果。

為了避免填滿整個磁碟,Saver可以自動的管理Checkpoints檔案。例如,我們可以指定儲存最近的N個Checkpoints檔案。

2. Saver的實例

下面以範例來敘述如何使用Saver類別 

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))

  1. isTrain:用來區分訓練階段和測試階段,True表示訓練,False表示測試

  2. train_steps:表示訓練的次數,例子中使用100

  3. checkpoint_steps:表示訓練多少次保存一下checkpoints,例子中使用50

  4. checkpoint_dir:表示checkpoints檔案的保存路徑,範例中使用目前路徑

2.1 訓練階段

#使用Saver.save()方法儲存模型:

  1. sess:表示目前會話,目前會話記錄了目前的變數值

  2. checkpoint_dir 'model.ckpt':表示儲存的檔案名稱

  3. global_step:表示目前是第幾步

訓練完成後,目前目錄底下會多出5個檔案。

開啟名為「checkpoint」的文件,可以看到儲存記錄,和最新的模型儲存位置。

2.1測試階段

測試階段使用saver.restore()方法還原變數:

sess:表示目前會話,先前儲存的結果將會載入這個會話

ckpt.model_checkpoint_path:表示模型儲存的位置,不需要提供模型的名字,它會去查看checkpoint文件,看看最新的是誰,叫做什麼。

運行結果如下圖所示,載入了先前訓練的參數w和b的結果

相關推薦:

tensorflow 使用flags定義指令列參數的方法

#tensorflow1.0學習之模型的儲存與復原(Saver)_python

以上是Tensorflow之Saver的用法的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn