清華&英偉達(dá)揭露ICML最佳論文隱藏?cái)?shù)值問(wèn)題,簡(jiǎn)單掩碼模型是等價(jià)更好選擇 | ICLR25
掩碼離散擴(kuò)散模型,可能并沒有看上去那么厲害。
這是清華及英偉達(dá)研究人員最新提出的觀點(diǎn)。
他們發(fā)現(xiàn),作為離散擴(kuò)散模型中性能最強(qiáng)的類別,掩碼擴(kuò)散模型可能有點(diǎn)“被包裝過(guò)度”了。為啥呢?
第一,這類模型所宣稱的超高性能,其實(shí)是由于一個(gè)技術(shù)上的小瑕疵,用32位計(jì)算時(shí),模型會(huì)產(chǎn)生一種“降溫”效果,使模型看起來(lái)表現(xiàn)很好,但實(shí)際上只是多樣性被降低了。用更精確的64位計(jì)算,就會(huì)發(fā)現(xiàn)它們的表現(xiàn)并不如宣稱的那么好。
第二,這些模型引入了“時(shí)間”的概念,看起來(lái)很高級(jí),但研究發(fā)現(xiàn)這完全沒必要。
第三,這些模型其實(shí)與已有的簡(jiǎn)單掩碼模型完全等價(jià),只要正確設(shè)置簡(jiǎn)單模型的參數(shù),就能達(dá)到相同效果。
目前,這篇研究已入選ICLR 2025。
具體說(shuō)了啥?一起來(lái)看。
背景
隨著SEDD獲得ICML 2024最佳論文獎(jiǎng),起源于D3PM的離散擴(kuò)散模型迎來(lái)了復(fù)興并成為自回歸范式的有力競(jìng)爭(zhēng)者,在文本、蛋白質(zhì)等離散序列生成任務(wù)上掀起了研究熱潮。
作為離散擴(kuò)散模型中性能最強(qiáng)的類別,掩碼式離散擴(kuò)散模型(簡(jiǎn)稱掩碼擴(kuò)散模型)在后續(xù)工作中被進(jìn)一步簡(jiǎn)化,從而在理論形式上與連續(xù)空間擴(kuò)散模型對(duì)齊。
掩碼擴(kuò)散模型通過(guò)引入一個(gè)連續(xù)的“時(shí)間”或“噪聲水平”的概念,定義了一個(gè)從原始數(shù)據(jù)逐漸“加噪”(掩碼)到完全掩碼狀態(tài)的前向過(guò)程,以及一個(gè)學(xué)習(xí)從掩碼狀態(tài)逐步“去噪”(預(yù)測(cè)被掩碼部分)恢復(fù)數(shù)據(jù)的反向(生成)過(guò)程。
在使用生成式困惑度(Gen PPL)作為衡量文本生成質(zhì)量的指標(biāo)時(shí),掩碼擴(kuò)散模型在先前工作中均顯示出了隨采樣步數(shù)增加的性能提升,并在足夠多步數(shù)下超越自回歸模型。
這種對(duì)比是否公平?同時(shí),作為離散空間中的“擴(kuò)散”模型,是否意味著其可以借鑒標(biāo)準(zhǔn)擴(kuò)散模型相關(guān)算法來(lái)增強(qiáng)性能?
論文從訓(xùn)練和采樣兩個(gè)方面對(duì)掩碼擴(kuò)散模型進(jìn)行解構(gòu)。
掩碼擴(kuò)散模型與掩碼模型的差異
盡管掩碼擴(kuò)散模型借鑒了擴(kuò)散模型的框架,但其核心操作與經(jīng)典的掩碼語(yǔ)言模型(如 BERT,Mask-Predict) 或掩碼圖像生成模型(如 MaskGIT)有著相似之處:都是對(duì)部分內(nèi)容進(jìn)行掩碼(masking),然后預(yù)測(cè)被掩碼的內(nèi)容。
BERT在訓(xùn)練時(shí)只會(huì)掩碼一小部分token,適用任務(wù)為表征學(xué)習(xí)、文本理解而非生成,而Mask-Predict與MaskGIT擴(kuò)大了掩碼比例的范圍并可用于文本、圖像生成。
相比于掩碼模型,掩碼擴(kuò)散模型引入了一個(gè)關(guān)鍵的復(fù)雜性:時(shí)間步(time step)。其訓(xùn)練和采樣都嚴(yán)格依賴于一個(gè)預(yù)先定義的、隨時(shí)間變化的掩碼(噪聲)調(diào)度。
模型需要根據(jù)當(dāng)前的時(shí)間步 t 來(lái)預(yù)測(cè)原始數(shù)據(jù)。
具體而言,它和掩碼模型的差異體現(xiàn)在:
在訓(xùn)練中,掩碼模型被掩碼的token數(shù)量及不同掩碼比例對(duì)應(yīng)的損失權(quán)重可以隨意設(shè)定;掩碼擴(kuò)散模型同一時(shí)間對(duì)應(yīng)被掩碼token的數(shù)量是不確定的,不同時(shí)間的分布及權(quán)重需要特殊設(shè)置使得損失構(gòu)成模型似然(likelihood)的證據(jù)下界(ELBO)。
在采樣中,掩碼模型按照token為粒度,逐token解碼;掩碼擴(kuò)散模型以時(shí)間為粒度進(jìn)行離散化,從時(shí)間t轉(zhuǎn)移到更小的時(shí)間s時(shí),每個(gè)token被解碼與否通過(guò)概率采樣決定,被解碼token的數(shù)量是不確定的。
掩碼擴(kuò)散模型的采樣存在隱性數(shù)值問(wèn)題
先前評(píng)估掩碼擴(kuò)散模型性能的關(guān)鍵指標(biāo)之一是Gen PPL,其通過(guò)計(jì)算參考模型(如GPT-2)對(duì)模型生成內(nèi)容的“驚訝程度”來(lái)衡量生成質(zhì)量。
然而,Gen PPL 指標(biāo)對(duì)采樣過(guò)程中的超參數(shù)(如采樣溫度)極為敏感,并且可以通過(guò)調(diào)整這些參數(shù)輕易地“刷低”數(shù)值,但這并不代表模型本身的生成能力有實(shí)質(zhì)提升。
本研究首次揭露,掩碼擴(kuò)散模型特有的采樣過(guò)程存在隱藏的數(shù)值問(wèn)題,即使在常用的32位浮點(diǎn)數(shù)精度下也會(huì)帶來(lái)類似于降低溫度的效果。
為了顯示這一點(diǎn),論文額外測(cè)試了生成句子的熵(entropy)來(lái)衡量生成多樣性。
隨著采樣步數(shù)的增加,Gen PPL不斷下降并超過(guò)自回歸模型(左圖),然而熵也在持續(xù)降低(右圖)。
當(dāng)采樣過(guò)程以64位精度進(jìn)行時(shí),熵穩(wěn)定在與自回歸模型類似的水平,而Gen PPL則顯著升高并遠(yuǎn)遠(yuǎn)落后于自回歸模型。
論文通過(guò)進(jìn)一步的數(shù)學(xué)推導(dǎo),從理論上解釋了這一溫度降低效果的根源。
具體而言,在[0,1)區(qū)間上均勻采樣的浮點(diǎn)數(shù)實(shí)際范圍為[0,1-ε],其中ε是一個(gè)接近0的小數(shù),這會(huì)導(dǎo)致基于Gumbel-max技巧的類別采樣(categorical sampling)存在截?cái)鄦?wèn)題。
最終的分布不服從原先的類別概率π,而會(huì)偏移到,其中
這一偏移會(huì)加強(qiáng)原先概率已經(jīng)比較大的類別,從而達(dá)到類似降低溫度的效果。
通過(guò)對(duì)類別采樣部分的代碼做對(duì)照試驗(yàn),文章驗(yàn)證了64位采樣+手動(dòng)截?cái)啻_實(shí)可以復(fù)現(xiàn)32位采樣的效果。
同時(shí),上述數(shù)值問(wèn)題對(duì)于單個(gè)token并不顯著,逐token解碼的模型(如自回歸模型、掩碼模型)在32位下基本不受影響。
然而,此問(wèn)題會(huì)在掩碼擴(kuò)散模型中額外影響所有token之間的交互,導(dǎo)致某些token被優(yōu)先解碼,進(jìn)一步降低生成多樣性。
可以說(shuō),這是掩碼擴(kuò)散模型采樣過(guò)程獨(dú)有的數(shù)值問(wèn)題。
掩碼模型與掩碼擴(kuò)散模型的等價(jià)性
先前工作從最優(yōu)網(wǎng)絡(luò)的角度證明了掩碼擴(kuò)散模型中的時(shí)間并不必要,本論文進(jìn)一步在訓(xùn)練和采樣兩方面證明掩碼擴(kuò)散模型和掩碼模型的等價(jià)性。
具體而言:
在訓(xùn)練損失函數(shù)上,掩碼擴(kuò)散模型與時(shí)間有關(guān)的似然下界等價(jià)于掩碼模型的以token為粒度的損失函數(shù),只要滿足:(1)被掩碼token的個(gè)數(shù)n在1和L之間均勻采樣,其中L是序列的總長(zhǎng)度(2)預(yù)測(cè)損失對(duì)n取均值,即施加“似然權(quán)重”1/n來(lái)實(shí)現(xiàn)最大似然訓(xùn)練。
需要注意的是,對(duì)不同時(shí)間/掩碼比例施加的權(quán)重并不影響網(wǎng)絡(luò)在無(wú)限容量下的最優(yōu)值,而決定了網(wǎng)絡(luò)訓(xùn)練過(guò)程中的重點(diǎn)優(yōu)化區(qū)域。
文本生成的自回歸范式采取了最大似然訓(xùn)練,而在圖像上,最大似然訓(xùn)練往往會(huì)帶來(lái)生成質(zhì)量的下降。
在采樣過(guò)程上,掩碼擴(kuò)散模型逆時(shí)間的采樣過(guò)程可以通過(guò)論文提出的首達(dá)采樣器(first-hitting sampler),轉(zhuǎn)化為與掩碼模型相同的逐token采樣,最多需要L步便可達(dá)到?jīng)]有離散化誤差的精確采樣,而掩碼擴(kuò)散模型原有采樣過(guò)程需要對(duì)時(shí)間無(wú)限細(xì)分才能完全精確。同時(shí),采用逐token解碼可以避免上文所述的隱藏?cái)?shù)值問(wèn)題。
結(jié)語(yǔ)
掩碼擴(kuò)散模型引入的“時(shí)間”概念可能不僅無(wú)益,反而有害(導(dǎo)致數(shù)值問(wèn)題和不必要的復(fù)雜性)。
同時(shí),其雖然帶有“擴(kuò)散”兩字,但與連續(xù)空間上的擴(kuò)散模型及其相關(guān)算法關(guān)系不大,如在論文中,作者仿照擴(kuò)散模型為掩碼擴(kuò)散模型開發(fā)了高階采樣算法,其并不如連續(xù)空間中的加速效果顯著。
在實(shí)踐中,使用掩碼擴(kuò)散模型、引入連續(xù)時(shí)間相關(guān)的訓(xùn)練/采樣過(guò)程或許并不必要,簡(jiǎn)單的掩碼模型(如 MaskGIT 及其變種)在概念上更簡(jiǎn)潔,實(shí)現(xiàn)上更穩(wěn)定,并且在理論上具有同等的潛力。
掩碼模型作為自回歸模型使用隨機(jī)token順序和雙向注意力機(jī)制的變種,同樣是基于似然的模型,可以作為建模離散數(shù)據(jù)生成的另一種選擇。
由于雙向注意力和KV cache機(jī)制不兼容,掩碼模型在長(zhǎng)上下文的推理速度上相較自回歸模型存在瓶頸。
近期工作通過(guò)在雙向注意力和因果注意力機(jī)制之間插值、使用隨機(jī)順序的自回歸模型等方法使模型保持雙向感知能力的同時(shí),推理速度向自回歸模型靠近。
也有工作探究非掩碼類型的離散擴(kuò)散模型與連續(xù)空間擴(kuò)散模型的理論聯(lián)系,其擴(kuò)散機(jī)制更加屬實(shí),而非如掩碼擴(kuò)散模型一樣是可有可無(wú)的噱頭。
論文第一作者鄭凱文為清華大學(xué)計(jì)算機(jī)系三年級(jí)博士生,在ICML、NeurIPS、ICLR發(fā)表擴(kuò)散模型相關(guān)一作5篇。文章通訊作者為朱軍教授,合作者張欽圣、陳永昕、毛含子為英偉達(dá)研究員,劉洺堉為英偉達(dá)副總裁與Deep Imagination研究組主管。
論文標(biāo)題:Masked Diffusion Models are Secretly Time-Agnostic Masked Models and Exploit Inaccurate Categorical Sampling