ホームページ >バックエンド開発 >Python チュートリアル >TensorFlow でトレーニング済みモデルを効果的に保存および復元するにはどうすればよいですか?

TensorFlow でトレーニング済みモデルを効果的に保存および復元するにはどうすればよいですか?

Linda Hamilton
Linda Hamiltonオリジナル
2024-12-14 12:03:12941ブラウズ

How to Effectively Save and Restore Trained Models in TensorFlow?

Tensorflow でのトレーニング済みモデルの保存と復元

Tensorflow でモデルをトレーニングした後、それを保存して再利用することが重要です。モデル ストレージを効果的に処理する方法は次のとおりです:

トレーニング済みモデルの保存 (Tensorflow バージョン 0.11 以降):

  1. 入力の準備:プレースホルダーを定義し、入力を含むフィード辞書を準備しますdata.
  2. 演算の定義: 加算や乗算など、復元する演算を指定します。
  3. セーバー オブジェクトの作成: セーバー オブジェクトをインスタンス化します。変数ストレージを管理します。
  4. グラフ: saver.save() メソッドを使用して、変数やグラフ構造を含むモデルを保存します。

コード例:

import tensorflow as tf

# Prepare input placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")

# Define test operation
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, tf.Variable(2.0, name="bias"), name="op_to_restore")

# Initialize variables and run session
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create saver object
saver = tf.train.Saver()

# Save the model
saver.save(sess, 'my_test_model', global_step=1000)

保存したものを復元するモデル:

  1. メタ グラフのロード: メタ グラフをインポートして、保存されたモデル構造にアクセスします。
  2. 変数の復元: saver.restore() メソッドを使用して保存されたデータを取得します変数。
  3. プレースホルダーの取得とデータのフィード: 入力プレースホルダーを取得し、新しいデータをフィードします。
  4. 保存された操作へのアクセス: 保存した操作を見つけます。実行して実行したい

コード例:

# Restore model
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Get placeholders and feed data
w1 = sess.graph.get_tensor_by_name("w1:0")
w2 = sess.graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Run saved operation
op_to_restore = sess.graph.get_tensor_by_name("op_to_restore:0")
result = sess.run(op_to_restore, feed_dict)

以上がTensorFlow でトレーニング済みモデルを効果的に保存および復元するにはどうすればよいですか?の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。