李飛飛團隊新作:DiT不訓練直接改架構,模型深度減半,質量還提高了
模型架構設計在機器學習中扮演著核心角色,與數據、算法、算力和基準測試一樣重要。它定義了模型函數、算子選擇(如注意力機制、卷積)和配置設定(如模型深度、寬度)等等模型要素。
盡管如此,由于從頭訓練模型的成本過高 —— 尤其人們難以獲得關于架構設計的深刻洞見(即哪些方案有效、哪些無效)。因此,研究新架構仍是一項挑戰,對生成模型而言尤為如此。
在本文中,來自斯坦福大學、 Liquid AI 等機構的研究者探索了這一問題,即對預訓練模型進行架構編輯來研究新架構。
- 論文鏈接:https://arxiv.org/pdf/2506.05340v1
- 論文主頁:https://grafting.stanford.edu/
- 論文標題: Exploring Diffusion Transformer Designs via Grafting
具體而言,該研究提出了一種編輯預訓練擴散 transformer(DiT)的簡單方法,即 Grafting(嫁接),該方法可以在較小的計算預算下實現新的架構。
嫁接過程如下:
- 激活蒸餾:此階段通過回歸目標(regression objective)蒸餾原始算子的激活特征,將其功能遷移至新算子。該階段核心在于實現算子間的功能傳遞。
- 輕量級調優:此階段通過使用有限的數據進行調優,減輕了由于集成多個新算子而導致的誤差傳播。
此外,架構編輯還涵蓋多種策略,如添加、刪除和替換算子。
本文還基于 DiT-XL/2 構建了一個測試平臺,以研究嫁接對模型質量的影響。
利用該測試平臺,本文通過嫁接技術開發了一系列混合設計:用門控卷積、局部注意力和線性注意力取代 Softmax 注意力,用可變擴展率和卷積變體取代 MLP。
值得注意的是,許多混合設計使用不到 2% 的預訓練計算資源就實現了良好的質量(FID:2.38–2.64,而 DiT-XL/2 為 2.27)。然后,本文嫁接了一個文本轉圖像模型 (PixArt-Σ),實現了 1.43 倍的加速,而 GenEval 分數下降不到 2%。
最后,本文展示了一個案例研究,該研究通過嫁接技術將每對序列 Transformer 模塊轉換為并行模塊,從而重構了 DiT-XL/2。這將模型深度減少到原來一半,并獲得了比其他同等深度模型更高的質量(FID:2.77)。
總而言之,該研究展示了可以通過預訓練 DiT 來探索新的擴散模型設計,其修改范圍涵蓋從算子替換到架構重構。
一、嫁接擴散 Transformer
1. 兩階段嫁接方法
嫁接旨在通過編輯預訓練模型的計算圖來實現新架構。由于該研究專注于用替代方案替換現有算子,這引出了兩個問題:
問題 1:在將新算子集成到計算圖之前,應該如何初始化?
對應第一階段:通過激活蒸餾進行初始化。由于 DiT 的激活是連續且平滑的,這可以被視為一個回歸問題:
問題 2:當多個算子集成到計算圖時,如何減輕誤差傳播?
對應第二階段:輕量級調優。隨著更多算子被替換,初始化誤差會不斷傳播,導致與預訓練模型的行為出現偏差。
本文采用端到端微調來緩解階段 1 的累積誤差。微調目標函數如公式 1 所示。
實踐中,本文發現,即使替換 DiT-XL/2 中的所有 MHA 或 MLP 層,僅使用 10% 的訓練數據也能恢復競爭性能。
2. 自嫁接基準
在研究新的架構設計之前,該研究引入了自嫁接(self-grafting),這是一種簡單的對照設置:將現有算子(如 MHA、MLP)替換為相同類型但權重隨機初始化的算子。這樣可以保持計算圖的結構 —— 包括算子類型和參數數量 —— 但改變了具體的計算過程。自嫁接有三方面作用:(1)評估在不改變架構的情況下嫁接流程本身的效果;(2)為比較不同的替換方案提供一個性能基準;(3)研究影響性能的因素,如數據規模、回歸目標和超參數。
3. 激活行為分析以及自嫁接結果
本文首先分析了 DiT-XL/2 層中的 MHA 和 MLP 算子激活行為。在這兩種情況下,本文觀察到激活值存在較大差異,尤其是在較深的層中(表 1 (i, ii))。
經過分析,本文得出通過選擇特定于算子的回歸目標,可以實現高質量的初始化。
如表 1 (iii,iv) 所示,回歸目標的選擇會影響性能。對于 MHA,L1 實現了最佳 FID(2.51),其次是 Huber(2.55)和 L2(2.58)。對于 MLP,L2 表現最佳(2.33),而 L1 表現不佳(2.83);值得注意的是,MLP 的參數量是 MHA 的 2 倍。
這表明高質量的初始化需要量身定制的、激活感知的策略。
研究還發現,使用 10% 的數據進行完全自嫁接可實現接近基線的性能。表明在適度的數據和計算預算下完全自嫁接是可行的。
二、實驗
1. 實驗一:通過嫁接實現混合架構
本節實驗圍繞這個問題進行:當現有算子被高效的替代方案取代時,我們能否保持模型質量?
為了探究這個問題,本文研究了以下嫁接過程:
- 待替換算子的類型 ——MHA 或 MLP;
- 替換算子的類型 —— 例如卷積;
- 層選擇策略 —— 替換所有層中的算子或使用啟發式選擇;
- 替換率 —— 全部替換或部分替換。
為了實驗,該研究構建了一個測試平臺,并提出兩種層選擇策略:完全替換和交錯替換。測試平臺詳見表 3。
此外,該研究還引入了 Hyena-X 和 Hyena-Y 兩種新的高效門控卷積算子,并設計為 MHA 的直接替代品。Figure 3 展示了它們的結構。
MHA 結果。通過嫁接替換 DiT-XL/2 中的 MHA 算子,獲得了良好的質量 - 效率權衡。主要發現如下:
- 在交錯嫁接下,較小的感受野表現出驚人的效果。實驗發現,在 50% 交錯替換比例下,滑動窗口注意力(SWA)、Hyena-X/Y 和 Mamba-2 等替代方案均能保持 FID 分數與基線(2.27)差距在 0.5 以內。尤其值得注意的是,盡管 SWA 和 Hyena 變體的感受野有限(卷積核 K=4 / 窗口 w=4),其 FID 下降幅度卻極小。
- 替換策略:交錯替換 vs. 完全替換。將交錯替換比例從 50% 提升至 75% 時,性能通常下降,但 SWA 在 75% 交錯替換下仍有效(FID=3.09)。100% 替換時,性能急劇惡化(所有 FID > 75),這與局部性分析一致,表明只有部分層是局部且適合嫁接的。
數據規模和層選擇的消融實驗結果。
MLP 結果顯示通過嫁接的方式替換 MLP 算子是有效的。
經過實驗,得出要點 1:嫁接對于在較小的計算預算下構建具有良好生成質量的高效混合架構非常有效。交錯設計尤其有效。
2. 實驗二:通過嫁接改進文本到圖像的擴散 Transformers
結果。嫁接模型在實時計算速度(wall-clock time)上實現了 1.43 倍的提升,同時生成評估分數(GenEval)僅出現小幅下降(47.78 vs. 49.75)。特定屬性的指標(Attribute-specific metrics)基本保持可比,并且定性樣本也展現出良好的對齊度和質量。在一些紋理區域觀察到了局部性的失真(artifacts),這可能是由于 LoRA 的適應能力以及所使用的合成數據質量不高所致(失敗案例詳見圖 D.3,D.4)
要點 2:在文生圖 DiTs 中成功應用嫁接技術,構建的混合架構在實現顯著加速的同時,生成質量損失極小。
了解更多內容,請參考原論文。