Rumah >pembangunan bahagian belakang >Tutorial Python >Bagaimana untuk Menyimpan dan Memulihkan Model Terlatih dengan Berkesan dalam TensorFlow?

Bagaimana untuk Menyimpan dan Memulihkan Model Terlatih dengan Berkesan dalam TensorFlow?

Linda Hamilton
Linda Hamiltonasal
2024-12-14 12:03:12900semak imbas

How to Effectively Save and Restore Trained Models in TensorFlow?

Menyimpan dan Memulihkan Model Terlatih dalam Tensorflow

Selepas melatih model dalam Tensorflow, memelihara dan menggunakannya semula adalah penting. Begini cara mengendalikan storan model dengan berkesan:

Menyimpan Model Terlatih (Tensorflow versi 0.11 dan ke atas):

  1. Sediakan Input: Tentukan ruang letak dan sediakan kamus suapan dengan data input.
  2. Tentukan Operasi: Tentukan operasi yang akan dipulihkan, seperti penambahan atau pendaraban.
  3. Buat Objek Penjimat: Buat seketika objek penjimat yang mengurus storan berubah-ubah.
  4. Simpan Graf: Gunakan kaedah saver.save() untuk menyimpan model, termasuk pembolehubah dan graf struktur.

Contoh Kod:

import tensorflow as tf

# Prepare input placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")

# Define test operation
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, tf.Variable(2.0, name="bias"), name="op_to_restore")

# Initialize variables and run session
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create saver object
saver = tf.train.Saver()

# Save the model
saver.save(sess, 'my_test_model', global_step=1000)

Memulihkan Model yang Disimpan:

  1. Muatkan Graf Meta: Import graf meta untuk mengakses model yang disimpan struktur.
  2. Pulihkan Pembolehubah: Gunakan kaedah saver.restore() untuk mendapatkan semula pembolehubah yang disimpan.
  3. Dapatkan Pemegang Tempat dan Data Suapan: Dapatkan input pemegang tempat dan suapkan mereka dengan data baharu.
  4. Akses Disimpan Operasi: Cari operasi yang anda mahu jalankan dan laksanakannya.

Kod Contoh:

# Restore model
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Get placeholders and feed data
w1 = sess.graph.get_tensor_by_name("w1:0")
w2 = sess.graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Run saved operation
op_to_restore = sess.graph.get_tensor_by_name("op_to_restore:0")
result = sess.run(op_to_restore, feed_dict)

Atas ialah kandungan terperinci Bagaimana untuk Menyimpan dan Memulihkan Model Terlatih dengan Berkesan dalam TensorFlow?. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn