參數量僅為1/700,性能超越GPT-3.5!CMU+清華開源Prompt2Model框架
基于大型語言模型(LLM),開發者或用戶可以通過描述任務,并給出幾個樣例來構造自然語言提示,很輕松地就能實現指定的功能。
不過從某種程度上來說,相比傳統的、面向任務開發的NLP系統,大型語言模型在計算資源需求等方面是一種極大的退步。
最近,卡內基梅隆大學和清華大學的研究人員提出了一種通用的模型構造方法Prompt2Model,開發者只需要構造自然語言提示,就可以訓練出一個可用于指定任務的模型,并易于部署。
論文鏈接:https://arxiv.org/abs/2308.12261
代碼鏈接:https://github.com/neulab/prompt2model
Prompt2Model框架包括檢索現有的數據集、生成訓練數據、搜索與訓練模型、微調訓練、自動化評估和部署等多個步驟。
三個任務的實驗結果證明,給出相同的少樣本提示作為輸入,Prompt2Model可以訓練出一個比大型語言模型更強的小模型,在參數量僅為gpt-3.5-turbo的1/700的情況下,實現了20%的性能提升。
Prompt2Model框架
Prompt2Model系統相當于一個平臺,可以對機器學習管道中的組件進行自動化:包括數據收集、模型訓練、評估和部署。
系統的核心是自動數據收集系統,利用數據集檢索和基于LLM的數據集生成來獲取與用戶需求相關的標注數據;
然后檢索預訓練模型,并在收集到的訓練數據上進行微調;
最后使用相同數據集下的劃分測試集,對得到的模型進行評估,也可以創建一個與模型交互web UI
Prompt2Model非常通用,設計上也遵循模塊化、可擴展,每個組件都可以由開發者進行重新實現或禁用。
下面介紹Prompt2Model各個組件的設計思路,以及文章作者給出的參考實現。
提示解析器(Prompt Parser)
作為系統的主要輸入,用戶需要提供類似LLMs使用的提示詞,包括指令,或是預期回復的幾個演示樣例。
開放式的接口(open-ended interface)對用戶來說很方便,并且端到端(end-to-end)機器學習管道也會從提示解析器中受益,例如將提示分割成指令、單獨的演示樣例,或是將指令翻譯成英語。
參考實現:研究人員將提示解析為指令(instruction)和演示(demonstration),其中指令表示主要的任務或目標,演示代表模型的預期行為。
可以利用具有上下文學習能力的大型語言模型(OpenAI gpt-3.5-turbo-0613)對用戶提示進行分割;如果用戶指令被識別為非英語,則使用DeepL API.2將其翻譯成英語。
數據集檢索器(Dataset Retriever)
用戶給出一個提示后,系統首先會進行檢索,嘗試發現那些符合用戶描述,且已經標注好的數據集,主要包括三個決策:
1. 要搜索哪些數據集?
2. 如何對數據集索引以支持搜索?
3. 哪些數據集是用戶任務所需要的,哪些應該被省略?
參考實現:研究人員先在Huggingface上,為所有的數據集提取用戶描述,然后利用DataFinder的雙編碼檢索器對數據集進行相關度排序。
然后系統會向用戶展示排名靠前的k(=25)個數據集,用戶可以選擇相關數據集,也可以聲明沒有適合目標任務的數據;如果存在可用數據,用戶還需要從數據集的模式中指定輸入和輸出列。
數據集生成器(Dataset Generator)
并不是所有的用戶任務都有完美匹配的數據集,但有些數據與任務在一定程度上是相關的。
為了支持更廣泛的任務,根據提示解析器得到的用戶要求,可以用數據集生成器來產生「合成訓練集」,主要難點在于如何降低成本、提升生成速度、生成樣本多樣性以及質量控制。
參考實現中,研究人員設計的策略包括:
1. 高多樣性的少樣本提示
使用自動化提示工程來生成多樣化的數據集,用先前生成的示例的隨機樣本來擴充用戶提供的演示示例,以促進多樣性并避免生成重復的示例。
生成200個問答樣本時,該策略可以將重復樣本從200降低到25個。
2. 溫度退火(Temperature Annealing)
根據已經生成的示例數量,將采樣溫度從低(輸出結果更確定)調整到高(輸出更隨機),有助于保持輸出質量,同時會促進數據多樣化。
3. 自洽解碼(Self-Consistency Decoding)
鑒于LLM可能為相同的輸入產生非唯一或不正確的輸出,研究人員使用自洽過濾(self-consistency filtering)來選擇偽標簽,具體來說,通過選擇最頻繁的答案,為每個唯一的輸入創建一個一致的輸出;在平局的情況下,啟發式地選擇最短的答案,可以提高生成數據集的準確性,同時確保樣本的唯一性。
4. 異步批處理(Asynchronous Batching)
API請求使用zeno-build進行并行化,引入額外的機制,如動態批大小和節流(throttling)來優化API的用量。
模型檢索器(Model Retriever)
除了訓練數據外,完成任務還需要確定一個合適的模型進行微調,研究人員認為這也是一個檢索問題,每個模型可以由一段「用戶生成的描述」和「元數據」(如受歡迎度、支持的任務等)。
參考實現:為了用統一的模型接口支持海量任務,所以研究人員將系統限制在Huggingface上的編碼器解碼器架構,對于模型蒸餾來說數據效率更高。
然后使用用戶指令作為查詢,基于Huggingface上模型的文本描述進行搜索,不過由于模型的描述通常很少,且包含大量模式化文本,通常只有幾個詞能表示模型的內容。
遵照HyDE框架,先使用gpt-3.5-turbo根據用戶的指示創建一個假設模型描述(hypothetical model description)作為擴展查詢,然后用BM25算法計算查詢模型的相似度分數。
為了確保模型易于部署,用戶可以設定模型的尺寸閾值(默認3GB),并過濾掉所有超過該閾值的模型。
一般來說,高下載量的模型可能質量也更高,也可以把下載量當作參數對模型進行排序:
模型訓練器(Model Trainer)
給定數據集和預訓練模型后,就可以對模型進行訓練、微調,其中所有的任務都可以當作是文本到文本的生成任務。
參考實現:在處理數據集時,研究人員會用到兩個數據集,一個是生成的,另一個是檢索到的,并將數據列文本化后與用戶指令合并到一起添加到模型輸入中。
在微調時,將兩個數據集組合起來后隨機打亂,然后訓練學生模型。
在所有的任務中都使用相同的超參數,使用AdamW優化器,以學習率5e-5訓練3個epoch,每個任務大約需要一小時。
模型評估器(Model Evaluator)
除去用作訓練模型的數據后,其余數據可以用來評估模型的訓練效果,主要難點在與如何在海量的目標任務中選擇出合適的評估指標。
參考實現:研究人員選擇三個通用的指標,即精確匹配、ChrF++和BERScore對所有任務實現自動化評估。
精確匹配(EM)可以衡量模型輸出與參考答案之間完美匹配的程度;ChrF++可以平衡精確度和召回率來評估文本生成質量;BERTScore可以通過比較嵌入空間中的模型輸出和引用來捕獲語義相似性。
使用XLM-R作為BERTScore的編碼器可以支持多語言任務的評估。
演示創建器(Demo Creator)
為了讓開發者可以將模型發布給普通用戶,可以在該模塊中創建一個圖形接口以供交互。
參考實現:研究人員使用Gradio構建了一個模型訪問界面。
實驗部分
實驗設置
作為概念驗證,研究人員測試了該系統在三項任務中學習模型的能力:
1. 機器閱讀問題回答:使用SQuAD作為基準數據集來評估。
2. 日語NL-to-Code:從日語查詢中生成代碼是一個有難度的任務,雖然之前有相關工作,但沒有可用的標注數據或與訓練模型,使用MCoNaLa進行評估。
3. 時態表達式規范化(Temporal Expression Normalization):目前沒有任何類型的預訓練模型或訓練數據集可用,使用Temporal數據集作為基準評估。
雖然Prompt2Model提供了自動模型評估的能力,在生成和檢索的數據測試上,但在這里使用真實的基準數據集來衡量我們的管道訓練準確模型的能力。
在基線模型的選取上,由于該工作的主要目標就是訓練一個小模型可以與大型語言模型相匹配或是更強,所以研究人員選擇gpt-3.5-turbo作為基準數據集的對比基線。
實驗結果
在下游任務中的表現上,Prompt2Model在三個任務中的兩個都實現了遠超gpt-3.5-turbo的性能。
值得注意的是,檢索到的SQuAD和Temporal模型是Flan-T5,僅有250M的參數量,比gpt-3.5-turbo(175B參數)小700倍。
還可以觀察到,Prompt2Model在MCoNaLa的日語轉Python任務上的性能明顯比gpt-3.5-turbo差。
可能的解釋是,生成的日語查詢數據集多樣性相對較低:5000個樣本中有45個都是「在數字列表中找到最大值」的不同說法,而在其他數據集中沒有觀察到這種高的冗余度,表明gpt-3.5-turbo可能很難為非英語的語言生成多樣化的文本。
另一個原因可能是缺乏合適的學生模型,模型型檢索器找到的模型是在多種自然語言或代碼上訓練的,沒有都是多語言的,導致預訓練模型缺乏表征日語輸入、Python輸出相關的參數知識。