Heim  >  Artikel  >  Backend-Entwicklung  >  Beispiel für einen GAN-Algorithmus in Python

Beispiel für einen GAN-Algorithmus in Python

王林
王林Original
2023-06-10 09:53:501188Durchsuche

Generative Adversarial Networks (GAN) ist ein Deep-Learning-Algorithmus, der durch zwei miteinander konkurrierende neuronale Netze neue Daten generiert. GAN wird häufig für Generierungsaufgaben in den Bereichen Bild, Audio, Text und anderen Bereichen verwendet. In diesem Artikel werden wir Python verwenden, um ein Beispiel für einen GAN-Algorithmus zum Generieren von Bildern handgeschriebener Ziffern zu schreiben.

  1. Datensatzvorbereitung

Wir werden den MNIST-Datensatz als unseren Trainingsdatensatz verwenden. Der MNIST-Datensatz enthält 60.000 Trainingsbilder und 10.000 Testbilder, jedes Bild ist ein 28x28-Graustufenbild. Wir werden die TensorFlow-Bibliothek verwenden, um den Datensatz zu laden und zu verarbeiten. Bevor wir den Datensatz laden, müssen wir die TensorFlow-Bibliothek und die NumPy-Bibliothek installieren. Tensorflow als TF importieren shape[0], 28, 28, 1).astype('float32')

train_images = (train_images - 127.5) / 127.5 # Pixelwerte auf den Bereich von [-1, 1] normalisieren


GAN-Architekturdesign und Training

Unser GAN wird zwei neuronale Netze umfassen: ein Generatornetzwerk und ein Diskriminatornetzwerk. Das Generatornetzwerk empfängt den Rauschvektor als Eingabe und gibt ein 28x28-Bild aus. Das Diskriminatornetzwerk empfängt ein 28x28-Bild als Eingabe und gibt die Wahrscheinlichkeit aus, dass es sich bei dem Bild um ein echtes Bild handelt.

Die Architektur sowohl des Generatornetzwerks als auch des Diskriminatornetzwerks wird Faltungs-Neuronale Netzwerke (CNN) verwenden. Im Generatornetzwerk verwenden wir eine Entfaltungsschicht, um den Rauschvektor in ein 28x28-Bild zu dekodieren. Im Diskriminatornetzwerk verwenden wir Faltungsschichten, um die Eingabebilder zu klassifizieren.

Der Eingang zum Generatornetzwerk ist ein Rauschvektor der Länge 100. Wir werden die Netzwerkschichten mithilfe der Funktion tf.keras.Sequential stapeln.
  1. def make_generator_model():
  2. model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    
    model.add(tf.keras.layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size没有限制
    
    model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    
    model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    
    model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
    
    return model
    
Die Eingabe des Diskriminatornetzwerks ist ein 28x28-Bild. Wir werden die Netzwerkschichten mithilfe der Funktion tf.keras.Sequential stapeln.

def make_discriminator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                 input_shape=[28, 28, 1]))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1))

return model

Als nächstes schreiben wir den Trainingscode. Wir werden in jeder Charge abwechselnd das Generatornetzwerk und das Diskriminatornetzwerk trainieren. Während des Trainingsprozesses zeichnen wir die Farbverläufe mit der Funktion tf.GradientTape() auf und optimieren dann das Netzwerk mit der Funktion tf.keras.optimizers.Adam().

generator = make_generator_model () real_output, fake_output):

real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss

Generatorverlustfunktion

def generator_loss(fake_output):

return cross_entropy(tf.ones_like(fake_output), fake_output)

optimizer

generator_optimizer = tf.keras.optimizers.Adam(1e-4)

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Definieren Sie Trainingsfunktionen , werden wir das Generatornetzwerk verwenden, um neue Bilder zu generieren. Wir werden zufällig 100 Rauschvektoren generieren und sie in das Generatornetzwerk einspeisen, um neue Bilder handgeschriebener Ziffern zu erzeugen.

matplotlib.pyplot als plt importieren

def generic_and_save_images(model, epoch, test_input):

noise = tf.random.normal([BATCH_SIZE, 100])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

Rauschenvektoren zufällig generieren

noise = tf.random.normal([16, 100])

generate_and_save_images(generator, 0 , Rauschen)

Die Ergebnisse zeigen, dass der Generator erfolgreich neue handgeschriebene Ziffernbilder generiert hat. Wir können die Leistung des Modells verbessern, indem wir die Anzahl der Trainingsepochen schrittweise erhöhen. Darüber hinaus können wir die Leistung von GAN weiter verbessern, indem wir andere Hyperparameterkombinationen und Netzwerkarchitekturen ausprobieren.

Kurz gesagt ist der GAN-Algorithmus ein sehr nützlicher Deep-Learning-Algorithmus, mit dem verschiedene Arten von Daten generiert werden können. In diesem Artikel haben wir ein Beispiel für einen GAN-Algorithmus zum Generieren von Bildern handgeschriebener Ziffern mit Python geschrieben und gezeigt, wie man ein Generatornetzwerk trainiert und verwendet, um neue Bilder zu generieren.

Das obige ist der detaillierte Inhalt vonBeispiel für einen GAN-Algorithmus in Python. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn