Maison >Périphériques technologiques >IA >Réseau contradictoire génératif, l'IA transforme les images en style bande dessinée

Réseau contradictoire génératif, l'IA transforme les images en style bande dessinée

WBOY
WBOYavant
2023-04-11 21:58:051672parcourir

Bonjour à tous.

Tout le monde joue avec la peinture IA récemment. J'ai trouvé un projet open source sur GitHub à partager avec vous.

Réseau contradictoire génératif, lIA transforme les images en style bande dessinée

Le projet partagé aujourd'hui est mis en œuvre à l'aide du GAN​ Generative Adversarial Network. Nous avons déjà partagé de nombreux articles sur les principes et la pratique du GAN. Les amis qui souhaitent en savoir plus peuvent lire des articles historiques.

Le code source et l'ensemble de données sont disponibles à la fin de l'article. Partageons comment former et exécuter le projet.

1. Préparez l'environnement

Installez tensorflow-gpu 1.15.0​, utilisez 2080Ti​ comme carte graphique GPU et cuda version 10.0.

Téléchargez le code source du projet AnimeGANv2 depuis git.

Après avoir configuré l'environnement, vous devez encore préparer l'ensemble de données et vgg19.

Réseau contradictoire génératif, lIA transforme les images en style bande dessinée

Téléchargez le fichier compressé dataset.zip, qui contient 6 000 images réelles et 2 000 images de bandes dessinées pour la formation GAN.

Réseau contradictoire génératif, lIA transforme les images en style bande dessinée

vgg19 est utilisé pour calculer la perte, qui sera présentée en détail ci-dessous.

2. Modèle de réseau

Le réseau antagoniste génératif doit définir deux modèles, l'un est le générateur et l'autre est le discriminateur.

Le réseau générateur est défini comme suit :

with tf.variable_scope('A'):
inputs = Conv2DNormLReLU(inputs, 32, 7)
inputs = Conv2DNormLReLU(inputs, 64, strides=2)
inputs = Conv2DNormLReLU(inputs, 64)

with tf.variable_scope('B'):
inputs = Conv2DNormLReLU(inputs, 128, strides=2)
inputs = Conv2DNormLReLU(inputs, 128)

with tf.variable_scope('C'):
inputs = Conv2DNormLReLU(inputs, 128)
inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r1')
inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r2')
inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r3')
inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r4')
inputs = Conv2DNormLReLU(inputs, 128)

with tf.variable_scope('D'):
inputs = Unsample(inputs, 128)
inputs = Conv2DNormLReLU(inputs, 128)

with tf.variable_scope('E'):
inputs = Unsample(inputs,64)
inputs = Conv2DNormLReLU(inputs, 64)
inputs = Conv2DNormLReLU(inputs, 32, 7)
with tf.variable_scope('out_layer'):
out = Conv2D(inputs, filters =3, kernel_size=1, strides=1)
self.fake = tf.tanh(out)

Le module principal du générateur est le bloc résiduel inverse

Réseau contradictoire génératif, lIA transforme les images en style bande dessinée

La structure résiduelle (a) et le bloc résiduel inverse (b)

La structure du réseau discriminateur est comme suit :

def D_net(x_init,ch, n_dis,sn, scope, reuse):
channel = ch // 2
with tf.variable_scope(scope, reuse=reuse):
x = conv(x_init, channel, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='conv_0')
x = lrelu(x, 0.2)

for i in range(1, n_dis):
x = conv(x, channel * 2, kernel=3, stride=2, pad=1, use_bias=False, sn=sn, scope='conv_s2_' + str(i))
x = lrelu(x, 0.2)

x = conv(x, channel * 4, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='conv_s1_' + str(i))
x = layer_norm(x, scope='1_norm_' + str(i))
x = lrelu(x, 0.2)

channel = channel * 2

x = conv(x, channel * 2, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='last_conv')
x = layer_norm(x, scope='2_ins_norm')
x = lrelu(x, 0.2)

x = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='D_logit')

return x

3. Perte

Avant de calculer la perte, l'image est vectorisée grâce au réseau VGG19. Ce processus est un peu comme l’opération Embedding en PNL.

Eembedding​ consiste à convertir des mots en vecteurs, et VGG19 consiste à convertir des images en vecteurs.

Réseau contradictoire génératif, lIA transforme les images en style bande dessinée

Définition VGG19

La logique de calcul de la partie perte est la suivante :

def con_sty_loss(vgg, real, anime, fake):

# 真实Réseau contradictoire génératif, lIA transforme les images en style bande dessinée向量化
vgg.build(real)
real_feature_map = vgg.conv4_4_no_activation

# 生成Réseau contradictoire génératif, lIA transforme les images en style bande dessinée向量化
vgg.build(fake)
fake_feature_map = vgg.conv4_4_no_activation

# 漫画风格向量化
vgg.build(anime[:fake_feature_map.shape[0]])
anime_feature_map = vgg.conv4_4_no_activation

# 真实Réseau contradictoire génératif, lIA transforme les images en style bande dessinée与生成Réseau contradictoire génératif, lIA transforme les images en style bande dessinée的损失
c_loss = L1_loss(real_feature_map, fake_feature_map)
# 漫画风格与生成Réseau contradictoire génératif, lIA transforme les images en style bande dessinée的损失
s_loss = style_loss(anime_feature_map, fake_feature_map)

return c_loss, s_loss

Ici, vgg19 est utilisé pour calculer respectivement la perte de l'image réelle (paramètre réel) et de l'image générée (paramètre faux). L'image générée (paramètre faux) et perte du style bande dessinée (paramètre anime).

c_loss, s_loss = con_sty_loss(self.vgg, self.real, self.anime_gray, self.generated)
t_loss = self.con_weight * c_loss + self.sty_weight * s_loss + color_loss(self.real,self.generated) * self.color_weight + tv_loss

Enfin, donnez des poids différents à ces deux pertes, afin que les images générées par le générateur conservent non seulement l'apparence des images réelles, mais migrent également vers le style bande dessinée

4. Formation

Exécutez la commande suivante dans le fichier. répertoire du projet Après avoir démarré la formation

python train.py --dataset Hayao --epoch 101 --init_epoch 10

et exécutée avec succès, vous pouvez voir les données.

Réseau contradictoire génératif, lIA transforme les images en style bande dessinée

En même temps, on constate aussi que les pertes diminuent.

Le code source et l'ensemble de données ont été empaquetés. Si vous en avez besoin, laissez simplement un message dans la zone de commentaires.

Si vous pensez que cet article vous est utile, cliquez et lisez pour m'encourager. Je continuerai à partager d'excellents projets Python+AI à l'avenir.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer