Heim  >  Artikel  >  Backend-Entwicklung  >  Beispiel für einen VAE-Algorithmus in Python

Beispiel für einen VAE-Algorithmus in Python

王林
王林Original
2023-06-11 19:58:342190Durchsuche

VAE ist ein generatives Modell, sein vollständiger Name ist Variational Autoencoder und seine chinesische Übersetzung ist Variational Autoencoder. Es handelt sich um einen unbeaufsichtigten Lernalgorithmus, der zur Generierung neuer Daten wie Bilder, Audio, Text usw. verwendet werden kann. Im Vergleich zu herkömmlichen Autoencodern sind VAEs flexibler und leistungsfähiger und können komplexere und realistischere Daten generieren.

Python ist eine der am weitesten verbreiteten Programmiersprachen und eines der wichtigsten Werkzeuge für Deep Learning. In Python gibt es viele hervorragende Frameworks für maschinelles Lernen und Deep Learning, wie TensorFlow, PyTorch, Keras usw., die alle über VAE-Implementierungen verfügen.

In diesem Artikel wird anhand eines Python-Codebeispiels vorgestellt, wie TensorFlow zum Implementieren des VAE-Algorithmus und zum Generieren neuer handschriftlicher Ziffernbilder verwendet wird.

VAE-Modellprinzip

VAE ist eine unbeaufsichtigte Lernmethode, die potenzielle Merkmale aus Daten extrahieren und diese Merkmale zur Generierung neuer Daten verwenden kann. VAE lernt die Verteilung von Daten durch Berücksichtigung der Wahrscheinlichkeitsverteilung latenter Variablen. Es ordnet die Originaldaten einem latenten Raum zu und wandelt den latenten Raum über einen Decoder in rekonstruierte Daten um.

Die Modellstruktur von VAE besteht aus zwei Teilen: Encoder und Decoder. Der Encoder komprimiert die Originaldaten in den latenten Variablenraum und der Decoder ordnet die latenten Variablen wieder dem ursprünglichen Datenraum zu. Zwischen Encoder und Decoder gibt es außerdem eine Reparametrisierungsschicht, um sicherzustellen, dass die Abtastung latenter Variablen differenzierbar ist.

Die Verlustfunktion von VAE besteht aus zwei Teilen: dem Rekonstruktionsfehler, dem Abstand zwischen den Originaldaten und den vom Decoder generierten Daten. Der andere Teil ist der Regularisierungsterm, der zur Begrenzung der Verteilung verwendet wird der latenten Variablen.

Datensatz

Wir werden den MNIST-Datensatz verwenden, um das VAE-Modell zu trainieren und neue handschriftliche Ziffernbilder zu generieren. Der MNIST-Datensatz enthält eine Reihe handgeschriebener Ziffernbilder. Jedes Bild ist ein 28 × 28-Graustufenbild.

Wir können die von TensorFlow bereitgestellte API verwenden, um den MNIST-Datensatz zu laden und das Bild in Vektorform zu konvertieren. Der Code lautet wie folgt:

import tensorflow as tf
import numpy as np

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist

# 加载训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 将图像转换为向量形式
x_train = x_train.astype(np.float32) / 255.
x_test = x_test.astype(np.float32) / 255.
x_train = x_train.reshape((-1, 28 * 28))
x_test = x_test.reshape((-1, 28 * 28))

VAE-Modellimplementierung

Wir können TensorFlow verwenden, um das VAE-Modell zu implementieren. Sowohl der Encoder als auch der Decoder sind mehrschichtige neuronale Netze, und die Neuparametrisierungsschicht ist eine Zufallsschicht.

Der Implementierungscode des VAE-Modells lautet wie folgt:

import tensorflow_probability as tfp

# 定义编码器
encoder_inputs = tf.keras.layers.Input(shape=(784,))
x = tf.keras.layers.Dense(256, activation='relu')(encoder_inputs)
x = tf.keras.layers.Dense(128, activation='relu')(x)
mean = tf.keras.layers.Dense(10)(x)
logvar = tf.keras.layers.Dense(10)(x)

# 定义重参数化层
def sampling(args):
    mean, logvar = args
    epsilon = tfp.distributions.Normal(0., 1.).sample(tf.shape(mean))
    return mean + tf.exp(logvar / 2) * epsilon

z = tf.keras.layers.Lambda(sampling)([mean, logvar])

# 定义解码器
decoder_inputs = tf.keras.layers.Input(shape=(10,))
x = tf.keras.layers.Dense(128, activation='relu')(decoder_inputs)
x = tf.keras.layers.Dense(256, activation='relu')(x)
decoder_outputs = tf.keras.layers.Dense(784, activation='sigmoid')(x)

# 构建模型
vae = tf.keras.models.Model(encoder_inputs, decoder_outputs)

# 定义损失函数
reconstruction = -tf.reduce_sum(encoder_inputs * tf.math.log(1e-10 + decoder_outputs) + 
                                (1 - encoder_inputs) * tf.math.log(1e-10 + 1 - decoder_outputs), axis=1)
kl_divergence = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=-1)
vae_loss = tf.reduce_mean(reconstruction + kl_divergence)

vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()

Beim Schreiben des Codes müssen Sie auf die folgenden Punkte achten:

  • Lambda-Schicht verwenden, um schwere Parametrisierungsoperationen zu implementieren
  • Die Verlustfunktion umfasst Rekonstruktionsfehler und Regularisierungsbegriff
  • Fügen Sie dem Modell die Verlustfunktion hinzu. Es ist nicht erforderlich, den Gradienten manuell zu berechnen. Sie können den Optimierer direkt zum Training verwenden . Der Code zum Trainieren des Modells lautet wie folgt:
  • vae.fit(x_train, x_train,
            epochs=50,
            batch_size=128,
            validation_data=(x_test, x_test))
Während des Trainings können wir mehrere Epochen und größere Batchgrößen verwenden, um den Trainingseffekt zu verbessern.

Neue handschriftliche Ziffernbilder generieren

Nach Abschluss des Trainings können wir das VAE-Modell verwenden, um neue handschriftliche Ziffernbilder zu generieren. Der Code zum Generieren des Bildes lautet wie folgt:

import matplotlib.pyplot as plt

# 随机生成潜在变量
z = np.random.normal(size=(1, 10))

# 将潜在变量解码为图像
generated = vae.predict(z)

# 将图像转换为灰度图像
generated = generated.reshape((28, 28))
plt.imshow(generated, cmap='gray')
plt.show()

Wir können verschiedene handgeschriebene Ziffernbilder generieren, indem wir den Code mehrmals ausführen. Diese Bilder werden basierend auf der von VAE erlernten Datenverteilung mit Vielfalt und Kreativität generiert.

Zusammenfassung

Dieser Artikel stellt vor, wie der VAE-Algorithmus mithilfe von TensorFlow in Python implementiert wird, und demonstriert seine Anwendung mithilfe des MNIST-Datensatzes und der Generierung neuer handschriftlicher Ziffernbilder. Durch das Erlernen des VAE-Algorithmus können nicht nur neue Daten generiert, sondern auch potenzielle Merkmale in den Daten extrahiert werden, was eine neue Idee für die Datenanalyse und Mustererkennung liefert.

Das obige ist der detaillierte Inhalt vonBeispiel für einen VAE-Algorithmus in Python. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn