Maison  >  Article  >  développement back-end  >  Exemple de migration de style d'image en Python

Exemple de migration de style d'image en Python

WBOY
WBOYoriginal
2023-06-11 20:44:251377parcourir

Le transfert de style d'image est une technologie basée sur l'apprentissage profond qui permet de transférer le style d'une image à une autre image. Ces dernières années, la technologie de transfert de style d’image a été largement utilisée dans les domaines de l’art et des effets spéciaux cinématographiques et télévisuels. Dans cet article, nous présenterons comment implémenter la migration de style d'image à l'aide du langage Python.

1. Qu'est-ce que le transfert de style d'image

Le transfert de style d'image peut transférer le style d'une image à une autre image. Le style peut être le style de peinture de l'artiste, le style de prise de vue du photographe ou d'autres styles. Le but du transfert de style d’image est de préserver le contenu de l’image originale tout en lui donnant un nouveau style.

La technologie de transfert de style d'image est une technologie d'apprentissage en profondeur basée sur un réseau neuronal convolutif (CNN). Son idée principale est d'extraire les informations de contenu et de style de l'image via un modèle CNN pré-entraîné et d'utiliser des méthodes d'optimisation pour synthétiser les deux. dans un nouveau sur l'image. Généralement, les informations sur le contenu d'une image sont extraites via les couches de convolution profondes de CNN, tandis que les informations de style de l'image sont extraites via la corrélation entre les noyaux de convolution de CNN.

2. Implémenter la migration de style d'image

Les principales étapes pour implémenter la migration de style d'image en Python comprennent le chargement d'images, le prétraitement des images, la création de modèles, le calcul des fonctions de perte, l'utilisation de méthodes d'optimisation pour itérer et générer des résultats. Ensuite, nous les aborderons étape par étape.

  1. Chargement des images

Tout d'abord, nous devons charger une image originale et une image de référence. L'image d'origine est l'image dont le style doit être transféré et l'image de référence est l'image de style qui doit être transférée. Le chargement des images peut être effectué à l'aide du module PIL (Python Imaging Library) de Python.

from PIL import Image
import numpy as np

# 载入原始图像和参考图像
content_image = Image.open('content.jpg')
style_image = Image.open('style.jpg')

# 将图像转化为numpy数组,方便后续处理
content_array = np.array(content_image)
style_array = np.array(style_image)
  1. Prétraitement des images

Le prétraitement comprend la conversion des images originales et des images de référence dans un format que le réseau neuronal peut traiter, c'est-à-dire la conversion de l'image en Tensor et la standardisation en même temps. Ici, nous utilisons le module de prétraitement fourni par PyTorch pour terminer.

import torch
import torch.nn as nn
import torchvision.transforms as transforms

# 定义预处理函数
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 将图像进行预处理
content_tensor = preprocess(content_image).unsqueeze(0).to(device)
style_tensor = preprocess(style_image).unsqueeze(0).to(device)
  1. Création de modèles

Le modèle de transfert de style d'image peut utiliser des modèles qui ont été formés sur des bases de données d'images à grande échelle. Les modèles couramment utilisés incluent VGG19 et ResNet. Ici, nous utilisons le modèle VGG19 pour compléter. Tout d'abord, nous devons charger le modèle VGG19 pré-entraîné et supprimer la dernière couche entièrement connectée, ne laissant que la couche convolutive. Ensuite, nous devons ajuster les informations de contenu et les informations de style de l'image en modifiant les poids de la couche convolutive.

import torchvision.models as models

class VGG(nn.Module):
    def __init__(self, requires_grad=False):
        super(VGG, self).__init__()
        vgg19 = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        self.slice5 = nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg19[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg19[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg19[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg19[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg19[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        h_relu1 = self.slice1(x)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        return h_relu1, h_relu2, h_relu3, h_relu4, h_relu5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG().to(device).eval()
  1. Calculer la fonction de perte

Étant donné que le but du transfert de style d'image est de conserver le contenu de l'image originale tout en lui donnant un nouveau style, nous devons définir une fonction de perte pour atteindre cet objectif. La fonction de perte se compose de deux parties, l’une est la perte de contenu et l’autre la perte de style.

La perte de contenu peut être définie en calculant l'erreur quadratique moyenne entre l'image originale et l'image générée dans la carte des caractéristiques de la couche convolutive. La perte de style est définie en calculant l'erreur quadratique moyenne entre la matrice de Gram entre la carte caractéristique de l'image générée et l'image de style dans la couche convolutive. La matrice de Gram est ici la matrice de corrélation entre les noyaux de convolution de la carte de caractéristiques.

def content_loss(content_features, generated_features):
    return torch.mean((content_features - generated_features)**2)

def gram_matrix(input):
    batch_size , h, w, f_map_num = input.size()
    features = input.view(batch_size * h, w * f_map_num)
    G = torch.mm(features, features.t())
    return G.div(batch_size * h * w * f_map_num)

def style_loss(style_features, generated_features):
    style_gram = gram_matrix(style_features)
    generated_gram = gram_matrix(generated_features)
    return torch.mean((style_gram - generated_gram)**2)

content_weight = 1
style_weight = 1000

def compute_loss(content_features, style_features, generated_features):
    content_loss_fn = content_loss(content_features, generated_features[0])
    style_loss_fn = style_loss(style_features, generated_features[1])
    loss = content_weight * content_loss_fn + style_weight * style_loss_fn
    return loss, content_loss_fn, style_loss_fn
  1. Itérer à l'aide de méthodes d'optimisation

Après avoir calculé la fonction de perte, nous pouvons utiliser des méthodes d'optimisation pour ajuster les valeurs de pixels de l'image générée afin de minimiser la fonction de perte. Les méthodes d'optimisation couramment utilisées incluent la méthode de descente de gradient et l'algorithme L-BFGS. Ici, nous utilisons l'optimiseur LBFGS fourni par PyTorch pour terminer la migration des images. Le nombre d'itérations peut être ajusté selon les besoins. Habituellement, 2 000 itérations peuvent obtenir de meilleurs résultats.

from torch.optim import LBFGS

generated = content_tensor.detach().clone().requires_grad_(True).to(device)

optimizer = LBFGS([generated])

for i in range(2000):

    def closure():
        optimizer.zero_grad()
        generated_features = model(generated)
        loss, content_loss_fn, style_loss_fn = compute_loss(content_features, style_features, generated_features)
        loss.backward()
        return content_loss_fn + style_loss_fn

    optimizer.step(closure)

    if i % 100 == 0:
        print('Iteration:', i)
        print('Total loss:', closure().tolist())
  1. Résultats de sortie

Enfin, nous pouvons enregistrer l'image générée localement et observer l'effet de la migration du style d'image.

import matplotlib.pyplot as plt

generated_array = generated.cpu().detach().numpy()
generated_array = np.squeeze(generated_array, 0)
generated_array = generated_array.transpose(1, 2, 0)
generated_array = np.clip(generated_array, 0, 1)

plt.imshow(generated_array)
plt.axis('off')
plt.show()

Image.fromarray(np.uint8(generated_array * 255)).save('generated.jpg')

3. Résumé

Cet article présente comment utiliser le langage Python pour implémenter la technologie de transfert de style d'image. En chargeant l'image, en prétraitant l'image, en construisant le modèle, en calculant la fonction de perte, en itérant avec la méthode d'optimisation et en produisant le résultat, nous pouvons transférer le style d'une image à une autre. Dans les applications pratiques, nous pouvons ajuster des paramètres tels que les images de référence et le nombre d'itérations en fonction des différents besoins pour obtenir de meilleurs résultats.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Article précédent:Exemples SVM en PythonArticle suivant:Exemples SVM en Python