Stacking Your Transformer:通過堆疊加快 LLM 預訓練
一、背景
我們之前的文章中介紹了幾種模型增長的方案,然而針對 LLM 場景卻又缺乏足夠的數據支撐以及最佳實踐。比如說不知道在 LLM 場景中這些方案的差異有多大,是否和模型規模、數據量、以及訓練 FLOPs 有關?本文中我們介紹來自香港大學、清華大學和香港科技大學的 Stacking Your Transformer,作者做了大量的實驗來嘗試回答上述問題。
對應的論文為:[2405.15319] Stacking Your Transformers: A Closer Look at Model Growth for Efficient LLM Pre-Training
對應的 Blog:https://llm-stacking.github.io/
對應的代碼庫:GitHub - tongxuluo/prts
PS:需要說明的是,論文中的實驗很多,這里也只列了一部分。當然,也有待完善的地方,比如論文中對比了 4 個抽象的增長算子的影響,但這并不代表就完全等效于之前的工作,如果能有一些具體的對比會更清晰;此外,論文中大小模型訓練數據量也和之前的方法很不同,比如 [2309.03852] FLM-101B: An Open LLM and How to Train It with $100K Budget 中先在 16B 模型訓練 245.37B Token,然后在 51B 模型訓練 39.64B Token,最后在 101B 模型訓練 26.54B Token;而本文中的方案基本是在小規模模型訓練很少的 Token,比如 10B 規模。
二、摘要
論文中,作者進一步探索了模型擴展在 LLM 預訓練中的可行性。首先,作者確定了 3 個關鍵障礙:
- O1:缺乏全面的評估。
- O2:未經測試的擴展可行性。
- O3:缺乏經驗指南。
為了解決 O1 問題,作者將現有的方案總結為 4 個原子增長算子,并在標準的 LLM 預訓練中對其進行了系統的評估。結果表明,與 Baseline 相比,深度堆疊算子 Gstack 表現出了顯著的加速,從而提升了在 8 個 NLP 基準的整體性能。基于此,作者深入的研究了 Gstack,以便解決 O2 和 O3。對于 O2,作者實驗表明,Gstack 是可擴展的,并且始終表現良好,例如,與直接使用 300B Token 訓練的 7B 模型相比,Gstack 只使用 194B Token 就可以達到相同損失,加速 54.6%。對于 O3,作者通過建模確定 Gstack 的增長規劃(Growth Timing)和增長因子(Growth Factor),使其在常見的 LLM 預訓練中更實用。
三、方法
3.1 O1:4 種增長算子
如下圖 Figure 2 所示,作者將之前方案中的生長方案總結為 4 個生長算子:
- (a):Gdirect 通過拷貝、切分和堆疊的方式實現,分為寬度方向和深度方向。
- (b):Glearn 通過學習映射函數的方式實現。
- (c):Gzero 通過擴充 0 值的方式實現。
- (d):Grandom 通過隨機初始化然后增加 Mask 的方式實現。(PS:是否也可以不使用隨機初始化,比如拷貝后添加 Mask?)?
為了對比不同方法的效果,作者制定了一個統一的方案:總共兩個訓練階段,增長前的小模型訓練,增長后的大模型訓練。其小模型訓練的 Token 數 d,大模型訓練的 Token 數 D 以及模型增長因子 g(對應非 Embedding 參數) 作為超參。
如下圖 Figure 3 所示,作者先用 d=10B Token 預訓練了一個 400M 參數的模型,然后擴展為 1.1B 參數,對應增長因子 g=4,并繼續使用額外的 D=97.5B Token 進行訓練,以此來驗證不同方案的效果。可以看出深度堆疊 Gdirect 獲得了最好的效果,與直接訓練 100B Token大模型相比可以加速 49.1%,同時寬度擴展基本都是負優化。
PS:實際上采用不同的 d 來訓練小模型得到的結果很不一樣,比如作者實際分別測試了使用 d=10B 和 d=50B Token 來訓練小模型的結果,可以發現在 d=50B 的時候 Gzero 在深度上獲得了更好的結果, Grandom 在寬度上獲得更好的效果。
3.2 O2:Gstack 擴展可行性
從上面可以看出,在作者 400M 的實驗中,模型深度堆疊的方案能獲得不錯的收益,因此作者也聚焦在模型深度堆疊場景。如下所示,作者將這種方式稱為 Gstack:
擴展模型規模:如下圖 Figure 4 和 Figure 5 所示,作者在 3B 模型和 7B 模型上進行了驗證,其中 g=4,d=10B,具體來說,小模型的層數分別是 3B 和 7B 模型的 1/4,小模型訓練 10B Token:
- Figure 4:3B 模型從頭開始訓練
- 訓練 180B Token 達到 loss 與 Gstack 花費 48.6% FLOPs 的 loss 相當。
- 訓練 240B Token 達到 loss 與 Gstack 花費 54.5% FLOPs 的 loss 相當。
- Figure 5:7B 從頭開始訓練,160B,220B 和 280B Token 對應 Gstack 的 FLOPs 為 40.8%, 55.3% 和 53.8%。?
從上可以看出,當對比 1B、3B 和 7B 模型時,Gstack 帶來的優勢并沒有隨著模型尺寸的增加而減少,這意味著,即使更大的模型中也可以利用 Gstack 來加速。
擴展數據規模:如下圖 Figure 6 所示,作者進一步探索了不斷擴展數據規模的時候 Gstack 是否還有優勢。從 Figure 6a 可以看出,對于 410M 的模型,訓練遠超縮放法則確定的 Token 數(8B),模型的 Loss 一直在下降,并且 Gstack 一直能獲得更低的 Loss。如下圖 Figure 6b 所示,作者進一步預估當訓練 Token 數達到 8T 時(1000倍),Gstack 對應的 Loss 依然更低,表明擴大數據規模 Gstack 依然會有加速的效果。
Scaling Laws:如下圖 Figure 7 所示,作者根據 410M、1.1B、3B、7B 模型的相關實驗擬合了縮放法則曲線。可以看出,Gstack 在基于此預估出來的 13B 和 70B 模型模型上依然能夠獲得加速:
3.3 O3:Growth Timing d 和 Growth Factor g
在上述的實驗中作者直接采用了 d=10B 和 g=4,那么這個是否是最佳組合呢,在小模型上訓練更多數據是否能帶來更高的加速比?從 1 B 模型直接擴展到 30B 模型是否可行?為了回答這些問題,作者通過建模來確定 d 和 g 的影響。
如下圖 Figure 8 所示,作者在 410M、1.1B 和 3B 模型上探索了 Growth Timing d 的影響。可以看出,對于給定的 FLOP 預算,可以確定一個最優的 Growth Timing d,比如說,對于 410M 模型,最優的 d 為 5-10B Token,對于 1.1B 模型,最優的 d 為 10-20B Token:
如下圖 Figure 9 所示,作者進一步擬合了對于給定大模型預訓練 Token 數 C 和參數量 N 的情況下來預測 d 的曲線:
如下圖 Figure 10 所示,作者在 1.1B(24 層) 和 3B(32 層) 模型上探索了 Growth Factor g 的影響。可以看出,對于 1.1 B 的模型,最優的 g 為 2-3 左右,對于 3B 模型,最優的 g 為 4-5 左右。對于 3B 模型來說,即使 g=16 時(對應 small model 為 2 層) Gstack 依然能有加速。
四、附錄
4.1 如何堆疊?
在 [2011.13635] Progressively Stacking 2.0: A Multi-stage Layerwise Training Method for BERT Training Speedup 中,作者對 StackingBert-v1 進行了擴展。具體來說,將一個 N 層 Encoder 的模型分 K+1 次訓練,第一次訓練一個 N/k 層的 Bert 模型,然后每次擴展 N/k 層并且進行訓練。其中綠色為凍結的層,紅色為訓練的層。也就是每次擴展后只訓練新擴展的層,全部擴展完之后再解凍所有層繼續訓練:
本文作者采用了稍有不同的方案,具體來說,不是逐層堆疊 N/k 個 Layer,而是分兩次堆疊。比如對于 6 層到 24 層的模型,第一次從 6 層到 12 層,第二次直接從 12 層到 24 層。作者也對兩種方案進行了消融實驗,如下圖 Figure 33 所示,可以看出本文 Gstack 的方案會更優一些:
此外,如下圖 Table 5 所示,作者也探索了不同的堆疊方式,比如只堆疊中間層,或只堆疊首/尾層,最終發現還是全部堆疊的方案最優:
4.2 為什么沒保證 FPI?
Function Preserving Initialization(FPI):目標是給定一個源模型,用它初始化目標模型,能保證給定相同的輸入,目標模型和源模型有相同的輸出。
在之前的很多工作中都在嘗試保證 FPI,那是否一定要滿足這個要求呢?針對這個問題作者也做了一些實驗,具體來說,在 Gdirect 中通過添加噪聲(加 20% 噪聲),不加噪聲,以及從頭訓練的方式進行對比,如下圖 Figure 39 可以看出,初始階段加噪的效果確實更差,但是隨著 FLOPs 的增加,加噪的方式反而更好:
五、參考鏈接
- ??https://arxiv.org/abs/2405.15319??
- ??https://llm-stacking.github.io/??
- ??https://github.com/tongxuluo/prts??
- ??https://arxiv.org/abs/2309.03852??
- ??https://arxiv.org/abs/2011.13635??
本文轉載自??AI閑談??,作者: AI閑談
