Maison >Périphériques technologiques >IA >Auto-encodeurs variationnels : théorie et mise en œuvre

Auto-encodeurs variationnels : théorie et mise en œuvre

PHPz
PHPzavant
2024-01-24 11:36:07743parcourir

如何实现变分自动编码器 变分自动编码器的原理和实现步骤

Variational Autoencoder (VAE) est un modèle génératif basé sur des réseaux de neurones. Son objectif est d'apprendre des représentations de variables latentes de faible dimension de données de grande dimension et d'utiliser ces variables latentes pour la reconstruction et la génération de données. Par rapport aux auto-encodeurs traditionnels, le VAE peut générer des échantillons plus réalistes et plus diversifiés en apprenant la distribution de l'espace latent. Le mode de mise en œuvre de la VAE sera présenté en détail ci-dessous.

1. Principes de base de la VAE

L'idée de base de la VAE est de réaliser une réduction de dimensionnalité et une reconstruction des données en mappant des données de haute dimension sur un espace latent de basse dimension. Il se compose de deux parties : encodeur et décodeur. L'encodeur mappe les données d'entrée x à la moyenne μ et à la variance σ^2 de l'espace latent. De cette manière, VAE peut échantillonner les données dans l'espace latent et reconstruire les résultats échantillonnés dans les données originales via le décodeur. Cette structure codeur-décodeur permet à VAE de générer de nouveaux échantillons avec une bonne continuité dans l'espace latent, rapprochant ainsi les échantillons similaires dans l'espace latent. Par conséquent, VAE ne peut pas seulement être utilisé pour la réduction de dimensionnalité et

\begin{aligned}
\mu &=f_{\mu}(x)\
\sigma^2 &=f_{\sigma}(x)
\end{aligned}

où f_{mu} et f_{sigma} peuvent être n'importe quel modèle de réseau neuronal. Généralement, nous utilisons un Perceptron multicouche (MLP) pour implémenter l'encodeur.

Le décodeur mappe la variable latente z à l'espace de données d'origine, c'est-à-dire :

x'=g(z)

où, g peut également être n'importe quel modèle de réseau neuronal. De même, nous utilisons généralement un MLP pour implémenter le décodeur.

En VAE, la variable latente $z$ est échantillonnée à partir d'une distribution a priori (généralement une distribution gaussienne), à ​​savoir :

z\sim\mathcal{N}(0,I)

De cette façon, nous pouvons minimiser l'erreur de reconstruction et la variable latente KL divergence est utilisée pour former la VAE pour parvenir à la réduction de dimensionnalité et à la génération de données. Plus précisément, la fonction de perte de VAE peut être exprimée comme suit :

\mathcal{L}=\mathbb{E}_{z\sim q(z|x)}[\log p(x|z)]-\beta\mathrm{KL}[q(z|x)||p(z)]

où q(z|x) est la distribution a posteriori, c'est-à-dire la distribution conditionnelle de la variable latente z lorsque l'entrée x est donnée ; z) est la distribution génératrice, c'est-à-dire la distribution de données correspondante lorsqu'une variable latente $z$ est donnée p(z) est la distribution a priori, c'est-à-dire que la distribution marginale de la variable latente bêta est un hyperparamètre utilisé ; pour équilibrer l'erreur de reconstruction et la divergence KL.

En minimisant la fonction de perte ci-dessus, nous pouvons apprendre une fonction de transformation f(x), qui peut mapper les données d'entrée x à la distribution q(z|x) de l'espace latent, et peut échantillonner les variables latentes de it z, réalisant ainsi une réduction de dimensionnalité et une génération de données.

2. Étapes de mise en œuvre de VAE

Ci-dessous, nous présenterons comment implémenter un modèle VAE de base, y compris la définition de l'encodeur, du décodeur et de la fonction de perte. Nous prenons comme exemple l'ensemble de données de chiffres manuscrits du MNIST. Cet ensemble de données contient 60 000 échantillons d'apprentissage et 10 000 échantillons de test, chaque échantillon est une image en niveaux de gris 28x28.

2.1 Prétraitement des données

Tout d'abord, nous devons prétraiter l'ensemble de données MNIST, convertir chaque échantillon en un vecteur à 784 dimensions et le normaliser dans la plage de [0,1] Inside. Le code est le suivant :

# python

import torch

import torchvision.transforms as transforms

from torchvision.datasets import MNIST

# 定义数据预处理

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换成Tensor格式
    transforms.Normalize(mean=(0.

2.2 Définir la structure du modèle

Ensuite, nous devons définir la structure du modèle VAE, y compris l'encodeur, le décodeur et la fonction d'échantillonnage de la variable latente. Dans cet exemple, nous utilisons un MLP à deux couches comme encodeur et décodeur, le nombre d'unités cachées dans chaque couche étant respectivement de 256 et 128. La dimension de la variable latente est 20. Le code est le suivant :

import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
        super(VAE, self).__init__()

        # 定义编码器的结构
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, latent_dim*2)  # 输出均值和方差
        )

        # 定义解码器的结构
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出范围在[0, 1]之间的概率
        )

    # 潜在变量的采样函数
    def sample_z(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    # 前向传播函数
    def forward(self, x):
        # 编码器
        h = self.encoder(x)
        mu, logvar = h[:, :latent_dim], h[:, latent_dim:]
        z = self.sample_z(mu, logvar)

        # 解码器
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

Dans le code ci-dessus, nous utilisons un MLP à deux couches comme encodeur et décodeur. L'encodeur mappe les données d'entrée sur la moyenne et la variance de l'espace latent, où la dimension de la moyenne est de 20 et la dimension de la variance est également de 20, ce qui garantit que la dimension de la variable latente est de 20. Le décodeur mappe les variables latentes sur l'espace de données d'origine, où la dernière couche utilise la fonction Sigmoïde pour limiter la plage de sortie à [0, 1].

Lors de la mise en œuvre du modèle VAE, nous devons également définir la fonction de perte. Dans cet exemple, nous utilisons l'erreur de reconstruction et la divergence KL pour définir la fonction de perte, où l'erreur de reconstruction utilise la fonction de perte d'entropie croisée et la divergence KL utilise la distribution normale standard comme distribution a priori. Le code est le suivant :

# 定义损失函数
def vae_loss(x_hat, x, mu, logvar, beta=1):
    # 重构误差
    recon_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')

    # KL散度
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + beta*kl_loss

Dans le code ci-dessus, nous utilisons la fonction de perte d'entropie croisée pour calculer l'erreur de reconstruction et la divergence KL pour calculer la différence entre la distribution de la variable latente et la distribution a priori. Parmi eux, bêta est un hyperparamètre utilisé pour équilibrer l’erreur de reconstruction et la divergence KL.

2.3 Modèle de formation

Enfin, nous devons définir la fonction de formation et entraîner le modèle VAE sur l'ensemble de données MNIST. Pendant le processus de formation, nous devons d'abord calculer la fonction de perte du modèle, puis utiliser l'algorithme de rétropropagation pour mettre à jour les paramètres du modèle. Le code est le suivant :

# python
# 定义训练函数

def train(model, dataloader, optimizer, device, beta):
    model.train()
    train_loss = 0

for x, _ in dataloader:
    x = x.view(-1, input_dim).to(device)
    optimizer.zero_grad()
    x_hat, mu, logvar = model(x)
    loss = vae_loss(x_hat, x, mu, logvar, beta)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

return train_loss / len(dataloader.dataset)

Maintenant, nous pouvons utiliser la fonction de formation ci-dessus pour entraîner le modèle VAE sur l'ensemble de données MNIST. Le code est le suivant :

# 定义模型和优化器
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = train(model, trainloader, optimizer, device, beta=1)
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}')

# 测试模型
model.eval()
with torch.no_grad():
    test_loss = 0
    for x, _ in testloader:
        x = x.view(-1, input_dim).to(device)
        x_hat, mu, logvar = model(x)
        test_loss += vae_loss(x_hat, x, mu, logvar, beta=1).item()
    test_loss /= len(testloader.dataset)
    print(f'Test Loss: {test_loss:.4f}')

Pendant le processus de formation, nous utilisons l'optimiseur Adam et l'hyperparamètre beta=1 pour mettre à jour les paramètres du modèle. Une fois la formation terminée, nous utilisons l'ensemble de test pour calculer la fonction de perte du modèle. Dans cet exemple, nous utilisons l'erreur de reconstruction et la divergence KL pour calculer la fonction de perte. Ainsi, plus la perte de test est faible, meilleure est la représentation potentielle apprise par le modèle et plus les échantillons générés sont réalistes.

2.4 Générer un échantillon

最后,我们可以使用VAE模型生成新的手写数字样本。生成样本的过程非常简单,只需要在潜在空间中随机采样,然后将采样结果输入到解码器中生成新的样本。代码如下:

# 生成新样本
n_samples = 10
with torch.no_grad():
    # 在潜在空间中随机采样
    z = torch.randn(n_samples, latent_dim).to(device)
    # 解码生成样本
    samples = model.decode(z).cpu()
    # 将样本重新变成图像的形状
    samples = samples.view(n_samples, 1, 28, 28)
    # 可视化生成的样本
    fig, axes = plt.subplots(1, n_samples, figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(samples[i][0], cmap='gray')
        ax.axis('off')
    plt.show()

在上述代码中,我们在潜在空间中随机采样10个点,然后将这些点输入到解码器中生成新的样本。最后,我们将生成的样本可视化展示出来,可以看到,生成的样本与MNIST数据集中的数字非常相似。

综上,我们介绍了VAE模型的原理、实现和应用,可以看到,VAE模型是一种非常强大的生成模型,可以学习到高维数据的潜在表示,并用潜在表示生成新的样本。

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer