上交最新時空預測模型PredFormer,純Transformer架構,多個數據集取得SOTA效果
今天給大家介紹一篇時空預測最新模型PredFormer,由上海交大等多所高校發表,采用純Transformer模型結構,在多個數據集中取得SOTA效果。
1.背景
時空預測學習是一個擁有廣泛應用場景的領域,比如天氣預測,交通流預測,降水預測,自動駕駛,人體運動預測等。
提起時空預測,不得不提到經典模型ConvLSTM和最經典的benchmark moving mnist,在ConvLSTM時代,對于Moving MNIST的預測存在肉眼可見的偽影和預測誤差。而在最新模型PredFormer中,對Moving MNIST的誤差達到肉眼難以分辨的近乎完美的預測結果。
在以前的時空預測工作中,主要分為兩個流派,基于循環(自回歸)的模型,以ConvLSTM/PredRNN//E3DLSTM/SwinLSTM/VMRNN等工作為代表;更近年來,研究者提出無需循環的SimVP框架,由CNN Encoder-Decoder結構和一個時間轉換器組成,以SimVP/TAU/OpenSTL等工作為代表。
RNN。系列模型的缺陷在于,無法并行化,自回歸速度慢,顯存占用高,效率低;CNN系列模型無需循環提高了效率,得益于歸納偏置,但往往以犧牲泛化性和可擴展性為代價,模型上限低。
于是作者提出了問題,時空預測,真的需要RNN嗎?真的需要CNN嗎?是否能夠設計一個模型,可以自動地學習數據中的時空依賴,而不需要依賴于歸納偏置呢?
一個直覺的想法是利用Transformer,因為它在各種視覺任務中的廣泛成功,并且是RNN和CNN的有力替代者。在此前的時空預測工作中,已有研究者把Transformer嵌入到上述兩種框架中,比如SwinLSTM(ICCV23)融合了Swin Transformer和LSTM,比如OpenSTL(NeurIPS23)把各種MetaFormer結構(比如ViT,Swin Transformer等)作為SimVP框架中的時間轉換器。但是,純Transformer結構的網絡鮮有探索。
但純Transformer模型的挑戰在于,如何在一個框架中同時處理時間和空間信息。一個簡單的想法是合并空間序列和時間序列,計算時空全注意力,由于Transformer的計算復雜度是序列長度的二次復雜度,這樣的做法會導致計算復雜度較大。
在這篇文章中,作者提出了用于時空預測學習的新框架PredFormer,這是一個純ViT模型,既沒有自回歸也沒有任何卷積。作者利用精心設計的基于門控Transfomer模塊,對3D Attention進行了全面的分析,包括時空全注意力,時空分解的注意力,和時空交錯的注意力。PredFormer 采用非循環、基于Transformer的設計,既簡單又高效,更少參數量,Flops,更快推理速度,性能顯著優于以前的方法。在合成和真實數據集上進行的大量實驗表明,PredFormer 實現了最先進的性能。在 Moving MNIST 上,PredFormer 相對于 SimVP 實現了 51.3% 的 MSE 降低,突破性地達到11.6。對于 TaxiBJ,該模型將 MSE 降低了 33.1%,并將 FPS 從 533 提高到 2364。此外,在 WeatherBench 上,它將 MSE 降低了 11.1%,同時將 FPS 從 196 提高到 404。這些準確度和效率方面的性能提升證明了 PredFormer 在實際應用中的潛力。
2.實現方法
PredFormer模型遵循標準ViT的設計,先對輸入進行Patch Embedding,把輸入為[B, T, C, H, W]的時空序列轉換為[B, T, N, D]的張量。在位置編碼環節,作者采用了不同于一般ViT設計的可學習的位置編碼,而是采用了基于sin函數的絕對位置編碼,作者在消融實驗中進一步闡述了絕對位置編碼在時空任務中的優越性。
PredFormer的編碼器部分,由門控Transfomer模塊以不同的方式堆疊而成。由于編碼器部分是純Transformer結構,沒有任何卷積,也沒有分辨率的下降,每一個門控Transformer模塊都建模了全局信息,這允許模型只需使用一個簡單的解碼器就可以構成一個性能強大的預測模型。作者采用了一個線性層作為解碼器來進行Patch Recovery,這讓模型的輸出從[B, T, N, D]恢復到[B, T, C, H, W]。
不同于標準Transformer模型采用MLP作為FFN,PredFormer采用了Gated Linear Unit(GLU)作為FFN,這是受GLU在NLP任務中優于MLP啟發的改進。作者在消融實驗中進一步闡述了GLU相比于MLP在時空任務上的優越性。
作者對3D Attention進行了全面的分析,并提出了9種PredFormer變體。在以前用于視頻分類的Video ViT設計中,TimesFormer(ICML21), ViviT(ICCV21), TSViT(CVPR23)等工作也對時空分解進行了分析,但是TimesFormer是在self-attention層面進行分解,也就是spatial attention和temporal attention共用一個MLP。ViviT則是提出了在Encoder層面(先空間后時間),self-attention層面和head層面進行時空分解。而TSViT發現先時間后空間的Encoder對衛星序列圖像分類更有效。
不同于以上工作,PredFormer是在Gated Transformer Block(GTB)層面(多了基于Gated Linear Unit)進行時空分解。對時間和空間的self-attention都加GLU是至關重要的,因為它可以讓學習到的特征互相作用并且增強非線性。
PredFormer提出了時空全注意力Encoder,時間在前和空間在前的2種分解Encoder和6種新穎的時空交錯的Encoder,一共9種模型。PredFormer提出了PredFormer Layer的概念,即一個既能建模空間信息,又能建模時間信息的最小單元。基于這種想法,作者提出了三種基本范式,二元組(由一個Temporal GTB和一個Spatial GTB組成,有T-S和S-T兩種方式),三元組(T-S-T和S-T-S),四元組(兩個二元組以相反的方向重組)。
這一設計源于不同的時空預測任務往往有著不同的空間分辨率和時間分辨率(時間間隔以及變化程度),這意味著不同的數據集上對時間信息和空間信息的依賴程度不同,作者設計了這些模型以提高PredFormer模型在不同任務上的適應性。
3.實驗效果
在實驗部分,作者控制了提出的每種變體使用相同的GTB數目,這可以保證模型的參數量基本一致,從而對比不同模型的性能。
實驗發現了一些規律,(1)時間在前的分解Encoder模型優于時空全注意力模型,由于空間在前的分解Encoder模型 (2)時空交錯的6種模型在大多數任務上表現都很好,都能達到sota,但最優模型因為數據集本身的不同時空依賴特性而不同,這體現了PredFormer這種框架和時空交錯設計的優勢 (3)作者在討論環節提出了建議,在其他的時空預測任務上,從四元組-TSST開始嘗試,因為這個模型在三個數據集上都表現sota,先調整M個TSST(即4M個門控Transformer)的M參數,然后嘗試M個TST和M個STS以確定數據集是時間依賴更強或空間依賴更強的模型。得益于Transformer構架的可擴展性,不同于SimVP框架的CNN Encoder-Decoder模型,對spatial和temporal的hidden dim以及block數都設置了不同的值,PredFormer對spatial和temporal GTB采用相同的固定的參數,因此只需要調整M的值,在比較少次數的調整后就可以達到最優性能。
ViT模型的訓練通常要求較大的數據集,在時空預測任務上,大多數據集在幾千到幾萬的量級,數據集少,因此很容易過擬合。作者還探索了不同的正則化策略,包括dropout和drop path,通過廣泛的消融實驗,作者發現同時使用dropout和uniform的drop path(不同于一般ViT使用線性增加的drop path rate)會產生最優的模型效果。
作者還進行了可視化比較,可以看到,在PredFormer相對于TAU明顯減少了預測誤差。作者還給出了一個特殊例子來證明PredFormerr模型相比于CNN模型在泛化性上的優越性。在交通流預測任務上,當第四幀相比前三幀明顯減少流量時,TAU受限于歸納偏置仍然預測了較高的流量,而PredFormer卻能捕捉到這里的變化。PredFormerr預測劇烈變化的能力在交通流和天氣預測中可能有非常寶貴的應用價值。
本文轉載自 ??圓圓的算法筆記??,作者:yujin
