本篇文章主要介紹了淺談Tensorflow模型的保存與恢復加載,現在分享給大家,也給大家做個參考。一起過來看看吧
近期做了一些反垃圾的工作,除了使用常用的規則匹配過濾等手段,也採用了一些機器學習方法進行分類預測。我們使用TensorFlow進行模型的訓練,訓練好的模型需要保存,預測階段我們需要將模型進行載入還原使用,這涉及TensorFlow模型的保存與復原載入。
總結一下Tensorflow常用的模型保存方式。
儲存checkpoint模型檔(.ckpt)
#首先,TensorFlow提供了一個非常方便的api,tf.train.Saver()來保存和還原一個機器學習模型。
模型保存
使用tf.train.Saver()來保存模型檔案非常方便,以下是一個簡單的範例:
import tensorflow as tf import os def save_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(ckpt_file_path)) if os.path.isdir(path) is False: os.makedirs(path) tf.train.Saver().save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
程式產生並儲存四個檔案(在版本0.11之前只會產生三個檔案:checkpoint, model.ckpt , model.ckpt.meta)
checkpoint 文字文件,記錄了模型文件的路徑資訊清單
model.ckpt.data-00000 -of-00001 網路權重資訊
model.ckpt.index .data和.index這兩個文件是二進位文件,保存了模型中的變數參數(權重)資訊
model.ckpt.meta 二進位文件,保存了模型的計算圖結構資訊(模型的網路結構)protobuf
以上是tf.train .Saver().save()的基本用法,save()方法還有很多可設定的參數:
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
程式碼如下:tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
如果想每兩小時保存一次模型,並且只保存最新的4個模型,可以加上使用max_to_keep(預設值為5,如果想每訓練一個epoch就保存一次,可以將其設定為None或0,但沒啥用不推薦), keep_checkpoint_every_n_hours參數,如下:
程式碼如下:tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
同時在tf.train.Saver()類別中,如果我們不指定任何訊息,則會保存所有的參數信息,我們也可以指定部分想要保存的內容,例如只儲存x, y參數(可傳入參數list或dict):
#
tf.train.Saver([x, y]).save(sess, ckpt_file_path)
##ps.在模型訓練過程中需要在儲存後拿到的變數或參數名稱屬性name不能丟,不然模型還原後不能透過get_tensor_by_name()取得。
模型載入還原
針對上面的模型保存例子,還原模型的過程如下:
import tensorflow as tf def restore_model_ckpt(ckpt_file_path): sess = tf.Session() saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构 saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息 # 直接获取保存的变量 print(sess.run('b:0')) # 获取placeholder变量 input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') # 获取需要进行计算的operator op = sess.graph.get_tensor_by_name('op_to_store:0') # 加入新的操作 add_on_op = tf.multiply(op, 2) ret = sess.run(add_on_op, {input_x: 5, input_y: 5}) print(ret)
首先還原模型結構,然後還原變數(參數)信息,最後我們就可以獲得已訓練的模型中的各種信息了(保存的變量、placeholder變數、operator等),同時可以為取得的變數新增各種新的操作(請參閱以上程式碼註解)。
並且,我們也可以載入部分模型,在此基礎上加入其它操作,具體可以參考官方文件和demo。
針對ckpt模型檔案的保存與還原,stackoverflow上有一個回答解釋比較清晰,可以參考。
儲存單一模型檔(.pb)
我自己有執行過Tensorflow的inception-v3的demo,發現運行結束後會產生一個.pb的模型文件,這個文件是作為後續預測或遷移學習使用的,就一個文件,非常酷炫,也十分方便。
這個過程的主要想法是graph_def檔案中沒有包含網路中的Variable值(通常情況儲存了權重),但卻包含了constant值,所以如果我們能把Variable轉換為constant(使用graph_util.convert_variables_to_constants()函式),即可達到使用一個檔案同時儲存網路架構與權重的目標。
ps:這裡.pb是模型檔案的後綴名,當然我們也可以用其它的後綴(使用.pb與google保持一致╮(╯▽╰)╭)
模型保存
同樣根據上面的例子,一個簡單的demo:
import tensorflow as tf import os from tensorflow.python.framework import graph_util def save_mode_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 这里的输出需要加上name属性 op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(pb_file_path)) if os.path.isdir(path) is False: os.makedirs(path) # convert_variables_to_constants 需要指定output_node_names,list(),可以多个 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(constant_graph.SerializeToString()) # test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict))
程序生成并保存一个文件
model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息
模型加载还原
针对上面的模型保存例子,还原模型的过程如下:
import tensorflow as tf from tensorflow.python.platform import gfile def restore_mode_pb(pb_file_path): sess = tf.Session() with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') print(sess.run('b:0')) input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') op = sess.graph.get_tensor_by_name('op_to_store:0') ret = sess.run(op, {input_x: 5, input_y: 5}) print(ret)
模型的还原过程与checkpoint差不多一样。
《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。
思考
模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。
同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。
选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。
以上完整代码:github https://github.com/liuyan731/tf_demo
相关推荐:
以上是淺談Tensorflow模型的保存與恢復加載的詳細內容。更多資訊請關注PHP中文網其他相關文章!

Python列表切片的基本語法是list[start:stop:step]。 1.start是包含的第一個元素索引,2.stop是排除的第一個元素索引,3.step決定元素之間的步長。切片不僅用於提取數據,還可以修改和反轉列表。

ListSoutPerformarRaysin:1)DynamicsizicsizingandFrequentInsertions/刪除,2)儲存的二聚體和3)MemoryFeliceFiceForceforseforsparsedata,butmayhaveslightperformancecostsinclentoperations。

toConvertapythonarraytoalist,usEthelist()constructororageneratorexpression.1)intimpthearraymoduleandcreateanArray.2)USELIST(ARR)或[XFORXINARR] to ConconverTittoalist,請考慮performorefformanceandmemoryfformanceandmemoryfformienceforlargedAtasetset。

choosearraysoverlistsinpythonforbetterperformanceandmemoryfliceSpecificScenarios.1)largenumericaldatasets:arraysreducememoryusage.2)績效 - 臨界雜貨:arraysoffersoffersOffersOffersOffersPoostSfoostSforsssfortasssfortaskslikeappensearch orearch.3)testessenforcety:arraysenforce:arraysenforc

在Python中,可以使用for循環、enumerate和列表推導式遍歷列表;在Java中,可以使用傳統for循環和增強for循環遍歷數組。 1.Python列表遍歷方法包括:for循環、enumerate和列表推導式。 2.Java數組遍歷方法包括:傳統for循環和增強for循環。

本文討論了版本3.10中介紹的Python的新“匹配”語句,該語句與其他語言相同。它增強了代碼的可讀性,並為傳統的if-elif-el提供了性能優勢

Python中的功能註釋將元數據添加到函數中,以進行類型檢查,文檔和IDE支持。它們增強了代碼的可讀性,維護,並且在API開發,數據科學和圖書館創建中至關重要。


熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

Video Face Swap
使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

熱工具

SublimeText3 Mac版
神級程式碼編輯軟體(SublimeText3)

Dreamweaver CS6
視覺化網頁開發工具

EditPlus 中文破解版
體積小,語法高亮,不支援程式碼提示功能

WebStorm Mac版
好用的JavaScript開發工具

ZendStudio 13.5.1 Mac
強大的PHP整合開發環境