ホームページ >テクノロジー周辺機器 >AI >変分オートエンコーダ: 理論と実装

変分オートエンコーダ: 理論と実装

PHPz
PHPz転載
2024-01-24 11:36:07743ブラウズ

如何实现变分自动编码器 变分自动编码器的原理和实现步骤

変分オートエンコーダ (VAE) は、ニューラル ネットワークに基づく生成モデルです。その目標は、高次元データの低次元の潜在変数表現を学習し、これらの潜在変数をデータの再構築と生成に使用することです。従来のオートエンコーダと比較して、VAE は潜在空間の分布を学習することで、より現実的で多様なサンプルを生成できます。 VAEの実装方法については、以下で詳しく紹介します。

1. VAE の基本原理

VAE の基本的な考え方は、高次元のデータを低次元のデータにマッピングすることです。潜在空間、次元の削減と再構築。これは、エンコーダとデコーダの 2 つの部分で構成されます。エンコーダーは入力データ x を潜在空間の平均 μ と分散 σ^2 にマッピングします。このようにして、VAE は潜在空間内のデータをサンプリングし、デコーダを通じてサンプリング結果を元のデータに再構築できます。このエンコーダ/デコーダ構造により、VAE は潜在空間内で連続性の高い新しいサンプルを生成できるようになり、類似したサンプルが潜在空間内でより近くに存在するようになります。したがって、VAE は次元削減に使用できるだけでなく、

\begin{aligned}
\mu &=f_{\mu}(x)\
\sigma^2 &=f_{\sigma}(x)
\end{aligned}

ここで、f_{\mu} と f_{\sigma} は任意のニューラル ネットワーク モデルにすることができます。通常、エンコーダーの実装にはマルチレイヤー パーセプトロン (MLP) を使用します。

デコーダは、潜在変数 z を元のデータ空間、つまり、次のようにマッピングします。

x'=g(z)

その中で、 g は任意のニューラル ネットワーク モデルにすることもできます。同様に、通常は MLP を使用してデコーダを実装します。

VAE では、潜在変数 $z$ は事前分布 (通常はガウス分布) からサンプリングされます。つまり:

z\sim\mathcal{N}(0,I)

このようにして、VAE は次のことができます。潜在変数の再構成誤差と KL 発散を最小限に抑えてトレーニングすることで、次元の削減とデータの生成を実現します。具体的には、VAE の損失関数は次のように表すことができます。

\mathcal{L}=\mathbb{E}_{z\sim q(z|x)}[\log p(x|z)]-\beta\mathrm{KL}[q(z|x)||p(z)]

ここで、q(z|x) は事後分布、つまり入力 x が与えられたときの潜在変数 z の条件付き分布です。 p(x|z ) は生成分布、つまり潜在変数 $z$ が与えられた場合の対応するデータ分布、p(z) は事前分布、つまり潜在変数 z の周辺分布です。 beta は、再構成エラーと KL 発散のバランスを取るために使用されるハイパーパラメーターです。

上記の損失関数を最小化することで、入力データ x を潜在空間の分布 q(z|x) にマッピングできる変換関数 f(x) を学習できます。 、そこから潜在変数 z をサンプリングすることで、次元の削減とデータの生成を実現できます。

2. VAE 実装手順

以下では、エンコーダー、デコーダー、損失関数定義を含む、基本的な VAE モデルを実装する方法を紹介します。 MNIST の手書き数字データ セットを例にとると、このデータ セットには 60,000 個のトレーニング サンプルと 10,000 個のテスト サンプルが含まれており、各サンプルは 28x28 のグレースケール画像です。

2.1 データの前処理

まず、MNIST データセットを前処理して、各サンプルを 784 次元のベクトルに変換し、正規化する必要があります。 [0,1]の範囲まで。コードは次のとおりです。

# python

import torch

import torchvision.transforms as transforms

from torchvision.datasets import MNIST

# 定义数据预处理

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换成Tensor格式
    transforms.Normalize(mean=(0.

2.2 モデル構造の定義

次に、VAE モデルの構造を定義する必要があります。潜在変数のエンコーダ、デコーダおよびサンプリング関数。この例では、エンコーダとデコーダとして 2 層 MLP を使用し、各層の隠れユニットの数はそれぞれ 256 と 128 です。潜在変数の次元は 20 です。コードは次のとおりです。

import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
        super(VAE, self).__init__()

        # 定义编码器的结构
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, latent_dim*2)  # 输出均值和方差
        )

        # 定义解码器的结构
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出范围在[0, 1]之间的概率
        )

    # 潜在变量的采样函数
    def sample_z(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    # 前向传播函数
    def forward(self, x):
        # 编码器
        h = self.encoder(x)
        mu, logvar = h[:, :latent_dim], h[:, latent_dim:]
        z = self.sample_z(mu, logvar)

        # 解码器
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

上記のコードでは、エンコーダおよびデコーダとして 2 層 MLP を使用します。エンコーダーは入力データを潜在空間の平均と分散にマッピングします。平均の次元は 20、分散の次元も 20 であり、これにより潜在変数の次元は 20 になります。デコーダーは潜在変数を元のデータ空間にマップし直します。最後の層はシグモイド関数を使用して出力範囲を [0, 1] に制限します。

VAE モデルを実装するときは、損失関数も定義する必要があります。この例では、再構成誤差と KL 発散を使用して損失関数を定義します。再構成誤差はクロスエントロピー損失関数を使用し、KL 発散は事前分布として標準正規分布を使用します。コードは次のとおりです。

# 定义损失函数
def vae_loss(x_hat, x, mu, logvar, beta=1):
    # 重构误差
    recon_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')

    # KL散度
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + beta*kl_loss

上記のコードでは、クロスエントロピー損失関数を使用して再構成誤差を計算し、KL 発散を使用して潜在変数の分布と事前分布の差を計算します。 。このうち、 \beta は再構成誤差と KL 発散のバランスを取るために使用されるハイパーパラメーターです。

2.3 トレーニング モデル

最後に、トレーニング関数を定義し、MNIST データ セットで VAE モデルをトレーニングする必要があります。トレーニング プロセスでは、まずモデルの損失関数を計算し、次にバックプロパゲーション アルゴリズムを使用してモデル パラメーターを更新する必要があります。コードは次のとおりです。

# python
# 定义训练函数

def train(model, dataloader, optimizer, device, beta):
    model.train()
    train_loss = 0

for x, _ in dataloader:
    x = x.view(-1, input_dim).to(device)
    optimizer.zero_grad()
    x_hat, mu, logvar = model(x)
    loss = vae_loss(x_hat, x, mu, logvar, beta)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

return train_loss / len(dataloader.dataset)

ここで、上記のトレーニング関数を使用して、MNIST データ セットで VAE モデルをトレーニングできます。コードは次のとおりです。

# 定义模型和优化器
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = train(model, trainloader, optimizer, device, beta=1)
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}')

# 测试模型
model.eval()
with torch.no_grad():
    test_loss = 0
    for x, _ in testloader:
        x = x.view(-1, input_dim).to(device)
        x_hat, mu, logvar = model(x)
        test_loss += vae_loss(x_hat, x, mu, logvar, beta=1).item()
    test_loss /= len(testloader.dataset)
    print(f'Test Loss: {test_loss:.4f}')

トレーニング プロセス中に、Adam オプティマイザーと \beta=1 のハイパーパラメーターを使用してモデル パラメーターを更新します。トレーニングが完了したら、テスト セットを使用してモデルの損失関数を計算します。この例では、再構成誤差と KL 発散を使用して損失関数を計算します。そのため、テスト損失が小さいほど、モデルによって学習される潜在的な表現が向上し、生成されるサンプルがより現実的になります。

2.4 サンプルの生成

最后,我们可以使用VAE模型生成新的手写数字样本。生成样本的过程非常简单,只需要在潜在空间中随机采样,然后将采样结果输入到解码器中生成新的样本。代码如下:

# 生成新样本
n_samples = 10
with torch.no_grad():
    # 在潜在空间中随机采样
    z = torch.randn(n_samples, latent_dim).to(device)
    # 解码生成样本
    samples = model.decode(z).cpu()
    # 将样本重新变成图像的形状
    samples = samples.view(n_samples, 1, 28, 28)
    # 可视化生成的样本
    fig, axes = plt.subplots(1, n_samples, figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(samples[i][0], cmap='gray')
        ax.axis('off')
    plt.show()

在上述代码中,我们在潜在空间中随机采样10个点,然后将这些点输入到解码器中生成新的样本。最后,我们将生成的样本可视化展示出来,可以看到,生成的样本与MNIST数据集中的数字非常相似。

综上,我们介绍了VAE模型的原理、实现和应用,可以看到,VAE模型是一种非常强大的生成模型,可以学习到高维数据的潜在表示,并用潜在表示生成新的样本。

以上が変分オートエンコーダ: 理論と実装の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事は163.comで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。