告別偏科,能玩轉多模態、多任務、多領域的強化智能體終于來了
隨著 Llama 3 發布,未來大模型的參數量已飆升至驚人的 4000 億。盡管每周幾乎都有一個聲稱性能超強的大模型出來炸場,但 AI 應用還在等待屬于它們的「ChatGPT 時刻」。其中,AI 智能體無疑是最被看好的賽道。
就連吳恩達都說,GPT-4 加上 AI 智能體,可能提前達到 GPT-5 的效果。
不過,我們熟知的智能體往往有點「偏科」。例如,第一個 AI 軟件工程師 Devin,專精于代碼。會打游戲的智能體往往也只能在某一個游戲里秀操作。尋找一個能夠同時擅長多個領域,并能在其中無縫切換的通用模型仍是機器學習研究中的一個關鍵目標。
為了解決這個問題,研究者們對于智能體如何結合計算機視覺(CV)和自然語言處理(NLP)任務進行了廣泛探索,但將強化學習(RL)任務整合進來的研究相對較少。這是由于 RL 任務本質上是異質的,這使得將 RL 任務與對話和圖像識別等其他任務結合起來更加困難。這要求智能體能融會貫通不同領域任務中的不同模態、任務復雜性和數據類型。要達到全能型智能體,主要需要解決以下問題:(1)如何設計一個能夠處理多種數據類型和模態的統一模型結構?(2)如何有效地平衡不同任務的學習進度和優先級?(3)如何確保智能體制定合適的學習目標,以避免不同任務之間的干擾和負向遷移?
來自 Hugging Face、法國國家信息與自動化研究所(INRIA)和波爾多大學的四位研究者提出了智能體中的「六邊形戰士」——Jack of All Trades (JAT)。JAT 是一個基于 Transformer 的多模態通用強化學習智能體框架。在此框架下,智能體能夠通過同一套參數應對不同復雜度的多種任務,化身既會打游戲,又能控制機器人的全能高手。論文同時發布了大量 RL 智能體與 JAT 數據集。這是首個用于通用智能體訓練的數據集 JAT 數據集,包含了由專家智能體收集的數十萬條軌跡。
- 論文名稱:《Jack of All Trades, Master of Some, a Multi-Purpose Transformer Agent》
- 論文鏈接:https://huggingface.co/papers/2402.09844
- 代碼鏈接:https://github.com/huggingface/jat
- 項目鏈接:https://huggingface.co/jat-project/jat
- 數據集:https://huggingface.co/datasets/jat-project/jat-dataset
模型架構
JAT 的核心結構基于 Transformer,使用了 EleutherAI 的 GPT-Neo 實現。JAT 最大的創新點在于其嵌入機制,從本質上解決了數據類型不同的問題。JAT 模型將觀察嵌入與其對應的獎勵值和動作嵌入交錯排列,形成一個序列。
圖 1.JAT 網絡架構。對于序列中的決策任務,一方面輸入觀察嵌入與獎勵值,另一方面行動嵌入被編碼并被交錯放置。模型使用因果掩碼自回歸地生成下一個嵌入,并根據預期的模態進行解碼。
因此,每個嵌入要么對應一個與獎勵相關聯的觀察嵌入,要么對應一個動作嵌入。JAT 如何進一步對這些信息進行編碼呢?這要取決于數據的類型。如果觀察嵌入或動作嵌入的數據類型是圖像,那么 JAT 將使用 CNN。如果是連續向量,則使用線性層。如果是離散值,則使用線性投影層。模型的輸出也遵循相同的邏輯,具體取決于預測目標的數據類型。預測基于因果推理進行,將觀察嵌入向后移動一個時間步,確保智能體可以根據所有先前的觀察和動作嵌入來預測下一個動作嵌入。
這種嵌入設計讓研究團隊在訓練智能體執行 NLP 和 CV 任務時興致盎然。對于和文本相關的任務,作者讓 JAT 模型采用 GPT-2 的分詞策略,將文本轉換為一個整數序列,然后通過一個查找表映射到一個嵌入向量序列。對于和圖像有關的任務,JAT 模型將選擇 ViT 方法,將圖像切割成小塊后,通過線性層轉換為嵌入向量序列。JAT 模型再將圖像和文本的向量序列拼接在一起,形成一個統一的序列,輸入到 Transformer 中。
考慮到數據的模態變來變去,JAT 如何計算損失函數呢?它將針對每種模態分別計算 loss。對于圖像和連續值,它使用均方誤差(MSE)損失。對于離散值,它使用交叉熵損失。最終的損失是序列中每種元素損失的平均值。那么,這是否意味著 JAT 在預測動作嵌入和觀察嵌入時的權重是相同的呢?實際上不是,在此后的章節中將一步探討這個問題。
實驗結果
研究團隊共采用了 157 個訓練任務來 JAT 評估。他們將這些任務分為 10 類,并記錄了 JAT 的總獎勵值。
JAT 模型在最終的檢查點上達到了 65.8% 的專家得分,說明 JAT 能夠在非常廣泛的任務上達到專家水平。以下具體列出了 JAT 在四個常見的智能體訓練環境中的得分:
- 對于 Atari 57,應用 JAT 模型的智能體實現了專家分數的 14.1%,這相當于人類表現的 37.6%。Atari 視頻游戲廣泛被用作評估和開發強化學習算法的基準環境,其中《吃豆人》是一款標志性游戲。在這一系列的 21 款游戲中,JAT 智能體的表現已經超越了人類玩家。值得注意的是, JAT 只用了單一網絡就在所有 Atari 視頻游戲中達到了這種水平;
- 對于 BabyAI,應用 JAT 模型的智能體達到了專家分數的 99.0%,只有一個任務的表現未能超過專家水平的 50%;
- 對于 Meta-World,應用 JAT 模型的智能體達到了專家分數的 65.5%;
- 對于 MuJoCo,應用 JAT 模型的智能體達到了專家分數的 84.8%。
JAT 智能體在 Atari 57 基線上和人類表現的對比
這些 JAT 智能體都可以通過項目主頁下載,進一步測試和體驗。更多細節請參閱論文原文。
專家智能體和 JAT 數據集
專家策略
傳統的強化學習往往在單一環境中尋找專家策略,即在一個特定任務中尋找讓模型表現最優的方法。構建跨領域的多功能智能體,也離不開這種方法。論文作者選擇了 Atari、BabyAI、Meta-World 和 MuJoCo 一系列性質不同,難度各異的訓練環境,直到訓練出表現最好的智能體。這一系列采用 JAT 框架的專家智能體已經在項目主頁上發布。
JAT 數據集
論文作者隨論文同步發布了 JAT 數據集,這是首個針對通用智能體訓練的專項數據集。其中包含了數十萬條由上述專家智能體收集的軌跡數據。使用起來也很方便,可以像加載 Hugging Face 平臺上的其他數據集一樣簡單。以下是調用代碼示例:
JAT 數據集不僅包含強化學習的數據,還整合了來自維基百科等文本數據集,以及 Oscar、OK-VQA、Conceptual Captions 等針對視覺任務的數據集,提供了更豐富的數據類型選擇。
增加模型預測觀察嵌入的能力
智能體學得更好更快了
在訓練強化學習智能體時,主要目標是使其在未曾遇到的任務中實現獎勵最大化。然而,如果要求智能體預測未來可能遇到的情境,這一額外任務會促進還是阻礙其學習過程呢?
關于這個問題存在兩種相反的觀點。一方面,學會預判可能會讓智能體對環境有更深入的理解,從而學得更好更快。另一方面,這可能會分散智能體對其主要目標的注意力,導致在預測觀察嵌入和行動嵌入時都表現平庸。
為了得到問題的答案,論文作者進行了一個實驗,使用了一個結合了觀察損失和行動損失的損失函數,并通過權重參數 k 來平衡這兩種損失。
研究團隊在 95% 的置信區間內,針對選定任務,測量了預判將如何影響模型學習。每項任務進行了 100 次評估,基于這些評估得到了 k 值的范圍。結果表明,適當選擇 k 值可以顯著提升智能體的表現。
當 k 值過高(高于 0.5)時,預測觀察嵌入的額外任務阻礙了學習過程。但當 k 值較低時,對學習的影響可以忽略不計,且智能體的表現與沒有額外預判任務時的表現相似。
研究團隊發現,當 k=0.005 時,存在一個最佳臨界點。這意味著,只要平衡得當,為智能體增加預測觀察嵌入的任務,實際上可以提高智能體的學習效率。這一發現對于設計類似的智能體具有重要意義,突顯了輔助目標在提升智能體學習效率方面的潛在價值。
未來展望
JAT 項目為通用智能體研究領域開辟了全新的方向。研究團隊表示目前只是初步探索,以下幾點思路可供未來研究者深入挖掘:
改進數據的質量:盡管填補了之前少有通用智能體訓練數據集的空缺,JAT 數據集仍處于初級階段。其中的專家軌跡僅來自每個環境中的一名專家智能體,這可能導致一些誤差。雖然研究團隊已盡力讓智能體達到最優表現,但某些環境仍具挑戰性。在這些環境中,智能體仍有很大進步空間。收集到更多數據,訓練更多的專家智能體,將在很大程度上解決這些問題。
使用離線強化學習:JAT 智能體是仿照基線一比一地訓練出來的。這意味著,其一,智能體無法利用次優的軌跡;其二,JAT 智能體無法超越專家。論文選擇了這種方法是因為它比較簡單,但研究團隊相信,使用離線強化學習可以提高智能體的性能,同時,實現起來也不會過于復雜。
發揮更智能的多任務采樣策略的全部潛力:目前,JAT 智能體均勻地從所有任務中采樣數據,但這種方法可能限制了它的全部潛力。通過動態調整采樣率,專注于最具挑戰性的任務,或許也可以加速智能體的學習過程,并解鎖顯著的性能提升。
本文轉自 機器之心 ,作者:機器之心
