首頁  >  文章  >  科技週邊  >  Soft Diffusion:Google新框架從通用擴散過程中正確調度、學習和取樣

Soft Diffusion:Google新框架從通用擴散過程中正確調度、學習和取樣

王林
王林轉載
2023-04-30 13:22:061337瀏覽

我們知道,基於分數的模型和去噪擴散機率模型(DDPM)是兩類強大的生成模型,它們透過反轉擴散過程來產生樣本。這兩類模型已經在 Yang Song 等研究者的論文《Score-based generative modeling through stochastic differential equations》中統一到了單一的框架下,並被廣泛地稱為擴散模型。

目前,擴散模型在包括影像、音訊、視訊生成以及解決逆問題等一系列應用中取得了巨大的成功。 Tero Karras 等研究者在論文《Elucidating the design space of diffusionbased generative models》中對擴散模型的設計空間進行了分析,並確定了3 個階段,分別為i) 選擇噪聲水平的調度,ii) 選擇網絡參數化(每個參數化產生不同的損失函數),iii) 設計取樣演算法。

近日,在Google研究院和UT-Austin 合作的一篇arXiv 論文《Soft Diffusion: Score Matching for General Corruptions》中,幾位研究者認為擴散模型仍有一個重要的步驟:損壞(corrupt)。一般來說,損壞是一個添加不同幅度雜訊的過程,對於 DDMP 還需要重縮放。雖然有人嘗試使用不同的分佈來進行擴散,但仍缺乏一個通用的框架。因此,研究者提出了一個用於更通用損壞過程的擴散模型設計框架。

具體地,他們提出了一個名為 Soft Score Matching 的新訓練目標和一種新穎的採樣方法 Momentum Sampler。理論結果表明,對於滿足正則條件的損壞過程,Soft Score MatchIng 能夠學習它們的分數(即似然梯度),擴散必須將任何影像轉換為具有非零似然的任何影像。

在實驗部分,研究者在 CelebA 以及 CIFAR-10 上訓練模型,其中在 CelebA 上訓練的模型實現了線性擴散模型的 SOTA FID 分數——1.85。同時與使用原版高斯去噪擴散訓練的模型相比,研究者訓練的模型速度顯著更快。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

#論文網址:https://arxiv.org/pdf/2209.05442.pdf

#方法概覽

通常來說,擴散模型透過反轉逐漸增加雜訊的損壞過程來產生影像。研究者展示如何學習對涉及線性確定性退化和隨機加性雜訊的擴散進行反轉。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

具體地,研究者展示了使用更通用損壞模型訓練擴散模型的框架,包含三個部分,分別為新的訓練目標Soft Score Matching、新穎採樣方法Momentum Sampler 與損壞機制的調度。

首先來看訓練目標 Soft Score Matching,這個名字的靈感來自於軟過濾,是一種攝影術語,指的是去除精細細節的過濾器。它以一種可證明的方式學習常規線性損壞過程的分數,還在網路中合併入了過濾過程,並訓練模型來預測損壞後與擴散觀察相匹配的圖像。

只要擴散將非零機率指定為任何乾淨、損壞的影像對,則該訓練目標可以證明學習到了分數。另外,當損壞中存在加性雜訊時,此條件總是可以被滿足。

具體地,研究者探討如下形式的損壞過程。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

在過程中,研究者發現雜訊在實證(即更好的結果)和理論(即為了學習分數)這兩方面都很重要。這也成為了其與反轉確定性損壞的並發工作 Cold Diffusion 的關鍵區別。

其次是采样方法 Momentum Sampling。研究者证明,采样器的选择对生成样本质量具有显著影响。他们提出了 Momentum Sampler,用于反转通用线性损坏过程。该采样器使用了不同扩散水平的损坏的凸组合,并受到了优化中动量方法的启发。

这一采样方法受到了上文 Yang Song 等人论文提出的扩散模型连续公式化的启发。Momentum Sampler 的算法如下所示。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

下图直观展示了不同采样方法对生成样本质量的影响。图左使用 Naive Sampler 采样的图像似乎有重复且缺少细节,而图右 Momentum Sampler 显著提升了采样质量和 FID 分数。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

最后是调度。即使退化的类型是预定义的(如模糊),决定在每个扩散步骤中损坏多少并非易事。研究者提出一个原则性工具来指导损坏过程的设计。为了找到调度,他们将沿路径分布之间的 Wasserstein 距离最小化。直观地讲,研究者希望从完全损坏的分布平稳过渡到干净的分布。

实验结果

研究者在 CelebA-64 和 CIFAR-10 上评估了提出的方法,这两个数据集都是图像生成的标准基线。实验的主要目的是了解损坏类型的作用。

研究者首先尝试使用模糊和低幅噪声进行损坏。结果表明,他们提出的模型在 CelebA 上实现了 SOTA 结果,即 FID 分数为 1.85,超越了所有其他仅添加噪声以及可能重缩放图像的方法。此外在 CIFAR-10 上获得的 FID 分数为 4.64,虽未达到 SOTA 但也具有竞争力。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

此外,在 CIFAR-10 和 CelebA 数据集上,研究者的方法在另一项指标采样时间上也表现更好。另一个额外的好处是具有显著的计算优势。与图像生成去噪方法相比,去模糊(几乎没有噪声)似乎是一种更有效的操纵。

下图展示了 FID 分数如何随着函数评估数量(Number of Function Evaluations, NFE)而变。从结果可以看到,在 CIFAR-10 和 CelebA 数据集上,研究者的模型可以使用明显更少的步骤来获得与标准高斯去噪扩散模型相同或更好的质量。

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

以上是Soft Diffusion:Google新框架從通用擴散過程中正確調度、學習和取樣的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除