Maison >développement back-end >Tutoriel Python >Exemple d'algorithme GAN en Python

Exemple d'algorithme GAN en Python

王林
王林original
2023-06-10 09:53:501230parcourir

Generative Adversarial Networks (GAN) est un algorithme d'apprentissage en profondeur qui génère de nouvelles données via deux réseaux de neurones en compétition. GAN est largement utilisé pour les tâches de génération dans les domaines de l'image, de l'audio, du texte et d'autres domaines. Dans cet article, nous utiliserons Python pour écrire un exemple d'algorithme GAN permettant de générer des images de chiffres manuscrits.

  1. Préparation de l'ensemble de données

Nous utiliserons l'ensemble de données MNIST comme ensemble de données de formation. L'ensemble de données MNIST contient 60 000 images d'entraînement et 10 000 images de test, chaque image est une image en niveaux de gris 28 x 28. Nous utiliserons la bibliothèque TensorFlow pour charger et traiter l'ensemble de données. Avant de charger l'ensemble de données, nous devons installer la bibliothèque TensorFlow et la bibliothèque NumPy.

importer tensorflow en tant que tf
importer numpy en tant que np

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

prétraitement des ensembles de données

train_images = train_images.reshape( train_images. shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normaliser les valeurs des pixels dans la plage de [-1, 1]

  1. Conception de l'architecture GAN et formation

Notre GAN comprendra deux réseaux de neurones : un réseau générateur et un réseau discriminateur. Le réseau générateur recevra le vecteur de bruit en entrée et produira une image 28x28. Le réseau discriminateur recevra une image 28x28 en entrée et en sortie la probabilité que l'image soit une image réelle.

L'architecture du réseau générateur et du réseau discriminateur utilisera des réseaux de neurones convolutifs (CNN). Dans le réseau générateur, nous utiliserons une couche déconvolutive pour décoder le vecteur de bruit en une image 28x28. Dans le réseau discriminateur, nous utiliserons des couches convolutives pour classer les images d'entrée.

L'entrée du réseau générateur est un vecteur de bruit de longueur 100. Nous allons empiler les couches réseau en utilisant la fonction tf.keras.Sequential.

def make_generator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size没有限制

model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)

return model

L'entrée du réseau discriminateur est une image 28x28. Nous allons empiler les couches réseau en utilisant la fonction tf.keras.Sequential.

def make_discriminator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                 input_shape=[28, 28, 1]))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1))

return model

Ensuite, nous écrirons le code de formation. Nous entraînerons alternativement le réseau générateur et le réseau discriminateur dans chaque lot. Pendant le processus de formation, nous enregistrerons les dégradés à l'aide de la fonction tf.GradientTape(), puis optimiserons le réseau à l'aide de la fonction tf.keras.optimizers.Adam().

generator = make_generator_model()
discriminator = make_discriminator_model()

Fonction de perte

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Fonction de perte de discriminateur

def discriminator_loss(real_output , fausse_sortie):

real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss

fonction de perte du générateur

def generator_loss(fake_output):

return cross_entropy(tf.ones_like(fake_output), fake_output)

optimizer

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Définir la fonction d'entraînement

@tf.function
def train_step(images):

noise = tf.random.normal([BATCH_SIZE, 100])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

BATCH_SIZE = 256
EPOCHS = 100

pour l'époque dans la plage (EPOCHS):

for i in range(train_images.shape[0] // BATCH_SIZE):
    batch_images = train_images[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
    train_step(batch_images)
  1. Générer de nouvelles images

en formation Une fois terminé , nous utiliserons le réseau générateur pour générer de nouvelles images. Nous allons générer aléatoirement 100 vecteurs de bruit et les introduire dans le réseau de générateurs pour générer de nouvelles images de chiffres manuscrits.

importer matplotlib.pyplot en tant que plt

def generate_and_save_images(model, epoch, test_input):

# 注意 training` 设定为 False
# 因此,所有层都在推理模式下运行(batchnorm)。
predictions = model(test_input, training=False)

fig = plt.figure(figsize=(4, 4))

for i in range(predictions.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.axis('off')

plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()

Générer aléatoirement des vecteurs de bruit

noise = tf.random.normal([16, 100])
generate_and_save_images(generator, 0, bruit)

Les résultats montrent que le générateur a réussi à générer de nouvelles images de chiffres manuscrits. Nous pouvons améliorer les performances du modèle en augmentant progressivement le nombre d'époques de formation. De plus, nous pouvons encore améliorer les performances du GAN en essayant d’autres combinaisons d’hyperparamètres et architectures de réseau.

En bref, l'algorithme GAN est un algorithme d'apprentissage en profondeur très utile qui peut être utilisé pour générer différents types de données. Dans cet article, nous avons écrit un exemple d'algorithme GAN pour générer des images de chiffres manuscrits à l'aide de Python et avons montré comment entraîner et utiliser un réseau générateur pour générer de nouvelles images.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn