英偉達港大聯手革新視覺注意力機制!GSPN高分辨率生成加速超84倍
視覺注意力機制,又有新突破,來自香港大學和英偉達。
Transformer的自注意力在NLP和計算機視覺領域表現出色——它能捕捉遠距離依賴,構建深度上下文。然而,面對高分辨率圖像時,傳統自注意力有兩個大難題:
- 計算量巨大:O(N2) 的復雜度讓處理長上下文變得非常耗時。
- 破壞空間結構:將二維圖像拉平成一維序列,會丟失像素之間的空間關系。
雖然線性注意力和Mamba等方法能把復雜度降到O(N),但它們還是把圖像當作一維序列處理,無法真正利用二維空間信息。
為此,香港大學與英偉達聯合推出了廣義空間傳播網絡(GSPN)。
GSPN采用二維線性傳播,結合“穩定性–上下文條件”,將計算量從 O(N2) 或 O(N) 再降到√N級別,并完整保留圖像的空間連貫性。這樣,不僅大幅提升了效率,還在多個視覺任務上刷新了性能紀錄。
兼具空間連貫性和計算效率
GSPN的核心技術是二維線性傳播與穩定性-上下文條件,基于此,現有注意力機制與GSPN的對比如下:
- 傳統Transformer:將圖像視為一維序列,計算復雜度為O(N2),處理高分辨率圖像效率低下,且忽略空間連貫性。
- 線性注意力和狀態空間模型(如Mamba):雖將計算復雜度降至O(N),但仍抽象掉了視覺任務關鍵的空間結構。
- GSPN:直接對二維圖像進行線掃描,通過穩定性-上下文條件確保傳播穩定性和長距離依賴,有效序列長度降至√N,兼具空間連貫性和計算效率。
二維線性傳播:從行列并行到密集連接
作為GSPN的核心組件,二維線性傳播包括兩個關鍵點:
- 線掃描機制
對于二維圖像,二維線性傳播通過逐行或逐列的順序處理進行其遵循線性循環過程,隱藏層通過前一行的隱藏狀態和當前輸入計算得出:
將上述公式按展開,可以得到向量化的輸入與一個下三角矩陣的乘積,輸出則為輸入的加權和。其中與注意力機制里的大矩陣意義類似—都描述了任意兩個像素之間的連接關系。
在傳播過程中,研究人員并不對所有像素做全連接,而是只在前一行的左、中、右三個相鄰像素之間建立三向連接。這樣既能大幅減少參數量,又能保證信息像全連接那樣完整傳播。
GSPN有兩種變體,一種捕捉整個輸入的全局上下文(下圖左),另一種專注于局部區域以實現更快的傳播(下圖右)。這些變體使GSPN能夠作為現有注意力模塊的直接替代品無縫集成到現代視覺架構中。
- 元素之間的密集鏈接
那么上述三連接掃描方式是否能形成像注意力機制的那樣的密集鏈接呢?
下面的示意圖給出了肯定答案。通過左→右、右→左、上→下、下→上四個方向的掃描,所有的像素都會和其他任意像素產生連接。
此外,GSPN引入了一個可學習的合并器,用于聚合來自所有掃描方向的空間信息,增強了模型動態適應視覺數據二維結構的能力。
研究人員發現GSPN天然對位置信息敏感,無需位置嵌入,避免了常見的混疊問題。
穩定性-上下文條件:確保長距離傳播的可靠性
眾所周知,線性系統容易出現不穩定。
為了讓 GSPN在長距離傳播時既穩定又高效,研究人員提出了定理1和定理 2(統稱“穩定性–上下文條件”)。推導結果表明(過程見GSPN附錄),只要把傳播矩陣做成行隨機矩陣(元素非負、每行之和為1),就能在保證信息不丟失、信號不過度放大或衰減的前提下,維持系統穩定。
更簡單的是,只需在CUDA內核外加個Sigmoid激活,再對每行做歸一化,就能輕松實現這一點。
GSPN引入了一個根據線性傳播構建的序列長度為的全新內核。無需依賴ViT的大規模矩陣乘法,也不同于Mamba的并行方案。該核能在批量樣本、所有通道以及與傳播方向垂直的行/列上一次性并行執行,極大地減少循環開銷,實現了快速且可擴展的線性傳播。
模塊化設計
與ViT和Mamba類似,研究人員推薦了以下GSPN模塊化設計。
- GSPN模塊通過共享1×1卷積進行降維,再通過三個獨立的1×1卷積生成依賴于輸入的參數,用于二維線性傳播,這些投影和傳播封裝在模塊化的GSPN單元中。
- 圖像分類架構采用Swin-Transformer的四級分層架構,通過堆疊設計良好的GSPN塊,在相鄰層級間進行下采樣操作,平衡計算效率和表示能力。
- 類條件圖像生成架構重新設計生成架構,通過向量嵌入加法集成時間步和條件信息,包含跳躍連接和線性投影,去除位置嵌入并引入FFN進行通道混合。
- 文本到圖像生成架構將GSPN模塊直接集成到Stable Diffusion架構中,替換所有自注意力層,利用預訓練權重初始化參數,加速訓練。
需要注意的是,GSPN本身是一個獨立的注意力機制層,可以非常靈活的用在任何視覺網絡中。比如在Stable Diffusion中,研究人員就保留了大部分網絡結構甚至訓練參數,直接將注意力層替換為GSPN,仍然取得了非常好的效果。
實驗:從理解到高分辨率生成的全面領先
具體來看實驗結果。
在圖像分類領域,GSPN實現了效率與精度雙優。
在ImageNet中,GSPN-T在5.3 GFLOPs計算量下,Top-1準確率達82.2%,超越LocalVMamba-T(81.9%)和ViT類模型,參數效率提升顯著。
圖像生成方面,在類條件生成任務中,GSPN-XL/2在ImageNet 256×256任務中以65.6%參數實現FID 3.2,優于DiT-XL/2(FID 3.5),生成速度提升1.5倍。
文本到圖像生成任務上,在SD-XL模型中,生成16K×8K圖像的推理時間加速超84倍,且在未見分辨率外推的場景下FID分數(30.86)優于基線(32.71)。
值得一提的是,GSPN具備兩個優勢:
- 任意尺寸兼容:基于穩定性-上下文條件的歸一化權重,GSPN可直接處理2K至16K分辨率圖像,無需像Mamba那樣依賴額外歸一化層。
- 實時生成場景:單卡支持16K分辨率生成,適用于電影特效、虛擬場景搭建等對高分辨率和速度敏感的領域。
這使其具備了從學術到工業的落地潛力。
總結來說,GSPN通過二維結構感知和線性復雜度設計,重新定義了視覺注意力機制的范式。
其在保持空間連貫性的同時實現計算效率躍升,尤其在高分辨率生成任務中的突破,為多模態模型和實時視覺應用提供了新方向。
論文: https://arxiv.org/abs/2501.12381
項目主頁: https://whj363636.github.io/GSPN/
代碼:https://github.com/NVlabs/GSPN