達(dá)摩院開(kāi)源半監(jiān)督學(xué)習(xí)框架Dash,刷新多項(xiàng)SOTA
一、研究背景
監(jiān)督學(xué)習(xí)(Supervised Learning)?
我們知道模型訓(xùn)練的目的其實(shí)是學(xué)習(xí)一個(gè)預(yù)測(cè)函數(shù),在數(shù)學(xué)上,這可以刻畫(huà)成一個(gè)學(xué)習(xí)從數(shù)據(jù) (X) 到標(biāo)注 (y) 的映射函數(shù)。監(jiān)督學(xué)習(xí)就是一種最常用的模型訓(xùn)練方法,其效果的提升依賴于大量的且進(jìn)行了很好標(biāo)注的訓(xùn)練數(shù)據(jù),也就是所謂的大量帶標(biāo)簽數(shù)據(jù) ((X,y))。但是標(biāo)注數(shù)據(jù)往往需要大量的人力物力等等,因此效果提升的同時(shí)也會(huì)帶來(lái)成本過(guò)高的問(wèn)題。在實(shí)際應(yīng)用中經(jīng)常遇到的情況是有少量標(biāo)注數(shù)據(jù)和大量未標(biāo)注數(shù)據(jù),由此引出的半監(jiān)督學(xué)習(xí)也越來(lái)越引起科學(xué)工作者的注意。
半監(jiān)督學(xué)習(xí)(Semi-Supervised Learning)?
半監(jiān)督學(xué)習(xí)同時(shí)對(duì)少量標(biāo)注數(shù)據(jù)和大量未標(biāo)注數(shù)據(jù)進(jìn)行學(xué)習(xí),其目的是借助無(wú)標(biāo)簽數(shù)據(jù)來(lái)提高模型的精度。比如 self-training 就是一種很常見(jiàn)的半監(jiān)督學(xué)習(xí)方法,其具體流程是對(duì)于標(biāo)注數(shù)據(jù) (X, y) 學(xué)習(xí)數(shù)據(jù)從 X 到 y 的映射,同時(shí)利用學(xué)習(xí)得到的模型對(duì)未標(biāo)注數(shù)據(jù) X 預(yù)測(cè)出一個(gè)偽標(biāo)簽,通過(guò)對(duì)偽標(biāo)簽數(shù)據(jù) (X,
)進(jìn)一步進(jìn)行監(jiān)督學(xué)習(xí)來(lái)幫助模型進(jìn)行更好的收斂和精度提高。
核心解決問(wèn)題?
現(xiàn)有的半監(jiān)督學(xué)習(xí)框架對(duì)無(wú)標(biāo)簽數(shù)據(jù)的利用大致可以分為兩種,一是全部參與訓(xùn)練,二是用一個(gè)固定的閾值卡出置信度較高的樣本進(jìn)行訓(xùn)練 (比如 FixMatch)。由于半監(jiān)督學(xué)習(xí)對(duì)未標(biāo)注數(shù)據(jù)的利用依賴于當(dāng)前模型預(yù)測(cè)的偽標(biāo)簽,所以偽標(biāo)簽的正確與否會(huì)給模型的訓(xùn)練帶來(lái)較大的影響,好的預(yù)測(cè)結(jié)果有助于模型的收斂和對(duì)新的模式的學(xué)習(xí),差的預(yù)測(cè)結(jié)果則會(huì)干擾模型的訓(xùn)練。所以我們認(rèn)為:不是所有的無(wú)標(biāo)簽樣本都是必須的!
二、論文 & 代碼
- 論文鏈接:https://proceedings.mlr.press/v139/xu21e/xu21e.pdf
- 代碼地址:https://github.com/idstcv/Dash
- 技術(shù)應(yīng)用:https://modelscope.cn/models/damo/cv_manual_face-liveness_flrgb/summary
這篇論文創(chuàng)新性地提出用動(dòng)態(tài)閾值(dynamic threshold)的方式篩選無(wú)標(biāo)簽樣本進(jìn)行半監(jiān)督學(xué)習(xí)(semi-supervised learning,SSL)的方法,我們改造了半監(jiān)督學(xué)習(xí)的訓(xùn)練框架,在訓(xùn)練過(guò)程中對(duì)無(wú)標(biāo)簽樣本的選擇策略進(jìn)行了改進(jìn),通過(guò)動(dòng)態(tài)變化的閾值來(lái)選擇更有效的無(wú)標(biāo)簽樣本進(jìn)行訓(xùn)練。Dash 是一個(gè)通用策略,可以輕松與現(xiàn)有的半監(jiān)督學(xué)習(xí)方法集成。實(shí)驗(yàn)方面,我們?cè)?CIFAR-10, CIFAR-100, STL-10 和 SVHN 等標(biāo)準(zhǔn)數(shù)據(jù)集上充分驗(yàn)證了其有效性。理論方面,論文從非凸優(yōu)化的角度證明了 Dash 算法的收斂性質(zhì)。
三、方法
Fixmatch 訓(xùn)練框架?
在引出我們的方法 Dash 之前,我們介紹一下 Google 提出的 FixMatch 算法,一種利用固定閾值選擇無(wú)標(biāo)簽樣本的半監(jiān)督學(xué)習(xí)方法。FixMatch 訓(xùn)練框架是之前的 SOTA 解決方案。整個(gè)學(xué)習(xí)框架的重點(diǎn)可以歸納為以下幾點(diǎn):
1、對(duì)于無(wú)標(biāo)簽數(shù)據(jù)經(jīng)過(guò)弱數(shù)據(jù)增強(qiáng)(水平翻轉(zhuǎn)、偏移等)得到的樣本通過(guò)當(dāng)前的模型得到預(yù)測(cè)值
2、對(duì)于無(wú)標(biāo)簽數(shù)據(jù)經(jīng)過(guò)強(qiáng)數(shù)據(jù)增強(qiáng)(RA or CTA)得到的樣本通過(guò)當(dāng)前的模型得到預(yù)測(cè)值
3、把具有高置信度的弱數(shù)據(jù)增強(qiáng)的結(jié)果,通過(guò) one hot 的方式形成偽標(biāo)簽
,然后用
和 X 經(jīng)過(guò)強(qiáng)數(shù)據(jù)增強(qiáng)得到的預(yù)測(cè)值
進(jìn)行模型的訓(xùn)練。
fixmatch 的優(yōu)點(diǎn)是用弱增強(qiáng)數(shù)據(jù)進(jìn)行偽標(biāo)簽的預(yù)測(cè),增加了偽標(biāo)簽預(yù)測(cè)的準(zhǔn)確性,并在訓(xùn)練過(guò)程中用固定的閾值 0.95(對(duì)應(yīng) loss 為 0.0513) 選取高置信度(閾值大于等于 0.95,也就是 loss 小于等于 0.0513)的預(yù)測(cè)樣本生成偽標(biāo)簽,進(jìn)一步穩(wěn)定了訓(xùn)練過(guò)程。
Dash 訓(xùn)練框架?
針對(duì)全部選擇偽標(biāo)簽和用固定閾值選擇偽標(biāo)簽的問(wèn)題,我們創(chuàng)新性地提出用動(dòng)態(tài)閾值來(lái)進(jìn)行樣本篩選的策略。即動(dòng)態(tài)閾值 是隨 t 衰減的
式中 C=1.0001,是有標(biāo)簽數(shù)據(jù)在第一個(gè) epoch 之后 loss 的平均值,我們選擇那些
的無(wú)標(biāo)簽樣本參與梯度回傳。下圖展示了不同
值下的閾值
的變化曲線。可以看到參數(shù)
控制了閾值曲線的下降速率。
的變化曲線類似于模擬訓(xùn)練模型時(shí)損失函數(shù)下降的趨勢(shì)。
下圖對(duì)比了訓(xùn)練過(guò)程中的 FixMath 和 Dash 選擇的正確樣本數(shù)和錯(cuò)誤樣本數(shù)隨訓(xùn)練進(jìn)行的變化情況(使用的數(shù)據(jù)集是 cifar100)。從圖中可以很清楚地看到,對(duì)比 FixMatch,Dash 可以選取更多正確 label 的樣本,同時(shí)選擇更少的錯(cuò)誤 label 的樣本,從而最終有助于提高訓(xùn)練模型的精度。
我們的算法可以總結(jié)為如下 Algorithm 1。Dash 是一個(gè)通用策略,可以輕松與現(xiàn)有的半監(jiān)督學(xué)習(xí)方法集成。為了方便,在本文的實(shí)驗(yàn)中我們主要將 Dash 與 FixMatch 集成。更多理論證明詳見(jiàn)論文。
四、結(jié)果
我們?cè)诎氡O(jiān)督學(xué)習(xí)常用數(shù)據(jù)集:CIFAR-10,CIFAR-100,STL-10 和 SVHN 上進(jìn)行了算法的驗(yàn)證。結(jié)果分別如下:
可以看到我們的方法在多個(gè)實(shí)驗(yàn)設(shè)置上都取得了比 SOTA 更好的結(jié)果,其中需要說(shuō)明的是針對(duì) CIFAR-100 400label 的實(shí)驗(yàn),ReMixMatch 用了 data align 的額外 trick 取得了更好的結(jié)果,在 Dash 中加入 data align 的 trick 之后可以取得 43.31% 的錯(cuò)誤率,低于 ReMixMatch 44.28% 的錯(cuò)誤率。
五、應(yīng)用
實(shí)際面向任務(wù)域的模型研發(fā)過(guò)程中,該半監(jiān)督 Dash 框架經(jīng)常會(huì)被應(yīng)用到。接下來(lái)給大家介紹下我們研發(fā)的各個(gè)域上的開(kāi)源免費(fèi)模型,歡迎大家體驗(yàn)、下載(大部分手機(jī)端即可體驗(yàn)):
- ?https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface/summary?
- ?https://modelscope.cn/models/damo/cv_resnet101_face-detection_cvpr22papermogface/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-detection_tinymog/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-detection_ulfd/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-detection_mtcnn/summary?
- ?https://modelscope.cn/models/damo/cv_resnet_face-recognition_facemask/summary?
- ?https://modelscope.cn/models/damo/cv_ir50_face-recognition_arcface/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-liveness_flir/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-liveness_flrgb/summary?
- ?https://modelscope.cn/models/damo/cv_manual_facial-landmark-confidence_flcm/summary?
- ?https://modelscope.cn/models/damo/cv_vgg19_facial-expression-recognition_fer/summary?
- ?https://modelscope.cn/models/damo/cv_resnet34_face-attribute-recognition_fairface/summary?