StaR | 用少量推理數(shù)據(jù)讓模型學(xué)會(huì)通用推理能力,顯著提升模型復(fù)雜推理
一、概述
?Title:STaR: Bootstrapping Reasoning With Reasoning
?URL:?? https://arxiv.org/abs/2203.14465??
?Authors:Eric Zelikman, Yuhuai Wu, Jesse Mu, Noah D. Goodman
?Code:?? https://github.com/ezelikman/STaR??
1 Motivation
?Step-by-step推理步驟生成可以提升語言模型在復(fù)雜推理任務(wù)(如數(shù)學(xué)或常識(shí)問答)上的性能,但是當(dāng)前要讓LLM能生成rationale推理過程,要么需要構(gòu)建龐大的推理數(shù)據(jù)集,要么在只使用少量示例(但推理時(shí)犧牲了準(zhǔn)確性)。
?需要一種方法來利用少量的推理示例和大量未經(jīng)過推理的數(shù)據(jù)來提升模型的推理能力。
2 Methods
1 省流版總結(jié):
- 使用少量推理示例(few-shot)引導(dǎo)語言模型生成多個(gè)問題的推理Rational過程。
- 對(duì)于模型生成的錯(cuò)誤答案,通過提供正確答案(Hint)來生成新的推理過程(稱為“rationalization”)。
- 在所有最終生成正確答案的推理上微調(diào)模型(Finetune)。
- 重復(fù)上述過程,直到performance不再提升(注意每次都使用original的預(yù)模型進(jìn)行continually training來避免overfitting)。
2 專業(yè)版總結(jié):
本文提出了一種名為“Self-Taught Reasoner”(STaR)的方法來解決語言模型在復(fù)雜推理任務(wù)上性能提升的問題。**STaR方法的核心思想是通過迭代地利用少量推理示例(rationales)和大量無推理數(shù)據(jù)集,逐步引導(dǎo)模型提升進(jìn)行復(fù)雜推理的能力。**具體來說,STaR方法包括以下幾個(gè)步驟:
- Rationale Generation Bootstrapping:首先,使用少量帶有推理過程的示例作為提示,引導(dǎo)預(yù)訓(xùn)練的大型語言模型(LLM)生成多個(gè)問題的推理過程。這個(gè)過程被稱為“rationale generation”。
- Filtering and Finetuning:接著,只保留那些生成了正確答案的推理過程,并在這些數(shù)據(jù)上對(duì)模型進(jìn)行微調(diào)(finetune)。這一步驟的目的是強(qiáng)化模型生成高質(zhì)量推理過程的能力。
- Rationalization:對(duì)于模型未能正確回答的問題,STaR采用一種稱為“rationalization”的技術(shù)。在這個(gè)階段,模型被提供正確答案作為提示,然后生成一個(gè)合理的推理過程來解釋這個(gè)答案。這樣做可以讓模型從錯(cuò)誤中學(xué)習(xí),并改進(jìn)其推理策略。
- Iterative Improvement:重復(fù)上述過程,每次都使用上一輪微調(diào)后的模型來生成新的訓(xùn)練數(shù)據(jù)。通過這種方式,模型逐漸學(xué)習(xí)如何更好地生成推理過程,并解決越來越復(fù)雜的問題。
- 5.Performance Evaluation:在每次迭代后,評(píng)估模型在測(cè)試集上的性能,直到性能達(dá)到飽和或不再顯著提升。
3 Rationalization指的是什么?
Q1:為什么要用Rationalization?
? 直接讓LLM生成推理思考過程,這些思考過程有些是對(duì)的,有些是錯(cuò)的,直接拿正確的思考過程,來訓(xùn)練llm生成rational,由于沒有增量信息,會(huì)導(dǎo)致模型不能從failed example中學(xué)習(xí),這樣就不能讓模型具備對(duì)new problems進(jìn)行推理的能力。
Q2: 如何生成Rational
? 如下圖所示,直接讓LLM生成推理過程,對(duì)于failed的例子,加上label作為hint,基于hint,可以生成正確的推理過程。
3 Conclusion
? STaR顯著提升了在多個(gè)數(shù)據(jù)集上的性能,相對(duì)于直接預(yù)測(cè)最終答案的模型,其效果更加突出。
? 在CommonsenseQA數(shù)據(jù)集上的表現(xiàn)與微調(diào)一個(gè)大30倍的最先進(jìn)語言模型相當(dāng)。
? STaR使得模型能夠通過學(xué)習(xí)自身生成的推理步驟逐步提升推理能力。
二、詳細(xì)內(nèi)容
1 實(shí)驗(yàn)設(shè)計(jì)
數(shù)據(jù)集:
- 算術(shù)問題:使用隨機(jī)生成的加法問題來測(cè)試STaR在處理數(shù)字運(yùn)算任務(wù)上的性能。
- 常識(shí)問答(CommonsenseQA):使用CommonsenseQA(CQA)數(shù)據(jù)集,這是一個(gè)多項(xiàng)選擇的常識(shí)推理任務(wù),測(cè)試STaR在自然語言推理上的能力。
- 小學(xué)數(shù)學(xué)(Grade School Math, GSM8K):使用GSM8K數(shù)據(jù)集,包含小學(xué)水平的數(shù)學(xué)問題,這些問題以自然語言的形式表述,需要進(jìn)行多步計(jì)算來得出答案。
Baseline:模型采用的是6B的開源模型(GPT-J),其checkpoint和fine-tuning code都開源了。
2 Rationalization能快速提升accuracy(從失敗中學(xué)習(xí)能快速成長(zhǎng)!!!)
說明;rationalization指的就是對(duì)于failed的example,加上hint,生成正確的推理過程數(shù)據(jù)并用于訓(xùn)練。
結(jié)論:隨著STaR算法迭代次數(shù)的增加,模型在算術(shù)任務(wù)上的準(zhǔn)確率逐漸提高。特別是在使用rationalization的情況下,準(zhǔn)確率提升更加塊。
3 STaR + rationalization比直接FT和few-shot效果好很多
? CQA數(shù)據(jù)集
? GSM8K數(shù)據(jù)集
說明:
? Direct Finetuned:不輸出中間推理過程
? STaR without rationalization:不從失敗樣例中學(xué)習(xí)(以label作為hint生成推理過程用于ft)
? STaR with rationalization:從失敗中學(xué)習(xí)
結(jié)論1:生成中間推理過程能顯著提升最終的精度,例如就算使用100%的數(shù)據(jù),不加推理過程,精度只能到60%,加上后用更少的數(shù)據(jù)卻能更高的精度(大于68%)。
結(jié)論2:rationalization從失敗中學(xué)習(xí)能進(jìn)一步提升精度。
三、總結(jié)
STaR方法的關(guān)鍵在于,它允許模型通過自我生成的推理過程來自我改進(jìn),而不需要人工標(biāo)注大量的推理數(shù)據(jù)集。此外,**通過rationalization技術(shù),STaR能夠確保模型從其錯(cuò)誤中學(xué)習(xí),從而提高整體的推理能力。**論文的實(shí)驗(yàn)結(jié)果表明,STaR在多個(gè)數(shù)據(jù)集上的性能顯著優(yōu)于直接預(yù)測(cè)答案的模型,并且與使用30倍更大模型的微調(diào)性能相當(dāng)。
本文轉(zhuǎn)載自??NLP PaperWeekly??,作者: NLP PaperWeekly ????
