搜尋
首頁科技週邊人工智慧變分自動編碼器:理論與實作方案

變分自動編碼器:理論與實作方案

Jan 24, 2024 am 11:36 AM
機器學習人工神經網絡

如何实现变分自动编码器 变分自动编码器的原理和实现步骤

變分自動編碼器(VAE)是一種基於神經網路的生成模型。它的目標是學習高維度資料的低維潛在變數表示,並利用這些潛在變數進行資料的重構和生成。相較於傳統的自動編碼器,VAE透過學習潛在空間的分佈,可以產生更真實且多樣性的樣本。以下將詳細介紹VAE的實作方法。

1.VAE的基本原理

#VAE的基本想法是透過將高維度資料映射到低維的潛在空間,實現數據的降維和重構。它由編碼器和解碼器兩個部分組成。編碼器將輸入資料x映射到潛在空間的平均值μ和變異數σ^2。透過這種方式,VAE可以在潛在空間中對資料進行取樣,並透過解碼器將取樣結果重構為原始資料。這種編碼器-解碼器結構使得VAE能夠產生新的樣本,並且在潛在空間中具有良好的連續性,使得相似的樣本在潛在空間中距離較近。因此,VAE不僅可以用於降維和

\begin{aligned}
\mu &=f_{\mu}(x)\
\sigma^2 &=f_{\sigma}(x)
\end{aligned}

其中,f_{\mu}和f_{\sigma}可以是任意的神經網路模型。通常情況下,我們使用一個多層感知機(Multilayer Perceptron,MLP)來實作編碼器。

解碼器則將潛在變數z映射回原始資料空間,即:

x'=g(z)

其中,g也可以是任意的神經網路模型。同樣地,我們通常使用一個MLP來實作解碼器。

在VAE中,潛在變數$z$是從一個先驗分佈(通常是高斯分佈)中取樣得到的,即:

z\sim\mathcal{N}(0,I)

這樣,我們就可以透過最小化重建誤差和潛在變數的KL散度來訓練VAE,從而實現資料的降維和生成。具體來說,VAE的損失函數可以表示為:

\mathcal{L}=\mathbb{E}_{z\sim q(z|x)}[\log p(x|z)]-\beta\mathrm{KL}[q(z|x)||p(z)]

其中,q(z|x)是後驗分佈,即給定輸入x時潛在變數z的條件分佈;p(x|z )是生成分佈,即給定潛在變數$z$時對應的資料分佈;p(z)是先驗分佈,即潛在變數z的邊緣分佈;\beta是超參數,用於平衡重構誤差和KL散度。

透過最小化上述損失函數,我們可以學習到一個轉換函數f(x),它可以將輸入資料x對應到潛在空間的分佈q(z|x)中,並且可以從中採樣得到潛在變數z,從而實現資料的降維和生成。

2.VAE的實作步驟

#下面我們將介紹如何實作一個基本的VAE模型,包括編碼器、解碼器和損失函數的定義。我們以MNIST手寫數位資料集為例,此資料集包含60,000個訓練樣本和10,000個測試樣本,每個樣本為一張28x28的灰階影像。

2.1資料預處理

首先,我們需要對MNIST資料集進行預處理,將每個樣本轉換成一個784維的向量,並將其歸一化到[0,1]的範圍內。程式碼如下:

# python

import torch

import torchvision.transforms as transforms

from torchvision.datasets import MNIST

# 定义数据预处理

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换成Tensor格式
    transforms.Normalize(mean=(0.

2.2 定義模型結構

接下來,我們需要定義VAE模型的結構,包括編碼器、解碼器和潛在變數的取樣函數。在本例中,我們使用一個兩層的MLP作為編碼器和解碼器,每層的隱藏單元數分別為256和128。潛在變數的維度為20。程式碼如下:

import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
        super(VAE, self).__init__()

        # 定义编码器的结构
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, latent_dim*2)  # 输出均值和方差
        )

        # 定义解码器的结构
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出范围在[0, 1]之间的概率
        )

    # 潜在变量的采样函数
    def sample_z(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    # 前向传播函数
    def forward(self, x):
        # 编码器
        h = self.encoder(x)
        mu, logvar = h[:, :latent_dim], h[:, latent_dim:]
        z = self.sample_z(mu, logvar)

        # 解码器
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

在上述程式碼中,我們使用一個兩層的MLP作為編碼器和解碼器。編碼器將輸入資料對應到潛在空間的平均值和方差,其中平均值的維度為20,方差的維度也為20,這樣可以保證潛在變數的維度為20。解碼器將潛在變數映射回原始資料空間,其中最後一層使用Sigmoid函數將輸出範圍限制在[0, 1]之間。

在實作VAE模型時,我們還需要定義損失函數。在本例中,我們使用重構誤差和KL散度來定義損失函數,其中重構誤差使用交叉熵損失函數,KL散度使用標準常態分佈作為先驗分佈。程式碼如下:

# 定义损失函数
def vae_loss(x_hat, x, mu, logvar, beta=1):
    # 重构误差
    recon_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')

    # KL散度
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + beta*kl_loss

在上述程式碼中,我們使用交叉熵損失函數計算重建誤差,使用KL散度計算潛在變數的分佈與先驗分佈之間的差異。其中,\beta是一個超參數,用於平衡重建誤差和KL散度。

2.3 訓練模型

最後,我們需要定義訓練函數,並在MNIST資料集上訓練VAE模型。訓練過程中,我們首先需要計算模型的損失函數,然後使用反向傳播演算法更新模型參數。程式碼如下:

# python
# 定义训练函数

def train(model, dataloader, optimizer, device, beta):
    model.train()
    train_loss = 0

for x, _ in dataloader:
    x = x.view(-1, input_dim).to(device)
    optimizer.zero_grad()
    x_hat, mu, logvar = model(x)
    loss = vae_loss(x_hat, x, mu, logvar, beta)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

return train_loss / len(dataloader.dataset)

現在,我們可以使用上述訓練函數在MNIST資料集上訓練VAE模型了。程式碼如下:

# 定义模型和优化器
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = train(model, trainloader, optimizer, device, beta=1)
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}')

# 测试模型
model.eval()
with torch.no_grad():
    test_loss = 0
    for x, _ in testloader:
        x = x.view(-1, input_dim).to(device)
        x_hat, mu, logvar = model(x)
        test_loss += vae_loss(x_hat, x, mu, logvar, beta=1).item()
    test_loss /= len(testloader.dataset)
    print(f'Test Loss: {test_loss:.4f}')

在訓練過程中,我們使用Adam優化器和\beta=1的超參數來更新模型參數。在訓練完成後,我們使用測試集計算模型的損失函數。在本例中,我們使用重構誤差和KL散度來計算損失函數,因此測試損失越小,表示模型學習到的潛在表示越好,產生的樣本也越真實。

2.4 產生樣本

最后,我们可以使用VAE模型生成新的手写数字样本。生成样本的过程非常简单,只需要在潜在空间中随机采样,然后将采样结果输入到解码器中生成新的样本。代码如下:

# 生成新样本
n_samples = 10
with torch.no_grad():
    # 在潜在空间中随机采样
    z = torch.randn(n_samples, latent_dim).to(device)
    # 解码生成样本
    samples = model.decode(z).cpu()
    # 将样本重新变成图像的形状
    samples = samples.view(n_samples, 1, 28, 28)
    # 可视化生成的样本
    fig, axes = plt.subplots(1, n_samples, figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(samples[i][0], cmap='gray')
        ax.axis('off')
    plt.show()

在上述代码中,我们在潜在空间中随机采样10个点,然后将这些点输入到解码器中生成新的样本。最后,我们将生成的样本可视化展示出来,可以看到,生成的样本与MNIST数据集中的数字非常相似。

综上,我们介绍了VAE模型的原理、实现和应用,可以看到,VAE模型是一种非常强大的生成模型,可以学习到高维数据的潜在表示,并用潜在表示生成新的样本。

以上是變分自動編碼器:理論與實作方案的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:网易伏羲。如有侵權,請聯絡admin@php.cn刪除
外推指南外推指南Apr 15, 2025 am 11:38 AM

介紹 假設有一個農民每天在幾週內觀察農作物的進展。他研究了增長率,並開始思考他的植物在幾週內可以生長的高度。從Th

軟AI的興起及其對當今企業的意義軟AI的興起及其對當今企業的意義Apr 15, 2025 am 11:36 AM

軟AI(被定義為AI系統,旨在使用近似推理,模式識別和靈活的決策執行特定的狹窄任務 - 試圖通過擁抱歧義來模仿類似人類的思維。 但是這對業務意味著什麼

為AI前沿的不斷發展的安全框架為AI前沿的不斷發展的安全框架Apr 15, 2025 am 11:34 AM

答案很明確 - 只是雲計算需要向雲本地安全工具轉變,AI需要專門為AI獨特需求而設計的新型安全解決方案。 雲計算和安全課程的興起 在

生成AI的3種方法放大了企業家:當心平均值!生成AI的3種方法放大了企業家:當心平均值!Apr 15, 2025 am 11:33 AM

企業家,並使用AI和Generative AI來改善其業務。同時,重要的是要記住生成的AI,就像所有技術一樣,都是一個放大器 - 使得偉大和平庸,更糟。嚴格的2024研究O

Andrew Ng的新簡短課程Andrew Ng的新簡短課程Apr 15, 2025 am 11:32 AM

解鎖嵌入模型的力量:深入研究安德魯·NG的新課程 想像一個未來,機器可以完全準確地理解和回答您的問題。 這不是科幻小說;多虧了AI的進步,它已成為R

大語言模型(LLM)中的幻覺是不可避免的嗎?大語言模型(LLM)中的幻覺是不可避免的嗎?Apr 15, 2025 am 11:31 AM

大型語言模型(LLM)和不可避免的幻覺問題 您可能使用了諸如Chatgpt,Claude和Gemini之類的AI模型。 這些都是大型語言模型(LLM)的示例,在大規模文本數據集上訓練的功能強大的AI系統

60%的問題 -  AI搜索如何消耗您的流量60%的問題 - AI搜索如何消耗您的流量Apr 15, 2025 am 11:28 AM

最近的研究表明,根據行業和搜索類型,AI概述可能導致有機交通下降15-64%。這種根本性的變化導致營銷人員重新考慮其在數字可見性方面的整個策略。 新的

麻省理工學院媒體實驗室將人類蓬勃發展成為AI R&D的核心麻省理工學院媒體實驗室將人類蓬勃發展成為AI R&D的核心Apr 15, 2025 am 11:26 AM

埃隆大學(Elon University)想像的數字未來中心的最新報告對近300名全球技術專家進行了調查。由此產生的報告“ 2035年成為人類”,得出的結論是,大多數人擔心AI系統加深的採用

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.聊天命令以及如何使用它們
4 週前By尊渡假赌尊渡假赌尊渡假赌

熱工具

DVWA

DVWA

Damn Vulnerable Web App (DVWA) 是一個PHP/MySQL的Web應用程序,非常容易受到攻擊。它的主要目標是成為安全專業人員在合法環境中測試自己的技能和工具的輔助工具,幫助Web開發人員更好地理解保護網路應用程式的過程,並幫助教師/學生在課堂環境中教授/學習Web應用程式安全性。 DVWA的目標是透過簡單直接的介面練習一些最常見的Web漏洞,難度各不相同。請注意,該軟體中

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

MantisBT

MantisBT

Mantis是一個易於部署的基於Web的缺陷追蹤工具,用於幫助產品缺陷追蹤。它需要PHP、MySQL和一個Web伺服器。請查看我們的演示和託管服務。

SublimeText3 英文版

SublimeText3 英文版

推薦:為Win版本,支援程式碼提示!

mPDF

mPDF

mPDF是一個PHP庫,可以從UTF-8編碼的HTML產生PDF檔案。原作者Ian Back編寫mPDF以從他的網站上「即時」輸出PDF文件,並處理不同的語言。與原始腳本如HTML2FPDF相比,它的速度較慢,並且在使用Unicode字體時產生的檔案較大,但支援CSS樣式等,並進行了大量增強。支援幾乎所有語言,包括RTL(阿拉伯語和希伯來語)和CJK(中日韓)。支援嵌套的區塊級元素(如P、DIV),