補(bǔ)齊Transformer規(guī)劃短板,田淵棟團(tuán)隊的Searchformer火了
最近幾年,基于 Transformer 的架構(gòu)在多種任務(wù)上都表現(xiàn)卓越,吸引了世界的矚目。使用這類架構(gòu)搭配大量數(shù)據(jù),得到的大型語言模型(LLM)等模型可以很好地泛化用于真實世界用例。
盡管有如此成功,但基于 Transformer 的架構(gòu)和 LLM 依然難以處理規(guī)劃和推理任務(wù)。之前已有研究證明 LLM 難以應(yīng)對多步規(guī)劃任務(wù)或高階推理任務(wù)。
為了提升 Transformer 的推理和規(guī)劃性能,近些年研究社區(qū)也提出了一些方法。一種最常見且有效的方法是模擬人類的思考過程:先生成中間「思維」,然后再輸出響應(yīng)。比如思維鏈(CoT)提示法就是鼓勵模型預(yù)測中間步驟,進(jìn)行按步驟的「思考」。思維樹(ToT)則使用了分支策略和評判方法,讓模型生成多個不同的思維路徑,然后從中選出最佳路徑。盡管這些技術(shù)通常是有效的,但也有研究表明,在很多案例中,這些方法會讓模型的性能下降,原因包括自我強(qiáng)制(self-enforcing)。
另一方面,在一個數(shù)據(jù)集上有效的技術(shù)可能無法很好地處理其它數(shù)據(jù)集,原因可能包括所涉及的推理類型發(fā)生了變化,比如從空間推理變成了數(shù)學(xué)推理或常識推理。
相較之下,傳統(tǒng)的符號式規(guī)劃和搜索技術(shù)卻能表現(xiàn)出很好的推理能力。此外,這些傳統(tǒng)方法計算得到的解決方案通常有形式上的保證,因為符號規(guī)劃算法通常遵循明確定義的基于規(guī)則的搜索過程。
為了讓 Transformer 具備復(fù)雜推理能力,Meta FAIR 田淵棟團(tuán)隊近日提出了 Searchformer。
- 論文標(biāo)題:Beyond A?: Better Planning with Transformers via Search Dynamics Bootstrapping
- 論文地址:https://arxiv.org/pdf/2402.14083.pdf
Searchformer 是一種 Transformer 模型,但針對迷宮導(dǎo)航和推箱子等多步規(guī)劃任務(wù),它卻能計算出最優(yōu)規(guī)劃并且所用搜索步驟數(shù)也能遠(yuǎn)少于 A? 搜索等符號規(guī)劃算法。
為了做到這一點,該團(tuán)隊提出了一種新方法:搜索動態(tài)引導(dǎo)(search dynamics bootstrapping)。該方法首先是訓(xùn)練一個 Transformer 模型來模仿 A? 的搜索過程(如圖 1 所示,然后對其進(jìn)行微調(diào),使其能用更少的搜索步數(shù)找到最優(yōu)規(guī)劃。
更詳細(xì)地說,第一步,訓(xùn)練一個模仿 A? 搜索的 Transformer 模型。這里,該團(tuán)隊的做法是針對隨機(jī)生成的規(guī)劃任務(wù)實例運行 A* 搜索。在執(zhí)行 A? 時,該團(tuán)隊會記錄執(zhí)行的計算和最優(yōu)規(guī)劃并將其整理成詞序列,即 token。這樣一來,所得到的訓(xùn)練數(shù)據(jù)集就包含了 A? 的執(zhí)行軌跡并編碼了有關(guān) A? 本身的搜索動態(tài)的信息。然后,訓(xùn)練一個 Transformer 模型,讓其能針對任意規(guī)劃任務(wù)沿最優(yōu)規(guī)劃生成這些 token 序列。
第二步,使用專家迭代(expert iteration)方法進(jìn)一步提升使用上述經(jīng)過搜索增強(qiáng)的序列(包含 A? 的執(zhí)行軌跡)訓(xùn)練的 Searchformer。專家迭代方法可讓 Transformer 憑借更少的搜索步驟生成最優(yōu)解。這個過程會得到一種神經(jīng)規(guī)劃算法,其隱式地編碼在該 Transformer 的網(wǎng)絡(luò)權(quán)重之中,并且它有很高的概率以少于 A? 搜索的搜索步數(shù)找到最優(yōu)規(guī)劃。比如說,在執(zhí)行推箱子任務(wù)時,新模型能解答 93.7% 的測試任務(wù),同時搜索步數(shù)比 A? 搜索平均少 26.8%。
該團(tuán)隊表示:這為 Transformer 超越傳統(tǒng)符號規(guī)劃算法鋪平了道路。
實驗
為了更好地理解訓(xùn)練數(shù)據(jù)和模型參數(shù)量對所得模型性能的影響,他們進(jìn)行了一些消融研究。
他們使用了兩類數(shù)據(jù)集訓(xùn)練模型:一種的 token 序列中只包含解(solution-only,其中只有任務(wù)描述和最終規(guī)劃);另一種則是搜索增強(qiáng)型序列(search-augmented,其中包含任務(wù)描述、搜索樹動態(tài)和最終規(guī)劃)。
實驗中,該團(tuán)隊使用了 A? 搜索的一種確定性和非確定性變體來生成每個序列數(shù)據(jù)集。
迷宮導(dǎo)航
在第一個實驗中,該團(tuán)隊訓(xùn)練了一組編碼器 - 解碼器 Transformer 模型來預(yù)測 30×30 迷宮中的最優(yōu)路徑。
圖 4 表明,通過預(yù)測中間計算步驟,可在數(shù)據(jù)量少時獲得更穩(wěn)健的性能表現(xiàn)。
圖 5 給出了僅使用解訓(xùn)練的模型的性能。
圖 6 展示了任務(wù)難度對每個模型的性能的影響。
整體而言,盡管當(dāng)使用的訓(xùn)練數(shù)據(jù)集足夠大和足夠多樣化時,僅使用解訓(xùn)練的模型也能預(yù)測得到最優(yōu)規(guī)劃,但當(dāng)數(shù)據(jù)量少時,經(jīng)過搜索增強(qiáng)的模型的表現(xiàn)明顯好得多,并且也能更好地擴(kuò)展用于更困難的任務(wù)。
推箱子
為了測試能否在不同且更復(fù)雜的任務(wù)(具有不同的 token 化模式)上得到類似的結(jié)果,該團(tuán)隊還生成了一個推箱子的規(guī)劃數(shù)據(jù)集進(jìn)行測試。
圖 7 展示了每種模型針對每個測試任務(wù)生成正確規(guī)劃的概率。
可以看到,和上一個實驗一樣,通過使用執(zhí)行軌跡進(jìn)行訓(xùn)練,搜索增強(qiáng)型模型的表現(xiàn)優(yōu)于僅使用解訓(xùn)練的模型。
Searchformer:通過引導(dǎo)方法提升搜索動態(tài)
最后一個實驗,該團(tuán)隊研究了搜索增強(qiáng)型模型可以如何迭代提升,從而憑借更少的搜索步數(shù)計算出最優(yōu)規(guī)劃。這里的目標(biāo)是在縮短搜索軌跡長度的同時依然得到最優(yōu)解。
圖 8 表明,新提出的搜索動態(tài)引導(dǎo)方法能夠迭代式地縮短 Searchformer 模型生成的序列的長度。