Home  >  Article  >  Backend Development  >  GAN algorithm example in Python

GAN algorithm example in Python

2023-06-10 09:53:501153browse

Generative Adversarial Networks (GAN, Generative Adversarial Networks) is a deep learning algorithm that generates new data through two neural networks competing with each other. GAN is widely used for generation tasks in image, audio, text and other fields. In this article, we will use Python to write an example of a GAN algorithm for generating images of handwritten digits.

  1. Dataset preparation

We will use the MNIST data set as our training data set. The MNIST dataset contains 60,000 training images and 10,000 test images, each image is a 28x28 grayscale image. We will use the TensorFlow library to load and process the dataset. Before loading the dataset, we need to install the TensorFlow library and the NumPy library.

import tensorflow as tf
import numpy as np

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

Dataset preprocessing

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize pixel values ​​to the range of [-1, 1]

  1. GAN architecture design and training

Our GAN will include two neural networks: A generator network and a discriminator network. The generator network will receive the noise vector as input and output a 28x28 image. The discriminator network will receive a 28x28 image as input and output the probability that the image is a real image.

The architecture of both the generator network and the discriminator network will use convolutional neural networks (CNN). In the generator network, we will use a deconvolutional layer to decode the noise vector into a 28x28 image. In the discriminator network, we will use convolutional layers to classify the input images.

The input to the generator network is a noise vector of length 100. We will stack the network layers by using the tf.keras.Sequential function.

def make_generator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))

model.add(tf.keras.layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size没有限制

model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)

model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)

model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)

return model

The input of the discriminator network is a 28x28 image. We will stack the network layers by using the tf.keras.Sequential function.

def make_discriminator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                 input_shape=[28, 28, 1]))

model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))


return model

Next, we will write the training code. We will alternately train the generator network and the discriminator network in each batch. During the training process, we will record the gradients by using the tf.GradientTape() function, and then optimize the network using the tf.keras.optimizers.Adam() function.

generator = make_generator_model()
discriminator = make_discriminator_model()

Loss function

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Discriminator loss function

def discriminator_loss(real_output, fake_output):

real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss

Generator loss function

def generator_loss(fake_output):

return cross_entropy(tf.ones_like(fake_output), fake_output)


generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Define training function

def train_step(images):

noise = tf.random.normal([BATCH_SIZE, 100])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

EPOCHS = 100

for epoch in range(EPOCHS):

for i in range(train_images.shape[0] // BATCH_SIZE):
    batch_images = train_images[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
  1. Generate new images

After training is complete, we will use the generator network to generate new images. We will randomly generate 100 noise vectors and feed them into the generator network to generate new images of handwritten digits.

import matplotlib.pyplot as plt

def generate_and_save_images(model, epoch, test_input):

# 注意 training` 设定为 False
# 因此,所有层都在推理模式下运行(batchnorm)。
predictions = model(test_input, training=False)

fig = plt.figure(figsize=(4, 4))

for i in range(predictions.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')


Randomly generate noise vector

noise = tf.random .normal([16, 100])
generate_and_save_images(generator, 0, noise)

The result shows that the generator has successfully generated a new handwritten digit image. We can improve the performance of the model by gradually increasing the number of training epochs. In addition, we can further improve the performance of GAN by trying other hyperparameter combinations and network architectures.

In short, the GAN algorithm is a very useful deep learning algorithm that can be used to generate various types of data. In this article, we wrote an example of a GAN algorithm for generating images of handwritten digits using Python and showed how to train and use a generator network to generate new images.

The above is the detailed content of GAN algorithm example in Python. For more information, please follow other related articles on the PHP Chinese website!

The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn