Home  >  Article  >  Technology peripherals  >  The implementation process of image compression: variational autoencoder

The implementation process of image compression: variational autoencoder

王林
王林forward
2024-01-23 11:24:151175browse

The implementation process of image compression: variational autoencoder

Variational Autoencoder (VAE) is an unsupervised learning neural network used for image compression and generation. Compared with traditional autoencoders, VAE can reconstruct input images and generate new images similar to them. The core idea is to encode the input image into a distribution of latent variables and sample from it to generate new images. VAE is unique in using variational inference to train the model, achieving parameter learning by maximizing the lower bound between observed and generated data. This method enables VAE to learn the underlying structure of the data and the ability to generate new samples. VAE has achieved remarkable success in many fields, including tasks such as image generation, attribute editing, and image reconstruction.

The structure of VAE (variational autoencoder) is similar to that of an autoencoder, consisting of an encoder and a decoder. The encoder compresses the input image into a distribution of latent variables, including a mean vector and a variance vector. The decoder samples the latent variables to generate new images. In order to make the distribution of latent variables more reasonable, VAE introduces the regularization term of KL divergence to make the distribution of latent variables closer to the standard normal distribution. Doing so can improve the expressiveness and generation capabilities of the model.

The following takes the MNIST handwritten digit data set as an example to introduce the implementation process of VAE.

First, we need to import the necessary libraries and datasets.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# 加载数据集
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

Next, define the network structure of the encoder and decoder.

# 定义编码器
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(128 * 7 * 7, 256)
        self.fc21 = nn.Linear(256, 20) # 均值向量
        self.fc22 = nn.Linear(256, 20) # 方差向量

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = x.view(-1, 128 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        mean = self.fc21(x)
        log_var = self.fc22(x)
        return mean, log_var


# 定义解码器
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(20, 256)
        self.fc2 = nn.Linear(256, 128 * 7 * 7)
        self.conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = x.view(-1, 128, 7, 7)
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.sigmoid(self.conv3(x))
        return x


# 定义VAE模型
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps * std + mean

    def forward(self, x):
        mean, log_var = self.encoder(x)

The next step is the forward propagation process of the VAE model, which includes sampling from latent variables to generate new images, and calculating the regularization terms of the reconstruction error and KL divergence.

z = self.reparameterize(mean, log_var)
x_recon = self.decoder(z)
return x_recon, mean, log_var

def loss_function(self, x_recon, x, mean, log_var):
    recon_loss = nn.functional.binary_cross_entropy(x_recon, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return recon_loss + kl_loss

def sample(self, num_samples):
    z = torch.randn(num_samples, 20)
    samples = self.decoder(z)
    return samples

Finally, we define the optimizer and start training the model.

# 定义优化器
vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# 开始训练模型
num_epochs = 10
for epoch in range(num_epochs):
for batch_idx, (data, _) in enumerate(train_loader):
data = Variable(data)
optimizer.zero_grad()
x_recon, mean, log_var = vae(data)
loss = vae.loss_function(x_recon, data, mean, log_var)
loss.backward()
optimizer.step()

    if batch_idx % 100 == 0:
        print('Epoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'.format(
            epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.data.item()))

After training is completed, we can use VAE to generate new handwritten digit images.

# 生成手写数字图像
samples = vae.sample(10)
fig, ax = plt.subplots(1, 10, figsize=(10, 1))
for i in range(10):
ax[i].imshow(samples[i].detach().numpy().reshape(28, 28), cmap='gray')
ax[i].axis('off')
plt.show()

VAE is a powerful image compression and generative model that achieves image compression by encoding input images into a distribution of latent variables, while sampling from them to generate new images. Different from traditional autoencoders, VAE also introduces the regularization term of KL divergence to make the distribution of latent variables more reasonable. When implementing VAE, it is necessary to define the network structure of the encoder and decoder, and calculate the regularization terms of the reconstruction error and KL divergence. By training the VAE model, the latent variable distribution of the input image can be learned and new images can be generated from it.

The above is the basic introduction and implementation process of VAE. I hope it will be helpful to readers.

The above is the detailed content of The implementation process of image compression: variational autoencoder. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:163.com. If there is any infringement, please contact admin@php.cn delete