Rumah > Artikel > Peranti teknologi > Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik
Helo, semua.
Semua orang bermain dengan lukisan AI baru-baru ini saya menemui projek sumber terbuka di GitHub untuk dikongsi dengan anda.
Projek yang dikongsikan hari ini dilaksanakan menggunakan GAN Generative Adversarial Network Kami telah berkongsi banyak artikel sebelum ini tentang prinsip dan amalan GAN yang ingin tahu lebih lanjut boleh baca ia Artikel sejarah.
Kod sumber dan set data diperolehi pada akhir artikel Di sini kami berkongsi cara melatih dan menjalankan projek.
Pasang tensorflow-gpu 1.15.0, gunakan 2080Ti sebagai kad grafik GPU dan cuda versi 10.0.
git muat turun projek AnimeGANv2 kod sumber.
Selepas menyediakan persekitaran, anda perlu menyediakan set data dan vgg19.
Muat turun fail mampat dataset.zip, yang mengandungi 6k gambar sebenar dan 2k gambar komik untuk latihan GAN.
vgg19 digunakan untuk mengira kerugian, yang akan diperkenalkan secara terperinci di bawah.
Menjana rangkaian lawan memerlukan penentuan dua model, satu adalah penjana dan satu lagi adalah diskriminasi.
Rangkaian penjana ditakrifkan seperti berikut:
with tf.variable_scope('A'): inputs = Conv2DNormLReLU(inputs, 32, 7) inputs = Conv2DNormLReLU(inputs, 64, strides=2) inputs = Conv2DNormLReLU(inputs, 64) with tf.variable_scope('B'): inputs = Conv2DNormLReLU(inputs, 128, strides=2) inputs = Conv2DNormLReLU(inputs, 128) with tf.variable_scope('C'): inputs = Conv2DNormLReLU(inputs, 128) inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r1') inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r2') inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r3') inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r4') inputs = Conv2DNormLReLU(inputs, 128) with tf.variable_scope('D'): inputs = Unsample(inputs, 128) inputs = Conv2DNormLReLU(inputs, 128) with tf.variable_scope('E'): inputs = Unsample(inputs,64) inputs = Conv2DNormLReLU(inputs, 64) inputs = Conv2DNormLReLU(inputs, 32, 7) with tf.variable_scope('out_layer'): out = Conv2D(inputs, filters =3, kernel_size=1, strides=1) self.fake = tf.tanh(out)
Modul utama dalam penjana ialah blok baki songsang
sisa Struktur (a) dan blok baki terbalik (b)
Struktur rangkaian diskriminator adalah seperti berikut:
def D_net(x_init,ch, n_dis,sn, scope, reuse): channel = ch // 2 with tf.variable_scope(scope, reuse=reuse): x = conv(x_init, channel, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='conv_0') x = lrelu(x, 0.2) for i in range(1, n_dis): x = conv(x, channel * 2, kernel=3, stride=2, pad=1, use_bias=False, sn=sn, scope='conv_s2_' + str(i)) x = lrelu(x, 0.2) x = conv(x, channel * 4, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='conv_s1_' + str(i)) x = layer_norm(x, scope='1_norm_' + str(i)) x = lrelu(x, 0.2) channel = channel * 2 x = conv(x, channel * 2, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='last_conv') x = layer_norm(x, scope='2_ins_norm') x = lrelu(x, 0.2) x = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='D_logit') return x
Gunakan VGG19 sebelum mengira kerugian Rangkaian mengvektorkan imej. Proses ini sedikit seperti operasi Embedding dalam NLP.
Pembenaman ialah tentang menukar perkataan kepada vektor, dan VGG19 ialah tentang menukar gambar kepada vektor.
Definisi VGG19
Logik untuk mengira kerugian adalah seperti berikut:
def con_sty_loss(vgg, real, anime, fake): # 真实Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik向量化 vgg.build(real) real_feature_map = vgg.conv4_4_no_activation # 生成Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik向量化 vgg.build(fake) fake_feature_map = vgg.conv4_4_no_activation # 漫画风格向量化 vgg.build(anime[:fake_feature_map.shape[0]]) anime_feature_map = vgg.conv4_4_no_activation # 真实Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik与生成Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik的损失 c_loss = L1_loss(real_feature_map, fake_feature_map) # 漫画风格与生成Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik的损失 s_loss = style_loss(anime_feature_map, fake_feature_map) return c_loss, s_loss
Di sini vgg19 digunakan untuk mengira yang sebenar imej (parameter sebenar) masing-masing Kehilangan dengan imej yang dijana (parameter palsu), kehilangan imej yang dijana (parameter palsu) dengan gaya komik (parameter anime).
c_loss, s_loss = con_sty_loss(self.vgg, self.real, self.anime_gray, self.generated) t_loss = self.con_weight * c_loss + self.sty_weight * s_loss + color_loss(self.real,self.generated) * self.color_weight + tv_loss
Akhirnya berikan kedua-dua kerugian ini berat yang berbeza, supaya gambar yang dihasilkan oleh penjana bukan sahaja mengekalkan rupa gambar sebenar, tetapi juga berhijrah ke gaya komik
Laksanakan arahan berikut dalam direktori projek untuk memulakan latihan
python train.py --dataset Hayao --epoch 101 --init_epoch 10
Selepas operasi berjaya, anda boleh melihat data.
Dalam masa yang sama juga dapat dilihat kerugian yang semakin berkurangan.
Kod sumber dan set data telah dibungkus Jika anda memerlukannya, tinggalkan mesej di ruangan komen.
Jika anda mendapati artikel ini berguna kepada anda, sila klik dan baca untuk menggalakkan saya, saya akan terus berkongsi projek Python+AI yang sangat baik pada masa hadapan.
Atas ialah kandungan terperinci Rangkaian musuh generatif, AI mengubah gambar menjadi gaya komik. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!