ICML 2025 | 生成式視角重塑監(jiān)督學(xué)習(xí)!標(biāo)簽不只是答案,更是學(xué)習(xí)指南
生成式視角可以對(duì)監(jiān)督學(xué)習(xí)重新思考乃至重新定義!
想象你在教一個(gè)學(xué)生解數(shù)學(xué)題——你會(huì)直接讓他交卷對(duì)答案,還是會(huì)讓他參考完整答案來(lái)理解解題思路?
如今,一種全新的監(jiān)督學(xué)習(xí)范式正受到關(guān)注:標(biāo)簽不應(yīng)只是用于對(duì)照回答的標(biāo)準(zhǔn)答案,更可能成為學(xué)習(xí)過(guò)程中的輔助參考。
受生成式一致性模型的啟發(fā),來(lái)自上海交大、SII、MIT、港中文深圳等機(jī)構(gòu)的研究團(tuán)隊(duì)在ICML 2025最新提出預(yù)測(cè)一致性學(xué)習(xí)(PCL,Predictive Consistency Learning)。
PCL通過(guò)擴(kuò)散模型的擴(kuò)散過(guò)程消減標(biāo)簽的信息,將噪聲標(biāo)簽(Noised Labels)引入模型的輸入,使得模型在數(shù)據(jù)輸入和噪聲標(biāo)簽的共同參照下預(yù)測(cè)完整標(biāo)簽,實(shí)現(xiàn)標(biāo)簽信息的復(fù)用和價(jià)值挖掘。
訓(xùn)練過(guò)程概覽
傳統(tǒng)監(jiān)督學(xué)習(xí)中,輸入通過(guò)神經(jīng)網(wǎng)絡(luò)預(yù)測(cè)
,通過(guò)對(duì)比和標(biāo)準(zhǔn)答案
之間的關(guān)系,來(lái)計(jì)算損失和反向傳播更新模型,對(duì)應(yīng)損失函數(shù)
其中為具體損失函數(shù),
為神經(jīng)網(wǎng)絡(luò)函數(shù)。受生成一致性模型中一致性映射思想的啟發(fā),PCL對(duì)應(yīng)一種全新的監(jiān)督學(xué)習(xí)范式,旨在通過(guò)漸進(jìn)式分解標(biāo)簽信息來(lái)更好地捕捉復(fù)雜標(biāo)簽的完整表征,使得模型在部分標(biāo)簽信息的提示下實(shí)現(xiàn)完整標(biāo)簽信息的預(yù)測(cè)。
與傳統(tǒng)方法不同,PCL學(xué)習(xí)框架將完整標(biāo)簽的學(xué)習(xí)過(guò)程分解為逐步逼近的漸進(jìn)式任務(wù):模型會(huì)接收一個(gè)含有部分標(biāo)簽信息的額外輸入作為提示,首先學(xué)習(xí)捕捉互補(bǔ)的部分標(biāo)簽信息,隨后逐步逼近完整標(biāo)簽。
為了系統(tǒng)化地規(guī)劃標(biāo)簽學(xué)習(xí)過(guò)程,研究人員借鑒了擴(kuò)散模型和一致性模型中的加噪過(guò)程,通過(guò)生成帶噪標(biāo)簽作為額外的輸入提示,使模型能夠在學(xué)習(xí)帶噪部分的同時(shí)補(bǔ)充完整信息。
具體而言,PCL在訓(xùn)練時(shí):以輸入數(shù)據(jù)為條件,將不同噪聲水平的帶噪標(biāo)簽映射回真實(shí)標(biāo)簽,噪聲程度由時(shí)間步控制;約束不同噪聲時(shí)間步的預(yù)測(cè)結(jié)果均一致地逼近目標(biāo)標(biāo)簽。
模型每次采樣兩個(gè)不同的時(shí)間步,要求模型在不同時(shí)間步的提示下盡可能精準(zhǔn)還原標(biāo)簽,同時(shí)預(yù)測(cè)的結(jié)果盡可能保持一致。
通過(guò)這種跨噪聲水平的一致性約束,模型能夠?qū)W習(xí)從完全噪聲到精確標(biāo)簽的不同層級(jí)的標(biāo)簽信息,從而構(gòu)建更具表達(dá)力的映射關(guān)系。
預(yù)測(cè)一致性機(jī)制的作用在于,將低噪聲條件下的預(yù)測(cè)精度傳遞至高噪聲條件,同時(shí)約束模型在不同噪聲水平下表征的不變性,從而減小對(duì)于標(biāo)簽提示的過(guò)度依賴,服務(wù)于測(cè)試過(guò)程。最終損失函數(shù)形式為:
其中模型接收輸入,兩個(gè)不同時(shí)間步的噪聲標(biāo)簽
和相應(yīng)的時(shí)間步
,模型在兩個(gè)不同噪聲標(biāo)簽的提示下完成預(yù)測(cè),在預(yù)測(cè)結(jié)果逼近噪聲標(biāo)簽的同時(shí),額外約束兩個(gè)預(yù)測(cè)結(jié)果的一致性。
分別控制預(yù)測(cè)精度loss和預(yù)測(cè)一致性loss的權(quán)重。
標(biāo)簽噪聲過(guò)程
離散標(biāo)簽的噪聲過(guò)程:對(duì)于多維分類標(biāo)簽,其中
表示類別數(shù),
表示維度,研究人員遵循離散擴(kuò)散模型將噪聲過(guò)程建模為在每個(gè)時(shí)間步
引入類別噪聲到標(biāo)簽中。
他們將標(biāo)簽表示為,它是
個(gè)獨(dú)熱編碼向量的拼接。噪聲可以理解為在每個(gè)維度的不同類別之間進(jìn)行轉(zhuǎn)換。從初始點(diǎn)
開始,噪聲過(guò)程定義為:
其中是在
個(gè)獨(dú)熱向量上的分類分布,概率由
給出,
是轉(zhuǎn)移矩陣,決定了在時(shí)間步
引入的噪聲,對(duì)應(yīng)的標(biāo)簽類別以概率
轉(zhuǎn)移到任何其他類別。
隨著時(shí)間的推移,當(dāng)接近最終時(shí)間步
時(shí),標(biāo)簽會(huì)收斂到
個(gè)類別的均勻分布。由于噪聲矩陣可以事先計(jì)算,因此跨步噪聲計(jì)算代價(jià)很低。
連續(xù)標(biāo)簽的噪聲過(guò)程:對(duì)于多維連續(xù)標(biāo)簽,其中
表示維度,研究人員遵循高斯擴(kuò)散模型將擴(kuò)散過(guò)程建模為在每個(gè)時(shí)間步向標(biāo)簽引入高斯噪聲。在每個(gè)時(shí)間步
,高斯噪聲被應(yīng)用于標(biāo)簽,逐步將其推向一個(gè)噪聲分布。噪聲過(guò)程定義為:
其中是均值為
,協(xié)方差為
的高斯分布,
控制在時(shí)間步
上添加噪聲的方差。隨著時(shí)間的推移,當(dāng)
接近最終時(shí)間步
時(shí),標(biāo)簽會(huì)收斂到一個(gè)以零為中心的高斯分布。同樣噪聲函數(shù)可以事先計(jì)算,因此跨步噪聲計(jì)算代價(jià)很低。
嵌入空間的噪聲過(guò)程:在標(biāo)簽過(guò)于復(fù)雜,無(wú)法直接表示為分類或連續(xù)值,或者類別數(shù)過(guò)大時(shí),PCL直接向標(biāo)簽的潛在嵌入空間引入高斯噪聲,這種方式與連續(xù)標(biāo)簽的噪聲過(guò)程一致。
測(cè)試過(guò)程概覽
在訓(xùn)練完成后的推理階段,可以通過(guò)從隨機(jī)噪聲分布采樣標(biāo)簽作為提示信息,并進(jìn)行單次前向傳播來(lái)進(jìn)行高效預(yù)測(cè)。
由于不包含任何信息量,這個(gè)推理過(guò)程實(shí)際上和傳統(tǒng)監(jiān)督學(xué)習(xí)的直接預(yù)測(cè)是一致的。然而在訓(xùn)練階段的改進(jìn)使得PCL模型具有更好的預(yù)測(cè)能力,即使不依賴任何標(biāo)簽提示已經(jīng)能夠超越傳統(tǒng)監(jiān)督學(xué)習(xí)的精度。
在訓(xùn)練過(guò)程中,當(dāng)較小時(shí),直接預(yù)測(cè)精度往往較高,因?yàn)闃?biāo)簽提示包含更多的信息。目標(biāo)是通過(guò)訓(xùn)練將這種高精度逐步轉(zhuǎn)移到較大的
值,從而提升整體模型性能。
在理想情況下,當(dāng)一致性損失趨于零時(shí),可以通過(guò)一步推理獲得最優(yōu)結(jié)果,但實(shí)際上,通過(guò)逐步將從
降至0可以規(guī)劃不同層級(jí)標(biāo)簽信息的預(yù)測(cè),帶來(lái)精度的提升。
為了實(shí)現(xiàn)這種提升,可以采用多步推理策略,通過(guò)對(duì)上一步標(biāo)簽重新引入噪聲作為下一步預(yù)測(cè)的標(biāo)簽提示并且交替執(zhí)行預(yù)測(cè),使得模型能夠在多個(gè)推理步驟中逐步細(xì)化其輸出,并利用早期預(yù)測(cè)中嵌入的越來(lái)越豐富的提示信息。
給定一系列時(shí)間點(diǎn),在每一步
,上一步預(yù)測(cè)
會(huì)通過(guò)噪聲函數(shù)被擾動(dòng)到狀態(tài)
作為下一步預(yù)測(cè)的噪聲提示信息,從而修正
預(yù)測(cè)。
噪聲水平隨著每一步的進(jìn)行而降低,即。然后,模型通過(guò)應(yīng)用
對(duì)標(biāo)簽進(jìn)行更精確的預(yù)測(cè)。這個(gè)過(guò)程會(huì)在接下來(lái)的步驟中重復(fù)進(jìn)行,每一步新的標(biāo)簽提示信息都包含了從前一步獲取的更精確的信息。
這使得模型能夠逐步恢復(fù)的全部信息,通過(guò)將可能的近似預(yù)測(cè)作為標(biāo)簽提示,并利用逐步增益的信息來(lái)進(jìn)行最終預(yù)測(cè)。
信息論視角
從理論角度出發(fā),在標(biāo)準(zhǔn)監(jiān)督學(xué)習(xí)場(chǎng)景下,模型的主要目標(biāo)是捕捉輸入和標(biāo)簽
之間的互信息。
通常,由于輸入的信息量常常遠(yuǎn)遠(yuǎn)大于標(biāo)簽
,模型希望通過(guò)學(xué)習(xí)一個(gè)壓縮的特征表示來(lái)最大化
并最小化
,其中
是從
提取的特征表示。默認(rèn)情況下,從
到
之間的映射是直接且容易捕捉的。
然而,隨著任務(wù)的復(fù)雜性增加,標(biāo)簽的信息也變得越來(lái)越復(fù)雜,例如高維度、復(fù)雜的內(nèi)部結(jié)構(gòu)等。這使得從
到
之間的映射變得更加復(fù)雜,模型需要應(yīng)對(duì)更加困難的學(xué)習(xí)問(wèn)題。
為了更有效地建模,相較于一次性學(xué)習(xí)
所有的信息,PCL的設(shè)計(jì)實(shí)則提出了一種結(jié)構(gòu)化的學(xué)習(xí)過(guò)程,逐步捕捉這些信息。
為了將標(biāo)簽信息分解為一個(gè)更為漸進(jìn)的學(xué)習(xí)過(guò)程,PCL引入了一個(gè)附加的噪聲標(biāo)簽,用于在每次迭代中調(diào)節(jié)學(xué)習(xí)的標(biāo)簽信息量。通過(guò)引入
,原始的互信息
可以分解為如下形式:
由于是由
推導(dǎo)而來(lái)的,當(dāng)
已知時(shí),
對(duì)
并沒(méi)有額外的信息,因此冗余項(xiàng)
,公式簡(jiǎn)化為:
這一分解揭示了兩個(gè)關(guān)鍵成分,其中第一個(gè)項(xiàng)捕捉了在給定
的條件下可以學(xué)習(xí)的
的增量信息。該項(xiàng)作為
的下界,且它們之間的差距可以通過(guò)
的信息量進(jìn)行控制。
通過(guò)最大化,模型逐步學(xué)習(xí)捕捉
的完整信息內(nèi)容。具體而言,當(dāng)
時(shí),
提供的信息極少,迫使模型完全捕捉
;而當(dāng)
時(shí),
逼近
,允許模型專注于優(yōu)化標(biāo)簽的細(xì)節(jié)。
在訓(xùn)練過(guò)程中,通過(guò)隨機(jī)采樣一批值,模型能夠同時(shí)學(xué)習(xí)標(biāo)簽的不同方面。最初,模型期望能夠輕松捕捉標(biāo)簽的部分細(xì)節(jié),通過(guò)迭代訓(xùn)練,模型逐步積累
的完整信息內(nèi)容。
在實(shí)現(xiàn)方面,模型暴露于帶噪聲的。模型的輸入包括
和
,其中
作為條件輸入。盡管引入
作為輔助輸入有助于學(xué)習(xí),但最終目標(biāo)是使模型盡可能少地依賴
來(lái)進(jìn)行預(yù)測(cè)。
形式上,目標(biāo)是最小化噪聲條件依賴,該項(xiàng)衡量模型預(yù)測(cè)在多大程度上依賴于噪聲標(biāo)簽
。理想情況下,這一項(xiàng)應(yīng)該為零,表明在給定
和模型參數(shù)
的條件下,模型的預(yù)測(cè)與
無(wú)關(guān)。數(shù)學(xué)上,它可以通過(guò)以下公式進(jìn)行度量:
該項(xiàng)對(duì)應(yīng)預(yù)測(cè)一致性的約束,確保對(duì)于所有的和
,都有
。這種正則化確保了模型的預(yù)測(cè)在不同噪聲水平下保持一致,從而減少了對(duì)
的依賴,鼓勵(lì)
盡可能編碼所有必要的信息,以實(shí)現(xiàn)準(zhǔn)確的預(yù)測(cè)。
實(shí)驗(yàn)結(jié)果
由于PCL作為一種新穎的訓(xùn)練范式被提出,因此主要的基準(zhǔn)對(duì)比是傳統(tǒng)的監(jiān)督學(xué)習(xí)。研究者在不同模態(tài)的經(jīng)典代表性模型骨干網(wǎng)絡(luò)上進(jìn)行比較,以展示PCL的通用適用性。這些任務(wù)包括視覺(jué)模態(tài)的語(yǔ)義分割、圖模態(tài)的N體問(wèn)題仿真和語(yǔ)言模態(tài)的next-token prediction監(jiān)督微調(diào)。
在圖像語(yǔ)義分割任務(wù)中,上圖展示了PCL的預(yù)測(cè)過(guò)程。模型首先在完全隨機(jī)噪聲的提示下進(jìn)行預(yù)測(cè),然后將上一步的標(biāo)簽預(yù)測(cè)加噪到更小的噪聲程度,作為下一步的標(biāo)簽提示。
通過(guò)這種遞進(jìn)式的噪聲處理和多步推理,最終得到更加精確的預(yù)測(cè)結(jié)果。與傳統(tǒng)監(jiān)督學(xué)習(xí)(SL)進(jìn)行對(duì)比,PCL在單步預(yù)測(cè)時(shí)就已經(jīng)超過(guò)了SL,而隨著預(yù)測(cè)步驟的增多,預(yù)測(cè)質(zhì)量持續(xù)提升。
上圖展示了在給定標(biāo)簽提示的情況下,不同時(shí)間步設(shè)置對(duì)模型預(yù)測(cè)錯(cuò)誤范圍的影響??梢园l(fā)現(xiàn),設(shè)置較大的時(shí)間步傾向于鼓勵(lì)模型改進(jìn)更廣泛的結(jié)構(gòu)關(guān)系,而設(shè)置較小的時(shí)間步則鼓勵(lì)模型專注于更精細(xì)的細(xì)節(jié),例如物體的邊界。
這一現(xiàn)象表明,模型通過(guò)引入時(shí)間步的設(shè)計(jì),能夠在標(biāo)簽預(yù)測(cè)過(guò)程中分層次地學(xué)習(xí)不同粒度的信息,從全局結(jié)構(gòu)到局部細(xì)節(jié)。
上表展示了在語(yǔ)義分割任務(wù)上,PCL與SL的定量表現(xiàn)對(duì)比,進(jìn)一步驗(yàn)證了PCL在提升預(yù)測(cè)精度方面的優(yōu)勢(shì)。
在圖模態(tài)的預(yù)測(cè)任務(wù)中,上圖展示了不同學(xué)習(xí)階段下模型在預(yù)測(cè)階段的推理步數(shù)對(duì)預(yù)測(cè)質(zhì)量的影響。在訓(xùn)練尚不完全時(shí),推理步數(shù)越多,預(yù)測(cè)精度越高。
然而,隨著訓(xùn)練逐漸完成,觀察到隨著推理步數(shù)的增加,預(yù)測(cè)誤差會(huì)持續(xù)下降,但在達(dá)到某個(gè)臨界點(diǎn)后,誤差可能會(huì)反彈上升。
這種現(xiàn)象源于訓(xùn)練與推理階段的差異:在訓(xùn)練階段,模型始終以真實(shí)標(biāo)簽的噪聲擾動(dòng)版本作為輸入,而在推理階段,模型依賴于自身的中間預(yù)測(cè)結(jié)果,這些預(yù)測(cè)可能包含誤差,并在多步迭代中逐漸累積。
由此產(chǎn)生了一個(gè)權(quán)衡問(wèn)題:更多的推理步數(shù)有助于捕捉更精細(xì)的預(yù)測(cè)細(xì)節(jié),但也增加了誤差累積的風(fēng)險(xiǎn)。為了優(yōu)化這一平衡,研究人員通過(guò)驗(yàn)證集確定最佳的推理步數(shù),并在測(cè)試階段引入早停機(jī)制,在誤差開始上升之前終止推理流程。
值得注意的是,單步預(yù)測(cè)的精度相比于傳統(tǒng)監(jiān)督學(xué)習(xí)已經(jīng)有了顯著提升。
上表展示了PCL相較于SL在預(yù)測(cè)精度上的顯著提升,進(jìn)一步驗(yàn)證了PCL在處理復(fù)雜預(yù)測(cè)任務(wù)中的優(yōu)勢(shì)。
在語(yǔ)言模態(tài)的next-token prediction監(jiān)督微調(diào)任務(wù)中,研究人員對(duì)比了使用SL和PCL微調(diào)LLaMa2-7B模型的效果,結(jié)果表明,PCL相較于SL在性能上具有優(yōu)勢(shì)。
由于噪聲過(guò)程尚未進(jìn)行定制化,并且next token作為標(biāo)簽信息的提示量相對(duì)單薄,當(dāng)前的框架仍然有較大的提升空間。
未來(lái)的研究可以進(jìn)一步優(yōu)化噪聲過(guò)程并增強(qiáng)標(biāo)簽信息的豐富度,從而進(jìn)一步提升PCL在語(yǔ)言任務(wù)中的表現(xiàn)。
論文鏈接:https://openreview.net/pdf?id=FO2fu3daSL
代碼鏈接:https://github.com/Thinklab-SJTU/predictive-consistency-learning