一秒十圖!英偉達MIT聯手刷新SOTA,一步擴散解鎖實時高質量可控圖像生成
擴散生成模型通常需要50-100次迭代去噪步驟,效率很低,時間步蒸餾技術可以極大提高推理效率,「基于分布的蒸餾」方法,如生成對抗網絡GAN及其變分分數蒸餾VSD變體,以及「基于軌跡的蒸餾方法」(如直接蒸餾、漸進蒸餾、一致性模型)可以實現10-100倍的圖像生成加速效果。
但仍然存在一些關鍵難點,比如基于GAN的方法由于對抗動態的振蕩特性和模式坍塌問題,訓練過程不穩定;基于VSD的方法需要聯合訓練一個額外的擴散模型,增加了計算開銷;一致性模型雖然穩定,但在極少數步驟(例如少于4步)的情況下,生成質量會下降。
如何開發一個能夠兼顧效率、靈活性和質量的蒸餾框架成了模型部署的關鍵。
圖片
論文地址:https://arxiv.org/pdf/2503.09641
項目主頁:https://nvlabs.github.io/Sana/Sprint/
基于OpenAI提出的連續時間一致性模型(sCM)的方法,研究人員提出SANA-Sprint,進一步結合了LADD的對抗蒸餾技術,幫助模型在蒸餾過程中更好地保留細節信息,從而實現超快速且高質量的文本到圖像生成,同時避免了離散化帶來的誤差,保留了傳統一致性模型的優勢。
SANA-Sprint的核心在于其創新的混合蒸餾框架和對ControlNet的集成,主要貢獻包括:
1.混合蒸餾框架:設計了一種新穎的混合蒸餾框架,將預訓練的流匹配模型無縫轉換為TrigFlow模型,集成了連續時間一致性模型(sCM)和潛在對抗擴散蒸餾(LADD)。
sCM確保了模型與教師模型的一致性和多樣性保留,而LADD則增強了單步生成的保真度,從而實現了統一的步長自適應采樣。
圖片
2.卓越的速度/質量權衡:SANA-Sprint僅需1-4步即可實現卓越的性能。在H100上,SANA-Sprint僅需0.10-0.18秒即可生成1024x1024的圖像,在MJHQ-30K數據集上實現了7.59的FID和0.74的GenEval分數,超越了FLUX-schnell(7.94FID/0.71GenEval),速度提升了10倍。
圖片
3.實時交互式生成:通過將ControlNet與SANA-Sprint集成,實現了在H100上僅需0.25秒的實時交互式圖像生成。這為需要即時視覺反饋的應用(如ControlNet引導的圖像生成/編輯)提供了可能,實現了更好的人機交互。
圖片
SANA-Sprint不僅在速度和性能上表現出色,生成的圖像質量也非常高。
圖片
SANA-Sprint
SANA-Sprint方法主要包括以下四個關鍵步驟:
1. 無訓練轉換到TrigFlow
研究人員提出了一種簡單的方法,通過直接的數學輸入和輸出轉換,將預訓練的流匹配模型轉換為TrigFlow模型。這使得可以直接使用已有的預訓練模型,無需額外的TrigFlow模型的訓練。
動機是,雖然sCM使用TrigFlow公式簡化了連續時間一致性模型的訓練,但大多數基于分數的生成模型(如擴散模型和流匹配模型)并不直接支持TrigFlow。
為了克服這一挑戰,SANA-Sprint提出了一種無需重新訓練的轉換方法,通過數學變換將流匹配模型轉換TrigFlow模型,從而避免了復雜的額外算法設計和額外的計算成本。
2. 混合蒸餾策略
混合蒸餾策略結合了sCM和LADD兩種蒸餾方法。sCM利用TrigFlow的公式簡化了連續時間一致性模型的訓練,而LADD則通過對抗訓練在潛在空間中直接進行判別,進一步提升了生成質量。
3. 穩定訓練的關鍵技術
- 密集時間嵌入(Dense Time-Embedding):為了穩定連續時間一致性模型的訓練,SANA-Sprint采用了密集時間嵌入設計。通過將噪聲系數
調整為
- Query-Key歸一化(QK-Normalization):在Transformer模型的自注意力和交叉注意力機制中引入了RMS歸一化,進一步穩定了訓練過程,尤其是在大模型和高分辨率場景下。
4. 集成ControlNet
將SANA-Sprint的訓練流程應用于ControlNet任務,利用圖像和文本提示作為條件,實現了SANA-ControlNet模型,并通過蒸餾得到SANA-Sprint-ControlNet,支持實時的圖像編輯和生成。
實驗結果
研究人員采用了兩階段的訓練策略,詳細的設置和評估協議在論文附錄中進行了概述。
教師模型通過剪枝和微調SANA-1.5 4.8B模型得到,然后使用文中提出的訓練范式進行蒸餾,使用包括FID、MJHQ-30K上的CLIP Score和GenEval在內的指標評估性能。
實驗結果表明,SANA-Sprint在速度和質量方面均達到了最先進的水平。
- 效率與性能對比:在4步推理下,SANA-Sprint 0.6B實現了5.34個樣本/秒的吞吐量和0.32秒的延遲,FID為6.48,GenEval為0.76;SANA-Sprint 1.6B 的吞吐量略低(5.20個樣本/秒),但GenEval提升至0.77,優于更大的模型如FLUX-schnell 12B,其吞吐量僅為0.5個樣本/秒,延遲為2.10秒。
- 單步生成性能:SANA-Sprint在單步生成方面也表現出色,實現了7.59的FID和0.74的GenEval分數,超越了其他單步生成方法。
- 實時交互式生成:集成ControlNet的SANA-Sprint模型在H100上實現了約200毫秒的推理速度,支持近乎實時的交互。
結論與展望
SANA-Sprint是一款高效的擴散模型,用于超快速的單步文本到圖像生成,同時保留了多步采樣的靈活性。通過采用結合了連續時間一致性蒸餾(sCM)和潛在對抗蒸餾(LADD)的混合蒸餾策略,SANA-Sprint在一步內實現了7.59的FID和0.74的GenEval分數,無需針對特定步驟進行訓練。
該統一的步長自適應模型僅需0.1秒即可在H100上生成高質量的1024x1024圖像,在速度和質量的權衡方面樹立了新的標桿。
展望未來,SANA-Sprint的即時反饋特性將為實時交互應用(如響應迅速的創意工具和AIPC)開啟新的可能性。
參考資料: