被DeepSeek帶火的知識蒸餾詳解!
今天來詳細(xì)了解DeepSeek中提到的知識蒸餾技術(shù),主要內(nèi)容來自三巨頭之一Geoffrey Hinton的一篇經(jīng)典工作:https://arxiv.org/pdf/1503.02531。
主要從背景、定義、原理、代碼復(fù)現(xiàn)等幾個方面來介紹:
1、背景介紹
訓(xùn)練與部署的不一致性
在機(jī)器學(xué)習(xí)和深度學(xué)習(xí)領(lǐng)域,訓(xùn)練模型和部署模型通常存在顯著差異。訓(xùn)練階段,為了追求最佳性能,我們通常會使用復(fù)雜的模型架構(gòu)和大量的計算資源,從海量且高度冗余的數(shù)據(jù)集中提取有用信息。例如,一些最先進(jìn)的模型可能包含數(shù)十億甚至上百億的參數(shù),或者通過多個模型集成來進(jìn)一步提升性能。然而,這些龐大的模型在實(shí)際部署時面臨諸多問題:
- 推斷速度慢大模型在處理數(shù)據(jù)時需要更多的時間來完成計算,這在需要實(shí)時響應(yīng)的場景中是不可接受的。
- 資源要求高大模型需要大量的內(nèi)存和顯存來存儲模型參數(shù)和中間計算結(jié)果,這使得它們難以部署在資源受限的設(shè)備上,如移動設(shè)備或嵌入式系統(tǒng)。
因此,在部署階段,我們對模型的延遲和計算資源有著嚴(yán)格的限制。這就引出了模型壓縮的需求——在盡量不損失性能的前提下,減少模型的參數(shù)量,使其更適合實(shí)際應(yīng)用環(huán)境。
模型壓縮與知識蒸餾
模型壓縮用于解決訓(xùn)練階段與部署階段之間的不一致性,特別是在模型規(guī)模與實(shí)際應(yīng)用需求之間的矛盾,在盡量不損失模型性能的前提下,減少模型的參數(shù)量和計算復(fù)雜度,使其更適合在資源受限的環(huán)境中部署。
知識蒸餾(Knowledge Distillation)是其中一種非常有效的模型壓縮技術(shù)。
2、什么是知識蒸餾?
知識蒸餾是一種模型壓縮技術(shù),通過訓(xùn)練一個小而高效的學(xué)生模型來模仿一個預(yù)訓(xùn)練的大且復(fù)雜的教師模型(或一組模型)的行為。這種訓(xùn)練設(shè)置通常被稱為“教師-學(xué)生”模式,其中大型模型作為教師,小型模型作為學(xué)生。教師模型的知識通過最小化損失函數(shù)傳遞給學(xué)生模型,目標(biāo)是匹配教師模型預(yù)測的類概率分布。
知識蒸餾的核心思想是將一個復(fù)雜且性能強(qiáng)大的“教師”模型的知識遷移到一個更小、更輕量的“學(xué)生”模型中。通過這種方式,學(xué)生模型可以在保持較小參數(shù)量的同時,盡可能地繼承教師模型的性能。
該方法最早由Bucila等人在2006年提出(https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf),并在2015年由Hinton等人推廣,成為知識蒸餾領(lǐng)域的奠基之作。Hinton的工作引入了帶溫度參數(shù)的softmax函數(shù),進(jìn)一步增強(qiáng)了知識轉(zhuǎn)移的有效性。
在 Geoffrey Hinton 等人的論文《Distilling the Knowledge in a Neural Network》中,作者通過昆蟲的幼蟲和成蟲形態(tài)來比喻機(jī)器學(xué)習(xí)中的訓(xùn)練階段和部署階段。這種比喻強(qiáng)調(diào)了不同階段對于模型的不同需求:就像昆蟲的幼蟲形態(tài)專注于從環(huán)境中吸收能量和養(yǎng)分,而其成蟲形態(tài)則優(yōu)化于移動和繁殖;同樣地,在機(jī)器學(xué)習(xí)中,訓(xùn)練階段需要處理大量數(shù)據(jù)以提取結(jié)構(gòu)信息,而部署階段則更關(guān)注于實(shí)時性和計算資源的有效利用。
教師與學(xué)生模型
教師模型:指的是那些龐大、復(fù)雜且可能計算成本高昂的模型或模型集合(ensemble)。這些模型雖然在訓(xùn)練階段能夠提供優(yōu)秀的性能,但由于其復(fù)雜性,不適合直接用于實(shí)際部署。教師模型通常經(jīng)過充分訓(xùn)練,可以很好地泛化到新數(shù)據(jù)上。
學(xué)生模型:相對較小且計算效率更高的模型,旨在模仿教師模型的表現(xiàn)同時保持低延遲和高效能,便于大規(guī)模部署。學(xué)生模型通過“知識蒸餾”過程從教師模型那里獲得知識,從而能夠在資源受限的環(huán)境下運(yùn)行。
知識的理解
關(guān)于“知識”的理解存在一個常見的誤解,即將其簡單等同于模型的權(quán)重參數(shù)。然而,Hinton 等人提倡一種更為抽象的觀點(diǎn),即知識應(yīng)被視為從輸入向量到輸出向量的學(xué)習(xí)映射關(guān)系。這不僅包括了對正確答案的概率預(yù)測,還涵蓋了對不正確答案之間細(xì)微差異的理解。例如,在圖像分類任務(wù)中,即使某個類別的概率非常小,它與其他錯誤類別相比仍然可能存在顯著差異,這些差異反映了模型如何泛化的關(guān)鍵信息。
蒸餾技術(shù)
知識蒸餾是一種技術(shù),它允許我們將教師模型中的知識轉(zhuǎn)移到學(xué)生模型中。這不僅僅是復(fù)制權(quán)重的過程,而是涉及到使用教師模型生成的軟目標(biāo)(soft targets)作為指導(dǎo),幫助學(xué)生模型學(xué)習(xí)到相似的泛化能力。通過調(diào)整 softmax 層的溫度參數(shù),可以使這些軟目標(biāo)更加平滑,從而讓學(xué)生模型能夠捕捉到更多有用的信息,并減少過擬合的風(fēng)險。
3、知識蒸餾的工作原理
Soft Target vs Hard Target
Hard Targets是最常見的訓(xùn)練標(biāo)簽形式,通常指的是每個訓(xùn)練樣本對應(yīng)的一個確切的類別標(biāo)簽。例如,在分類問題中,如果任務(wù)是對圖像進(jìn)行數(shù)字識別(如MNIST數(shù)據(jù)集),那么hard target就是一個具體的數(shù)字(比如1)。在這種情況下,模型被訓(xùn)練來最大化正確類別的概率,而其他所有類別的概率則應(yīng)盡可能地小。這通常通過最小化交叉熵?fù)p失函數(shù)來實(shí)現(xiàn),該函數(shù)懲罰了模型對正確標(biāo)簽預(yù)測的概率不足,并且不考慮錯誤標(biāo)簽之間的相對概率。
相比之下,Soft Targets提供了一個更加細(xì)致的概率分布,不僅包含了正確的類別,還包括了模型認(rèn)為可能相關(guān)的其他類別的概率。這意味著除了給出最有可能的類別之外,soft targets還提供了關(guān)于模型對于哪些類別可能是正確的、以及這些類別之間如何相互關(guān)聯(lián)的信息。這種類型的標(biāo)簽可以通過一個已經(jīng)訓(xùn)練好的教師模型生成,它為每個輸入產(chǎn)生一個概率分布而不是單一的類別標(biāo)簽。
基于soft targets訓(xùn)練模型相較于hard targets訓(xùn)練模型,尤其是在資源受限的情況下有效地轉(zhuǎn)移復(fù)雜模型的知識,具有以下幾個顯著的優(yōu)勢和意義:
- 傳遞更豐富的信息:Soft targets不僅包含正確類別的概率,還包括了對其他類別的相對概率估計。這意味著模型可以學(xué)習(xí)到不同類別之間的相似性和差異性,使得即使是對于錯誤的類別預(yù)測,也能反映出它們之間的細(xì)微差別,從而提供比單一正確答案更多的信息。例如,在圖像識別任務(wù)中,一個特定圖像可能與多個類別有一定的相似性,而這些信息在hard target中是丟失的。
- 減少梯度方差:當(dāng)使用較高的溫度參數(shù)T時,softmax輸出的概率分布變得更加平滑,這意味著每個樣本提供的信息量相對于hard target來說增加了,同時減少了梯度估計中的方差。這對于小數(shù)據(jù)集上的訓(xùn)練尤其有益,因為它可以防止過擬合并幫助模型更好地泛化到未見過的數(shù)據(jù)。
- 提高泛化能力:相比于hard targets僅給出一個確切的類別標(biāo)簽,通過模仿教師模型(通常是大型或集成模型)生成的soft targets,學(xué)生模型能夠?qū)W習(xí)到教師模型如何進(jìn)行泛化的細(xì)節(jié)。這有助于模型更好地學(xué)習(xí)數(shù)據(jù)中的潛在結(jié)構(gòu)和模式,特別是在處理復(fù)雜或模糊邊界的問題時。對于那些難以明確區(qū)分的類別,soft targets能夠指導(dǎo)模型認(rèn)識到哪些錯誤類別之間更加接近,從而改進(jìn)其泛化能力。
- 加速收斂和降低過擬合風(fēng)險:由于soft targets通常比hard targets擁有更高的熵,它們能提供更多的信息,并且減少梯度估計的方差。這對于小規(guī)模數(shù)據(jù)集特別有用,因為它可以幫助防止模型過度擬合訓(xùn)練數(shù)據(jù)中的噪聲。論文中提到的一個實(shí)驗顯示,使用soft targets的學(xué)生模型即使只用3%的數(shù)據(jù)進(jìn)行訓(xùn)練,也能夠幾乎恢復(fù)全量數(shù)據(jù)所能提供的信息,同時不需要早期停止來防止過擬合。這表明soft targets作為一種有效的正則化方法,能夠幫助模型更好地泛化。
- 增強(qiáng)模型的魯棒性:通過利用soft targets進(jìn)行訓(xùn)練,模型可以學(xué)習(xí)到輸入數(shù)據(jù)的內(nèi)在分布特性,而不僅僅是表面特征。這意味著模型對于輸入數(shù)據(jù)的小變化(如輕微的圖像變換)會更加穩(wěn)健,因為它們已經(jīng)學(xué)會了識別那些對最終分類決策影響較小的變化。
帶溫度的Softmax
傳統(tǒng)softmax函數(shù)傾向于產(chǎn)生極端的概率分布,導(dǎo)致非正確類別的概率接近于零,這限制了其對模型訓(xùn)練的幫助。
如圖所示,當(dāng)輸入一張馬的圖片時,對于未調(diào)整溫度(默認(rèn)為1)的 Softmax 輸出,正標(biāo)簽的概率接近 1,而負(fù)標(biāo)簽的概率接近 0。這種尖銳的分布對學(xué)生模型不夠友好,因為它只提供了關(guān)于正確答案的信息,而忽略了錯誤答案的信息。即驢比汽車更像馬,識別為驢的概率應(yīng)該大于識別為汽車的概率。
帶溫度的Softmax是一種調(diào)整softmax函數(shù)輸出的方法,通過引入一個額外參數(shù)——溫度(temperature, T),來控制輸出概率分布的平滑度。這個概念在知識蒸餾中尤為重要,因為它能夠影響教師模型如何向?qū)W生模型傳遞知識。
- 當(dāng)T=1時,該公式退化為標(biāo)準(zhǔn)的softmax。
- 當(dāng)T>1時,輸出的概率分布變得更平滑,即不同類別的概率差異減小,這使得即使是不太可能的類別也會分配到一定的概率值。
- 當(dāng)T<1時,結(jié)果是相反的,輸出的概率分布變得更加尖銳,增加了最有可能類別的概率,同時進(jìn)一步降低了其他類別的概率。
關(guān)于T的取值,原文說是“不要太大,也不要太小”,如果T的值太小,就會導(dǎo)致原本概率值比較小的類別除以后還是比較小,就獲取不到數(shù)據(jù)中有效的信息。如果T的值太大,負(fù)標(biāo)簽帶的信息會被放大,就會引入噪聲,也就是將一些沒有用信息給放的很大。這個實(shí)際還是調(diào)參吧,記得之前看到過文章說一般是20以內(nèi)。當(dāng)需要考慮負(fù)標(biāo)簽之間的關(guān)系時,可以采用較大的溫度。例如,在自然語言處理任務(wù)中,模型可能需要學(xué)習(xí)到“貓”和“狗”之間的相似性,而不僅僅是它們的硬標(biāo)簽。在這種情況下,較大的溫度可以使模型更好地捕捉到這些關(guān)系。反之,如果為了消除負(fù)標(biāo)簽中噪聲的影響,可以采用較小的溫度。
蒸餾過程
準(zhǔn)備階段:
- 一個已經(jīng)預(yù)訓(xùn)練好的、泛化能力強(qiáng)的教師網(wǎng)絡(luò)。
- 構(gòu)造好數(shù)據(jù)集,這個數(shù)據(jù)集可以使用用來訓(xùn)練教師網(wǎng)絡(luò)的數(shù)據(jù)集,也可以專門準(zhǔn)備一個數(shù)據(jù)集來做蒸餾的過程。
- 搭建好學(xué)生網(wǎng)絡(luò)。
蒸餾過程:
- 輸入數(shù)據(jù)通過教師網(wǎng)絡(luò): 輸入數(shù)據(jù)首先通過教師網(wǎng)絡(luò)得到logits(即softmax層之前的輸出)。使用較高的溫度t對這些logits應(yīng)用softmax函數(shù),生成soft labels。
- 輸入數(shù)據(jù)通過學(xué)生網(wǎng)絡(luò): 同樣的輸入數(shù)據(jù)也通過學(xué)生網(wǎng)絡(luò)得到自己的logits。 這些logits分別通過相同的溫度t和標(biāo)準(zhǔn)溫度T=1的softmax函數(shù)處理,得到soft predictions和hard predictions。Soft predictions用于與教師網(wǎng)絡(luò)的soft labels比較,而hard predictions則用于與真實(shí)標(biāo)簽比較。
- 計算損失:
Distillation Loss:使用交叉熵?fù)p失函數(shù)計算教師網(wǎng)絡(luò)的soft labels和學(xué)生網(wǎng)絡(luò)的soft predictions之間的差異。這一步幫助學(xué)生網(wǎng)絡(luò)學(xué)習(xí)到教師網(wǎng)絡(luò)對于不同類別間細(xì)微差別的理解。
Student Loss:同樣使用交叉熵?fù)p失函數(shù)計算學(xué)生網(wǎng)絡(luò)的hard predictions與真實(shí)標(biāo)簽之間的差異。這確保了學(xué)生網(wǎng)絡(luò)也能直接從數(shù)據(jù)中學(xué)習(xí)。
最終損失:這兩個損失(distillation loss和student loss)通常會加權(quán)求和形成最終的損失函數(shù)。權(quán)重的選擇可以根據(jù)具體任務(wù)調(diào)整,但一般情況下,distillation loss的權(quán)重相對較小,以避免過度依賴教師網(wǎng)絡(luò)的預(yù)測。
其中,α是介于0和1之間的一個系數(shù),用來平衡兩個損失項的重要性。作者發(fā)現(xiàn)學(xué)生網(wǎng)絡(luò)的損失的權(quán)重小一點(diǎn)比較好,比如可以嘗試 0.3、0.4。
最后作者說,由于的梯度大約是
的梯度的
,因此在
前乘上
可以保證兩個損失部分的梯度量貢獻(xiàn)基本一致。
4、代碼復(fù)現(xiàn)
使用mnist公開數(shù)據(jù)集簡單復(fù)現(xiàn)了上述流程,知識蒸餾主要代碼如下:
def train_kd_student(epoch):
student_kd.train()
optimizer = torch.optim.SGD(student_kd.parameters(), lr=lr, momentum=momentum)
print(f'\nTraining KD Student Epoch: {epoch}')
train_loss = 0
correct = 0
total = 0
with tqdm(train_loader, desc=f"Training KD Student Epoch {epoch}", total=len(train_loader)) as pbar:
for batch_idx, (inputs, targets) in enumerate(pbar):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
logits_student = student_kd(inputs)
with torch.no_grad():
logits_teacher = teacher_net(inputs)
ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
kd_loss = loss(logits_student, logits_teacher, temperature=T)
total_loss = ALPHA * ce_loss + BETA * kd_loss
total_loss.backward()
optimizer.step()
train_loss += total_loss.item()
_, predicted = logits_student.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# 記錄訓(xùn)練損失
writers['student_kd'].add_scalar('Training Loss', ce_loss.item(),
epoch * len(train_loader) + batch_idx)
writers['student_kd'].add_scalar('Training Loss', kd_loss.item(),
epoch * len(train_loader) + batch_idx)
writers['student_kd'].add_scalar('Training Loss', total_loss.item(),
epoch * len(train_loader) + batch_idx)
pbar.set_postfix(loss=train_loss/(batch_idx+1), acc=f"{100.*correct/total:.1f}%")
# 記錄訓(xùn)練準(zhǔn)確率
acc = 100. * correct / total
writers['student_kd'].add_scalar('Training Accuracy', acc, epoch)
學(xué)生模型和教師模型代碼如下:
class LeNet(nn.Module):
def __init__(self, num_classes=10):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.fc3 = nn.Linear(in_features=84, out_features=num_classes)
def forward(self, x):
x = self.maxpool(F.relu(self.conv1(x)))
x = self.maxpool(F.relu(self.conv2(x)))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class LeNetHalfChannel(nn.Module):
def __init__(self, num_classes=10):
super(LeNetHalfChannel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=3 * 12 * 12, out_features=num_classes)
def forward(self, x):
x = self.maxpool(F.relu(self.conv1(x)))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
return x
# 初始化模型
teacher_net = LeNet().to(device=device)
student_plain = LeNetHalfChannel().to(device=device) # 單獨(dú)訓(xùn)練的學(xué)生
student_kd = LeNetHalfChannel().to(device=device) # 知識蒸餾的學(xué)生
teacher_net.load_state_dict(torch.load('./model/model.pt'))
對比僅使用學(xué)生模型、僅使用教師模型和使用教師模型蒸餾學(xué)生模型的模型性能,結(jié)果如下:
Teacher Model Accuracy: 97.99%
Plain Student Best Accuracy: 85.77%
KD Student Best Accuracy: 96.70%
可以看到,使用知識蒸餾得到的模型性能十分接近教師模型,遠(yuǎn)超僅訓(xùn)練學(xué)生模型的性能。
完整工程代碼地址:https://github.com/jinbo0906/awesome-model-compression/blob/main/knowledge%20distillation/kd.ipynb
5、DeepSeek-R1的蒸餾方法
DeepSeek-R1 的蒸餾過程基于其自身生成的合成推理數(shù)據(jù)。由DeepSeek-R1 模型生成的 800,000 個數(shù)據(jù)樣本,對較小的基礎(chǔ)模型(例如 Qwen 和 Llama 系列)僅進(jìn)行監(jiān)督微調(diào),從而將大型模型的推理能力高效地遷移到小型模型中。
從這里可以了解到,在大模型時代,蒸餾不僅僅完全通過學(xué)習(xí)教師模型的軟標(biāo)簽實(shí)現(xiàn),也可以通過學(xué)習(xí)教師模型的輸出結(jié)構(gòu)化數(shù)據(jù),從而學(xué)習(xí)到教師模型中一些強(qiáng)大的性能。這種方式對算力的需求大大減少,推動未來AI能力的普惠化。
這種方式降低了對模型訓(xùn)練方法的要求,但是對數(shù)據(jù)質(zhì)量的要求則大大增加,從這里也可以看到,DeepSeek-R1模型生成數(shù)據(jù)的質(zhì)量非常高,可能遠(yuǎn)遠(yuǎn)超過了人工標(biāo)注,在DeepSeek-R1的論文中也提到模型的self-evolution能力,可能是未來非常值得關(guān)注的一個研究方向。
6、總結(jié)
知識蒸餾不僅能夠有效地降低模型的復(fù)雜度和計算成本,還能保持較高的模型性能。通過對教師模型知識的巧妙遷移,學(xué)生模型能夠在資源受限的環(huán)境中展現(xiàn)出色的表現(xiàn)。未來知識蒸餾將在更多領(lǐng)域發(fā)揮重要作用。