RegMix-用回歸任務解決大模型數據混合問題
寫在前面
大型語言模型在預訓練過程中,如何選取數據的混合比例(利用較少的Tokens來實現較小的Loss從而加速預訓練過程)是一個復雜和關鍵的問題。手動確認數據集中各個組成的比例是不可擴展的,并且很可能不是最優選擇。
今天給大家介紹一個用回歸任務解決大模型數據混合問題的方法-RegMix。其核心思想是,利用不同的數據混合比例先訓練多個小模型并獲取其結果,在利用這些樣本訓練一個回歸模型,再遍歷所有比例利用回歸模型找到最優的數據混合比例,最后用最優數據混合比例訓練更大的語言模型。
Paper: https://arxiv.org/abs/2407.01492
Github: https://github.com/sail-sg/regmix
通過訓練512個1M的小模型,擬合回歸模型,找到top64的數據混合比例,訓練1B模型,最優數據混合比例訓練的模型的驗證集loss也是最低。
方法
整體流程如上圖所示,
- 生成隨機數據混合比例,按照比例采用混合數據并訓練小模型;
- 利用數據混合比例作為特征值,模型訓練的目標值作為標簽,擬合回歸模型;
- 在模擬更大數據混合比例空間,利用回歸模型預測最佳目標值,以獲取最佳混合比例;
- 使用模擬出的最佳混合比例的數據訓練更大的模型。
訓練小模型時越多越好,但為了節約成本需要盡量減少小模型訓練次數,那么在初始化數據混合比例時就需要時多樣化的,并且每個數據領域需要都存在極端值,數據采用過程主要是基于Tokens(chunk-level)分布的狄利克雷分布來實現。
詳見:mixture_config/synthesize_mixture.py
同時在擬合回歸模型時,采用了線性回歸和LightGBM兩種回歸模型。
結果
數據集采用Pile dataset中不涉及版權的17個數據集,如下表所示,
512個1M小模型在1B Tokens訓練得到的回歸模型,與在25B Tokens數據下訓練的1B模型,排序具有97.12%的高相關性,如下表所示,
同時訓練次數要比訓練的總Token數要重要,更影響回歸模型的效果,并且采用LightGBM建模要比線性回歸建模要好。
PS:跟作者@乾神交流過,512個樣本訓練回歸模型會不會數據量太少,乾神說他們做過1024的實驗,但并回歸模型效果無明顯提高,并且從成本考慮,那么512最佳。
不同的數據混合比例對下游任務結果影響較大,在Lambada數據集上最好和最差的效果相差14.6%,如下表所示,
同時發現了一個與傳統理解不一致的結果,一般我們任務維基數據質量很高,是評估大型語言模型最具代表性的數據集。但實驗結果發現,網絡數據集上評估的效果,更能體現模型在下游任務上的好壞,如下圖所示,可以發現Pile-CC數據集作為驗證時損失值與下游任務的相關性更強。
并且RegMix可以發現各領域數據之間是如何相互作用的,數據領域之間復雜的相互作用利用人類固有經驗很難直接區分。
本文轉載自 ??NLP工作站??,作者:劉聰NLP
