>  기사  >  백엔드 개발  >  Python의 이미지 스타일 마이그레이션 예

Python의 이미지 스타일 마이그레이션 예

WBOY
WBOY원래의
2023-06-11 20:44:251391검색

이미지 스타일 전송은 한 이미지의 스타일을 다른 이미지로 전송할 수 있는 딥러닝 기반 기술입니다. 최근 몇 년 동안 이미지 스타일 전송 기술은 예술, 영화, TV 특수 효과 분야에서 널리 사용되었습니다. 이번 글에서는 Python 언어를 사용하여 이미지 스타일 마이그레이션을 구현하는 방법을 소개하겠습니다.

1. 이미지 스타일 전송이란 무엇인가요?

이미지 스타일 전송은 한 이미지의 스타일을 다른 이미지로 전송할 수 있습니다. 스타일은 아티스트의 그림 스타일, 사진가의 촬영 스타일 또는 기타 스타일일 수 있습니다. 이미지 스타일 전송의 목표는 원본 이미지의 내용을 보존하면서 새로운 스타일을 부여하는 것입니다.

이미지 스타일 전송 기술은 CNN(Convolutional Neural Network) 기반의 딥러닝 기술로, 사전 학습된 CNN 모델을 통해 이미지의 내용과 스타일 정보를 추출하고 최적화 방법을 사용하여 두 가지를 합성하는 것이 핵심 아이디어입니다. 이미지의 새 항목으로 이동합니다. 일반적으로 이미지의 내용 정보는 CNN의 Deep Convolutional Layer를 통해 추출되고, 이미지의 스타일 정보는 CNN의 Convolution Kernel 간의 Correlation을 통해 추출됩니다.

2. 이미지 스타일 마이그레이션 구현

Python에서 이미지 스타일 마이그레이션을 구현하는 주요 단계에는 이미지 로드, 이미지 전처리, 모델 구축, 손실 함수 계산, 최적화 방법을 사용한 반복 및 결과 출력이 포함됩니다. 다음으로 이러한 내용을 단계별로 다루겠습니다.

  1. 이미지 로드

먼저 원본 이미지와 참조 이미지를 로드해야 합니다. 원본 이미지는 스타일을 전송해야 하는 이미지이고, 참조 이미지는 전송하려는 스타일 이미지입니다. 이미지 로딩은 Python의 PIL(Python Imaging Library) 모듈을 사용하여 수행할 수 있습니다.

from PIL import Image
import numpy as np

# 载入原始图像和参考图像
content_image = Image.open('content.jpg')
style_image = Image.open('style.jpg')

# 将图像转化为numpy数组,方便后续处理
content_array = np.array(content_image)
style_array = np.array(style_image)
  1. 이미지 전처리

전처리에는 원본 이미지와 참조 이미지를 신경망이 처리할 수 있는 형식으로 변환하는 것, 즉 이미지를 Tensor로 변환하는 동시에 표준화를 수행하는 것이 포함됩니다. 여기서는 PyTorch에서 제공하는 전처리 모듈을 사용하여 완료합니다.

import torch
import torch.nn as nn
import torchvision.transforms as transforms

# 定义预处理函数
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 将图像进行预处理
content_tensor = preprocess(content_image).unsqueeze(0).to(device)
style_tensor = preprocess(style_image).unsqueeze(0).to(device)
  1. 모델 구축

이미지 스타일 전송 모델은 대규모 이미지 데이터베이스에서 훈련된 모델을 사용할 수 있습니다. 일반적으로 사용되는 모델에는 VGG19 및 ResNet이 있습니다. 여기서는 VGG19 모델을 사용하여 완성합니다. 먼저 사전 훈련된 VGG19 모델을 로드하고 마지막 완전 연결 레이어를 제거하고 컨볼루셔널 레이어만 남겨야 합니다. 그런 다음 컨볼루셔널 레이어의 가중치를 수정하여 이미지의 콘텐츠 정보와 스타일 정보를 조정해야 합니다.

import torchvision.models as models

class VGG(nn.Module):
    def __init__(self, requires_grad=False):
        super(VGG, self).__init__()
        vgg19 = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        self.slice5 = nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg19[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg19[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg19[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg19[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg19[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        h_relu1 = self.slice1(x)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        return h_relu1, h_relu2, h_relu3, h_relu4, h_relu5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG().to(device).eval()
  1. 손실 함수 계산

이미지 스타일 전송의 목표는 원본 이미지의 내용을 유지하면서 새로운 스타일을 부여하는 것이므로 이 목표를 달성하려면 손실 함수를 정의해야 합니다. 손실 함수는 두 부분으로 구성됩니다. 하나는 콘텐츠 손실이고 다른 하나는 스타일 손실입니다.

컨볼루션 레이어의 특징 맵에서 원본 이미지와 생성된 이미지 사이의 평균 제곱 오차를 계산하여 콘텐츠 손실을 정의할 수 있습니다. 스타일 손실은 생성된 이미지의 특징 맵과 컨볼루셔널 레이어의 스타일 이미지 사이의 그램 행렬 사이의 평균 제곱 오차를 계산하여 정의됩니다. 여기서 그램 행렬은 기능 맵의 컨볼루션 커널 간의 상관 행렬입니다.

def content_loss(content_features, generated_features):
    return torch.mean((content_features - generated_features)**2)

def gram_matrix(input):
    batch_size , h, w, f_map_num = input.size()
    features = input.view(batch_size * h, w * f_map_num)
    G = torch.mm(features, features.t())
    return G.div(batch_size * h * w * f_map_num)

def style_loss(style_features, generated_features):
    style_gram = gram_matrix(style_features)
    generated_gram = gram_matrix(generated_features)
    return torch.mean((style_gram - generated_gram)**2)

content_weight = 1
style_weight = 1000

def compute_loss(content_features, style_features, generated_features):
    content_loss_fn = content_loss(content_features, generated_features[0])
    style_loss_fn = style_loss(style_features, generated_features[1])
    loss = content_weight * content_loss_fn + style_weight * style_loss_fn
    return loss, content_loss_fn, style_loss_fn
  1. 최적화 방법을 사용하여 반복

손실 함수를 계산한 후 최적화 방법을 사용하여 생성된 이미지의 픽셀 값을 조정하여 손실 함수를 최소화할 수 있습니다. 일반적으로 사용되는 최적화 방법에는 경사하강법과 L-BFGS 알고리즘이 있습니다. 여기서는 PyTorch에서 제공하는 LBFGS 최적화 프로그램을 사용하여 이미지 마이그레이션을 완료합니다. 반복 횟수는 필요에 따라 조정할 수 있습니다. 일반적으로 2000번 반복하면 더 나은 결과를 얻을 수 있습니다.

from torch.optim import LBFGS

generated = content_tensor.detach().clone().requires_grad_(True).to(device)

optimizer = LBFGS([generated])

for i in range(2000):

    def closure():
        optimizer.zero_grad()
        generated_features = model(generated)
        loss, content_loss_fn, style_loss_fn = compute_loss(content_features, style_features, generated_features)
        loss.backward()
        return content_loss_fn + style_loss_fn

    optimizer.step(closure)

    if i % 100 == 0:
        print('Iteration:', i)
        print('Total loss:', closure().tolist())
  1. 출력 결과

마지막으로 생성된 이미지를 로컬에 저장하고 이미지 스타일 마이그레이션 효과를 관찰할 수 있습니다.

import matplotlib.pyplot as plt

generated_array = generated.cpu().detach().numpy()
generated_array = np.squeeze(generated_array, 0)
generated_array = generated_array.transpose(1, 2, 0)
generated_array = np.clip(generated_array, 0, 1)

plt.imshow(generated_array)
plt.axis('off')
plt.show()

Image.fromarray(np.uint8(generated_array * 255)).save('generated.jpg')

3. 요약

이 글에서는 Python 언어를 사용하여 이미지 스타일 전송 기술을 구현하는 방법을 소개합니다. 이미지 로드, 이미지 전처리, 모델 구축, 손실 함수 계산, 최적화 방법 반복 및 결과 출력을 통해 한 이미지의 스타일을 다른 이미지로 전송할 수 있습니다. 실제 적용에서는 더 나은 결과를 얻기 위해 다양한 요구에 따라 참조 이미지 및 반복 횟수와 같은 매개변수를 조정할 수 있습니다.

위 내용은 Python의 이미지 스타일 마이그레이션 예의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.
이전 기사:Python의 SVM 예제다음 기사:Python의 SVM 예제