首頁  >  文章  >  科技週邊  >  顏水成/程明明新作! Sora核心元件DiT訓練加速10倍,Masked Diffusion Transformer V2開源

顏水成/程明明新作! Sora核心元件DiT訓練加速10倍,Masked Diffusion Transformer V2開源

王林
王林轉載
2024-03-13 17:58:18412瀏覽

作為Sora引人注目的核心技術之一,DiT利用Diffusion Transformer將生成模型擴展到更大的規模,從而實現出色的影像生成效果。

然而,更大的模型規模導致訓練成本飆升。

Sea AI Lab、南開大學、崑崙萬維2050研究院的顏水成和程明明研究團隊在ICCV 2023會議上提出了一個名為Masked Diffusion Transformer的新模型。該模型利用mask建模技術,透過學習語意表徵資訊來加快Diffusion Transfomer的訓練速度,並在影像生成領域中取得了SoTA的效果。這項創新為圖像生成模型的發展帶來了新的突破,為研究者提供了一個更有效率的訓練方法。透過結合不同領域的專業知識和技術,研究團隊成功地提出了一種能夠提高訓練速度並改善產生效果的解決方案。他們的工作為人工智慧領域的發展貢獻了重要的創新思路,為未來的研究和實踐提供了有益的啟

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源圖片

##論文網址:https://arxiv.org/abs/2303.14389

GitHub網址:https://github.com/sail-sg/MDT颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

#近日,Masked Diffusion Transformer V2再次刷新SoTA, 相比DiT的訓練速度提升10倍以上,並實現了ImageNet benchmark 上1.58的FID score。

最新版本的論文和程式碼都已開源。

背景

儘管以DiT 為代表的擴散模型在影像生成領域取得了顯著的成功,但研究者發現擴散模型往往難以有效率地學習影像中物體各部分之間的語意關係,這一限制導致了訓練過程的低收斂效率。

圖片

#例如上圖所示,DiT在第50k次訓練步驟時已經學會生成狗的毛髮紋理,然後在第200k次訓練步驟時才學會生成狗的一隻眼睛和嘴巴,但是卻漏生成了另一隻眼睛。

即使在第300k次訓練步驟時,DiT產生的狗的兩隻耳朵的相對位置也不是非常準確。

這個訓練學習過程揭示了擴散模型未能有效率地學習到影像中物體各部分之間的語意關係,而只是獨立地學習每個物體的語意資訊。 颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

研究者推測這一現象的原因是擴散模型透過最小化每個像素的預測損失來學習真實影像資料的分佈,這個過程忽略了影像中物體各部分之間的語意相對關係,因此導致模型的收斂速度緩慢。

方法:Masked Diffusion Transformer

######受到上述觀察的啟發,研究者提出了Masked Diffusion Transformer (MDT) 提高擴散模型的訓練效率和生成品質。 ############MDT提出了一個針對Diffusion Transformer 設計的mask modeling表徵學習策略,以明確地增強Diffusion Transformer對上下文語義資訊的學習能力,並增強圖像中物體之間語意資訊的關聯學習。 ##################圖片################如上圖所示,MDT在維持擴散訓練過程的同時引入mask modeling學習策略。透過mask部分加雜訊的圖像token,MDT利用一個非對稱Diffusion Transformer (Asymmetric Diffusion Transformer) 架構從未被mask的加噪聲的圖像token預測被mask部分的圖像token,從而同時實現mask modeling 和擴散訓練過程。 ##########

在推理過程中,MDT仍維持標準的擴散生成過程。 MDT的設計有助於Diffusion Transformer同時具有mask modeling表徵學習帶來的語意資訊表達能力和擴散模型對影像細節的生成能力。

具體而言,MDT透過VAE encoder將圖片對應到latent空間,並在latent空間中處理以節省計算成本。

在訓練過程中,MDT先mask掉部分加雜訊後的影像token,並將剩餘的token送入Asymmetric Diffusion Transformer來預測去雜訊後的全部影像token。

Asymmetric Diffusion Transformer架構

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源圖片

如上圖所示,Asymmetric Diffusion Transformer架構包含encoder、side-interpolater(輔助插值器)和decoder。

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源圖片

在訓練過程中,Encoder只處理未被mask的token;而在在推理過程中,由於沒有mask步驟,它會處理所有token。

因此,為了確保在訓練或推理階段,decoder總是能處理所有的token,研究者提出了一個方案:在訓練過程中,透過一個由DiT block組成的輔助插值器(如上圖所示),從encoder的輸出中插值預測出被mask的token,並在推理階段將其移除因而不增加任何推理開銷。

MDT的encoder和decoder在標準的DiT block中插入全域和局部位置編碼資訊以幫助預測mask部分的token。

Asymmetric Diffusion Transformer V2

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源圖片

如上圖所示,MDTv2透過引入了一個針對Masked Diffusion過程設計的更為高效的宏觀網路結構,進一步優化了diffusion和mask modeling的學習過程。

這包括在encoder中融合了U-Net式的long-shortcut,在decoder中整合了dense input-shortcut。

其中,dense input-shortcut將添加噪後的被mask的token送入decoder,保留了被mask的token對應的噪聲信息,從而有助於diffusion過程的訓練。

此外,MDT還引入了包括採用更快的Adan優化器、time-step相關的損失權重,以及擴大掩碼比率等更優的訓練策略來進一步加速Masked Diffusion模型的訓練過程。

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源實驗結果

ImageNet 256基準產生品質比較

圖片

上表比較了不同模型尺寸下MDT與DiT在ImageNet 256基準下的效能比較。

顯而易見,MDT在所有模型規模上都以較少的訓練成本實現了更高的FID分數。

MDT的參數和推理成本與DiT基本一致,因為如前文所介紹的,MDT推理過程中仍保持與DiT一致的標準的diffusion過程。 ############對於最大的XL模型,經過400k步驟訓練的MDTv2-XL/2,顯著超過了經過7000k步驟訓練的DiT-XL/2,FID分數提高了1.92。在這一setting下,結果顯示了MDT相對DiT有約18倍的訓練加速。 ##########

对于小型模型,MDTv2-S/2 仍然以显著更少的训练步骤实现了相比DiT-S/2显著更好的性能。例如同样训练400k步骤,MDTv2以39.50的FID指标大幅领先DiT 68.40的FID指标。

更重要的是,这一结果也超过更大模型DiT-B/2在400k训练步骤下的性能(39.50 vs 43.47)。

ImageNet 256基准CFG生成质量比较

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源图片

我们还在上表中比较了MDT与现有方法在classifier-free guidance下的图像生成性能。

MDT以1.79的FID分数超越了以前的SOTA DiT和其他方法。MDTv2进一步提升了性能,以更少的训练步骤将图像生成的SOTA FID得分推至新低,达到1.58。

与DiT类似,我们在训练过程中没有观察到模型的FID分数在继续训练时出现饱和现象。

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源MDT在PaperWithCode的leaderboard上刷新SoTA

收敛速度比较

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源图片

上图比较了ImageNet 256基准下,8×A100 GPU上DiT-S/2基线、MDT-S/2和MDTv2-S/2在不同训练步骤/训练时间下的FID性能。

得益于更优秀的上下文学习能力,MDT在性能和生成速度上均超越了DiT。MDTv2的训练收敛速度相比DiT提升10倍以上。

MDT在训练步骤和训练时间方面大相比DiT约3倍的速度提升。MDTv2进一步将训练速度相比于MDT提高了大约5倍。

例如,MDTv2-S/2仅需13小时(15k步骤)就展示出比需要大约100小时(1500k步骤)训练的DiT-S/2更好的性能,这揭示了上下文表征学习对于扩散模型更快的生成学习至关重要。

总结&讨论

MDT通过在扩散训练过程中引入类似于MAE的mask modeling表征学习方案,能够利用图像物体的上下文信息重建不完整输入图像的完整信息,从而学习图像中语义部分之间的关联关系,进而提升图像生成的质量和学习速度。

研究者认为,通过视觉表征学习增强对物理世界的语义理解,能够提升生成模型对物理世界的模拟效果。这正与Sora期待的通过生成模型构建物理世界模拟器的理念不谋而合。希望该工作能够激发更多关于统一表征学习和生成学习的工作。

参考资料:

https://arxiv.org/abs/2303.14389

以上是顏水成/程明明新作! Sora核心元件DiT訓練加速10倍,Masked Diffusion Transformer V2開源的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除