Transformer結構優勢 ,How Much Attention Do You Need?
前言
本期基于凱斯西儲大學(CWRU)軸承數據,進行 Transformer 的結構優勢進行講解,結合論文《How Much Attention Do You Need? 》,探索不同模塊對故障分類任務的影響力。
1 《How Much Attention Do You Need? 》
1.1 論文解析
論文提到三個觀點:
(1)Source attention on lower encoder layers brings no additional benefit (x4.2).
解釋:
- Source Attention:通常指的是在編碼器(encoder)中,模型關注輸入序列的不同部分的機制。在Transformer中,編碼器的每一層都通過自注意力(self-attention)機制來處理輸入序列。
- Lower Encoder Layers:指的是編碼器中的靠前或較淺的層。
- 結論意義:在模型的較淺層次,對源輸入進行復雜的注意力機制處理并沒有顯著的性能提升。也就是說,較低層次的編碼器主要在做基礎特征提取,而復雜的注意力模式在這些層次上未能發揮其潛力。因此,將注意力機制的復雜性集中在編碼器的更深層次可能更為有效。
(2)Multiple source attention layers and residual feed-forward layers are key (x4.3).
解釋:
- Multiple Source Attention Layers:在模型中使用多個層次的注意力機制來處理源輸入序列。
- Residual Feed-Forward Layers:在每個注意力層之后,通常會有一個殘差結構的前饋神經網絡(Feed-Forward Neural Network),這對于學習復雜的特征變換是重要的。
- 結論意義:這表明,對源輸入進行多層次的注意力處理,以及在每個注意力層之后使用殘差前饋層,是模型性能的關鍵因素。這可能是因為多層次的關注機制允許模型在不同的抽象層次上理解輸入序列,并通過殘差連接有效地訓練深層模型。
(3)Self-attention is more important for the source than for the target side (x4.4).
解釋:
- Self-Attention:一種注意力機制,輸入的每個元素(如序列中的一個詞)對自身序列中的其他元素進行關注,以捕捉序列內部的相關性。
- Source Side vs. Target Side:在序列到序列模型中,源指的是輸入序列(例如,翻譯任務中的源語言),而目標指的是輸出序列(例如,翻譯任務中的目標語言)。
- 結論意義:這個結論表明,自注意力機制在源輸入序列的處理過程中比在目標輸出序列生成過程中更為重要。這可能是因為在源輸入的編碼階段,理解和建模句子內部的長程依賴性和上下文關系至關重要,而在目標側,可能更多依賴上下文和已生成的部分來預測下一個元素。
2.1 具體結構探究
(1)雙向 RNN 替換多頭注意力
RNN嵌入Transformer后,上圖展示了對原生RNN逐步加入Transformer的各個構件后的效果。從上面的逐步對比過程可以看到,原生RNN的效果在不斷穩定提升。但是原本的Transformer相比,性能仍然有差距。
(2)CNN 替換多頭注意力
上圖展示了對CNN進行不斷加入Transformer的各個構件后的過程以及其對應效果。同樣的,性能也有不同幅度的提升。但是也與原本的Transformer性能存在一些差距。
(3)論文結論:
我們發現基于RNN的模型受益于多源注意機制和剩余前饋塊。另一方面,基于CNN的模型可以通過層歸一化和前饋塊來改進。這些變化使基于RNN和CNN的模型更接近Transformer。此外,我們還展示了可以成功地組合體系結構。
我們發現自我注意在編碼器方面比在解碼器方面重要得多,即使沒有自我注意的模型也表現得非常好。對于我們評估的數據集,在大多數情況下,編碼器側具有自關注的模型以及解碼器側具有RNN或CNN的模型與Transformer模型相比具有競爭力。
2 結合故障診斷進行探索
2.1 探索目標分析
根據上述結論,編碼器結構在特征提取過程中扮演著重要角色,而編碼器結構又由多頭注意力、前饋神經網絡、殘差連接等部分組成。結合故障診斷任務,我們可以鎖定2個研究對象:
- 編碼器整體結構
- 多頭注意力機制
2.2 多頭注意力機制的優勢
(1)多角度關注:
不同的頭可以從不同的角度和細粒度的層次來關注輸入序列的不同部分,從而捕捉更豐富的特征和關系。
(2)提高模型的表達能力:
多頭機制使得模型可以在不同的子空間中并行學習,這增強了其表達復雜關系和模式的能力。
(3)穩定訓練:
通過縮放點積和多頭的并行計算,模型能更好地處理長序列并穩定梯度。
3 軸承故障數據的預處理
3.1 導入數據
參考之前的文章,進行故障10分類的預處理,凱斯西儲大學軸承數據10分類數據集:
train_set、val_set、test_set 均為按照7:2:1劃分訓練集、驗證集、測試集,最后保存數據
3.2 故障數據預處理與數據集制作
4 編碼器整體結構的實驗對比
4.1 對比模型為:
- 模型 A :Transformer 編碼器結構
- 模型 B :多頭注意力機制
4.2 西儲大學十分類數據集實驗對比
(1)模型 A:
模型評估:
準確率、精確率、召回率、F1 Score
(2)模型 B:
模型評估:
準確率、精確率、召回率、F1 Score
4.3 東南大學齒輪箱軸承故障-五分類數據集實驗對比
(1)模型 A:
模型評估:
準確率、精確率、召回率、F1 Score
(2)模型 B:
模型評估:
準確率、精確率、召回率、F1 Score
5 實驗對比結果分析
通過兩個數據集的對比實驗,我們可以發現,Transformer 編碼器層在故障信號分類任務上取得了不錯的效果,但是僅用多頭注意力機制分類效果有一定程度的下降,證明Transformer 編碼器整體結構在故障信號分類任務上的優越性!大家還可以進一步細致的探索結構中的其他部分。
