無限生成視頻,還能規劃決策,擴散強制整合下一token預測與全序列擴散
近日,MIT CSAIL 的一個研究團隊(一作為 MIT 在讀博士陳博遠)成功地將全序列擴散模型與下一 token 模型的強大能力統合到了一起,提出了一種訓練和采樣范式:Diffusion Forcing(DF)。
- 論文標題:Diffusion Forcing:Next-token Prediction Meets Full-Sequence Diffusion
- 論文地址:https://arxiv.org/pdf/2407.01392
- 項目網站:https://boyuan.space/diffusion-forcing
- 代碼地址:https://github.com/buoyancy99/diffusion-forcing?
如下所示,擴散強制在一致性和穩定性方面都明顯勝過全序列擴散和教師強制這兩種方法。
在該框架中,每個 token 都關聯了一個隨機的、獨立的噪聲水平,并且可使用一種共享的下一 token 預測模型或下幾 token 預測模型根據任意的、獨立的、每 token 的方案對 token 進行去噪。
該方法的研究靈感來自這一觀察:對 token 加噪聲的過程就是一種形式的部分掩碼過程 —— 零噪聲就意味著未對 token 加掩碼,而完整噪聲則是完全掩蔽 token。因此,DF 可強迫模型學習去除任何可變有噪聲 token 集合的掩碼(圖 2)。
與此同時,通過將預測方法參數化為多個下一 token 預測模型的組合,該系統可以靈活地生成不同長度的序列,并以組合方式泛化到新的軌跡(圖 1)。
該團隊將用于序列生成的 DF 實現成了因果擴散強制(Causal Diffusion Forcing/CDF),其中未來 token 通過一個因果架構依賴于過去 token。他們訓練該模型一次性去噪序列的所有 token(其中每個 token 都有獨立的噪聲水平)。
在采樣期間,CDF 會將一個高斯噪聲幀序列逐漸地去噪成潔凈的樣本,其中不同幀在每個去噪步驟可能會有不同的噪聲水平。類似于下一 token 預測模型,CDF 可以生成長度可變的序列;不同于下一 token 預測,CDF 的表現非常穩定 —— 不管是預測接下來的一個 token,還是未來的數千 token,甚至是連續 token。
此外,類似于全序列擴散,它也可接收引導,從而實現高獎勵生成。通過協同利用因果關系、靈活的范圍和可變噪聲調度,CDF 能實現一項新功能:蒙特卡洛樹引導(MCTG)。相比于非因果全序列擴散模型,MCTG 能極大提升高獎勵生成的采樣率。圖 1 給出了這些能力的概況。
Diffusion Forcing(擴散強制)
1、將加噪過程視為部分掩碼
首先,我們可以將任意 token 集合(不管是否為序列)視為一個通過 t 索引的有序集合。那么,使用教師強制(teacher forcing)訓練下一 token 預測便可被解釋成掩蔽掉時間 t 的每個 token x_t 基于過去 x_{1:t?1} 預測它們。
對于序列,可將這種操作描述成:沿時間軸執行掩碼。我們可以將全序列前向擴散(即逐漸向數據
添加噪聲的過程)看作一種部分掩碼(partial masking),這可被稱為「沿噪聲軸執行掩碼)。
事實上,在 K 步加噪之后,
(大概)就是白噪聲了,不再有任何有關原數據的信息。
如圖 2 所示,該團隊建立了一個統一視角來看待沿這兩個軸的掩碼。
2、擴散強制:不同 token 的噪聲水平不同
擴散強制(DF)框架可用于訓練和采樣任意序列長度的有噪聲 token
,其中的關鍵在于每個 token 的噪聲水平 k_t 會隨時間步驟而變化。
這篇論文關注的重點是時間序列數據,因此他們通過一種因果架構實例化了 DF,并由此得到了因果擴散強制(CDF)。簡單來說,這是使用基礎循環神經網絡(RNN)獲得的一種最小實現。
權重為 θ 的 RNN 維護著獲悉過去 token 影響的隱藏狀態 z_t,其會通過一個循環層根據動態
而演化。當獲得輸入噪聲觀察
時,就以馬爾可夫方式更新該隱藏狀態。
當 k_t=0 時,這就是貝葉斯過濾中的后驗更新;而當 k_t= K(純噪聲、無信息)時,這就等價于建模貝葉斯過濾中的「后驗分布」p_θ(z_t | z_{t?1})。
給定隱藏狀態 z_t,觀察模型 p_θ(x_t^0 | z_t) 的目標是預測 x_t;這個單元的輸入 - 輸出行為與標準的條件擴散模型一樣:以條件變量 z_{t?1} 和有噪聲 token 為輸入,預測無噪聲的 x_t=x_t^0,并由此間接地通過仿射重新參數化預測噪聲 ε^{k_t}。因此,我們就可以直接使用經典的擴散目標來訓練(因果)擴散強制。根據噪聲預測結果 ε_θ,可以對上述單元進行參數化。然后,通過最小化以下損失來找到參數 θ:
算法 1 給出了偽代碼。重點在于,該損失捕獲了貝葉斯過濾和條件擴散的關鍵元素。該團隊也進一步重新推斷了用于擴散強制的擴散模型訓練中的常用技術,詳見原論文的附錄部分。他們也得出了一個非正式的定理。
定理 3.1(非正式)。擴散強制訓練流程(算法 1)是在期望對數似然
上優化證據下限(ELBO)的重新加權,其中期望值會在噪聲水平上平均,而
是根據前向過程加噪。此外,在適當條件下,優化 (3.1) 式還可以同時最大化所有噪聲水平序列的似然下限。
擴散強制采樣和所得到的能力
算法 2 描述了采樣過程,其定義是:在二維的 M × T 網格 K ∈ [K]^{M×T} 上指定噪聲調度;其中列對應于時間步驟 t,m 索引的行則決定了噪聲水平。
為了生成長度為 T 的整個序列,先將 token x_{1:T} 初始化為白噪聲,對應于噪聲水平 k = K。然后沿著網格逐行向下迭代,并從左到右逐列去噪,直到噪聲水平達到 K。到最后一行 m = 0 時,token 的噪聲已清理干凈,即噪聲水平為 K_{0,t} ≡ 0。
這個采樣范式會帶來如下新能力:
- 讓自回歸生成變得穩定
- 保持未來的不確定
- 長期引導能力
將擴散強制用于靈活的序列決策
擴散強制的新能力也帶來了新的可能性。該團隊基于此為序列決策(SDM)設計了一種全新框架,并且將其成功應用到了機器人和自主智能體領域。
首先,定義一個馬爾可夫決策過程,該過程具有動態 p (s_{t+1}|s_t, a_t)、觀察 p (o_t|s_t) 和獎勵 p (r_t|s_t, a_t)。這里的目標是訓練一個策略 π(a_t|o_{1:t}),使得軌跡
的預期累積獎勵最大化。這里分配 token x_t = [a_t, r_t, o_{t+1}]。一條軌跡就是一個序列 x_{1:T},其長度可能是可變的;訓練方式則如算法 1 所示。
在執行過程的每一步 t,都有一個隱藏狀態 z_{t-1} 總結過去的無噪聲 token x_{1:t-1}。基于這個隱藏狀態,根據算法 2 采樣一個規劃
,其中
包含預測的動作、獎勵和觀察。H 是一個前向觀察窗口,類似于模型預測控制中的未來預測。在采用了規劃的動作之后,環境會得到一個獎勵和下一個觀察,從而得到下一個 token。其中隱藏狀態可以根據后驗 p_θ(z_t|z_{t?1}, x_t, 0) 獲得更新。
該框架既可以作為策略,也可以作為規劃器,其優勢包括:
- 具有靈活的規劃范圍
- 可實現靈活的獎勵引導
- 能實現蒙特卡洛樹引導(MCTG),從而實現未來不確定性
實驗
該團隊評估了擴散強制作為生成序列模型的優勢,其中涉及視頻和時間序列預測、規劃和模仿學習等多種應用。
視頻預測:一致且穩定的序列生成和無限展開
針對視頻生成式建模任務,他們基于 Minecraft 游戲視頻和 DMLab 導航為因果擴散強制訓練了一個卷積 RNN 實現。
圖 3 展示了擴散強制與基準的定性結果。
可以看到,擴散強制能穩定地展開,甚至能超過其訓練范圍;而教師強制和全序列擴散基準會很快發散。
擴散規劃:MCTG、因果不確定性、靈活的范圍控制
擴散強制的能力能為決策帶來獨有的好處。該團隊使用一種標準的離線強化學習框架 D4RL 評估了新提出的決策框架。
表 1 給出了定性和定量的評估結果??梢钥吹剑瑪U散強制在全部 6 個環境中都優于 Diffuser 和所有基準。
可控的序列組合生成
該團隊發現,僅需修改采樣方案,就可以靈活地組合訓練時間觀察到的序列的子序列。
他們使用一個 2D 軌跡數據集進行了實驗:在一個方形平面上,所有軌跡都是始于一角并最終到達對角,形成一種十字形。
如上圖 1 所示,當不需要組合行為時,可讓 DF 保持完整記憶,復制十字形的分布。當需要組合時,可讓模型使用 MPC 無記憶地生成更短的規劃,從而實現對這個十字形的子軌跡的縫合,得到 V 形軌跡。
機器人:長范圍模仿學習和穩健的視覺運動控制
擴散強制也為真實機器人的視覺運動控制帶來了新的機會。
模仿學習是一種常用的機器人操控技術,即學習專家演示的觀察到動作的映射。但是,缺乏記憶往往會讓模仿學習難以完成長范圍的任務。DF 不僅能緩解這個短板,還能讓模仿學習更穩健。
使用記憶進行模仿學習。通過遙控 Franka 機器人,該團隊收集了一個視頻和動作數據集。如圖 4 所示,任務就是利用第三個位置交換蘋果和橘子的位置。水果的初始位置是隨機的,因此可能的目標狀態有兩個。
此外,當第三個位置有一個水果時,就無法通過當前觀察推斷出所需結果 —— 策略必須記住初始配置才能決定移動哪個水果。不同于常用的行為克隆方法,DF 可以自然地將記憶整合進自己的隱藏狀態中。結果發現,DF 能實現 80% 的成功率,而擴散策略(當前最佳的無記憶模仿學習算法)卻失敗了。
此外,DF 還能更穩健地應對噪聲并助益機器人預訓練。
時間序列預測:擴散強制是一種優秀的通用序列模型
對于多變量時間序列預測任務,該團隊的研究表明 DF 足以與之前的擴散模型和基于 Transformer 的模型媲美。
更多技術細節和實驗結果請參閱原論文。
本文轉自機器之心 ,作者:機器之心
