ホームページ >バックエンド開発 >Python チュートリアル >トレーニングされた TensorFlow モデルを保存および復元するにはどうすればよいですか?

トレーニングされた TensorFlow モデルを保存および復元するにはどうすればよいですか?

DDD
DDDオリジナル
2024-12-19 17:41:09636ブラウズ

How Can I Save and Restore Trained TensorFlow Models?

トレーニング済み TensorFlow モデルの保存と復元

TensorFlow は、トレーニング済みモデルの保存と復元のためのシームレスな機能を提供し、モデルを永続化して再利用できます。

保存Model

トレーニングされたモデルを TensorFlow に保存するには、tf.train.Saver クラスを使用できます。以下に例を示します。

import tensorflow as tf

# Prepare placeholders and variables
w1 = tf.placeholder(tf.float32, name="w1")
w2 = tf.placeholder(tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define an operation to be restored
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object
saver = tf.train.Saver()

# Run the operation and save the graph
print(sess.run(w4, feed_dict))
saver.save(sess, 'my_test_model', global_step=1000)

モデルの復元

以前に保存したモデルを復元するには、次のプロセスを使用できます:

import tensorflow as tf

sess = tf.Session()

# Load the meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Access saved variables directly
print(sess.run('bias:0'))  # Prints 2 (the bias value)

# Access and create feed-dict for new input data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Access the desired operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print(sess.run(op_to_restore, feed_dict))  # Prints 60 ((w1 + w2) * b1)

追加のシナリオとユースケースについては、提供された回答で提供されるリソースを参照してください。保存と復元について詳しく説明されています。 TensorFlow モデル。

以上がトレーニングされた TensorFlow モデルを保存および復元するにはどうすればよいですか?の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。