從零開始微調(diào)Embedding模型:基于BERT的實(shí)戰(zhàn)教程
背景
在理解與學(xué)會了Naive RAG的框架流程后,就很自然地關(guān)注到embedding模型,與問題相關(guān)的文本召回,也有很多論文在做這方面的創(chuàng)新。
以前一直不知道embedding模型是如何微調(diào)出來的,一直聽說是微調(diào)BERT,但是不知道是怎么微調(diào)出來的。直到在B站上看到bge模型微調(diào)的視頻[參考資料4]才理解。
于是便想著自己也微調(diào)出一個 embedding模型。涉及到下面三個階段:
- 數(shù)據(jù)集制作
- 模型訓(xùn)練
- 評估
微調(diào)實(shí)戰(zhàn)
安裝包
pip install -U FlagEmbedding[finetune]
項(xiàng)目基于 https://github.com/FlagOpen/FlagEmbedding,若遇到環(huán)境報(bào)錯,可參考該項(xiàng)目的環(huán)境,完成python環(huán)境設(shè)置
FlagEmbedding論文:C-Pack: Packed Resources For General Chinese Embeddings , 也稱 C-METB
介紹
你可以閱讀參考資料[1]和[2],先嘗試實(shí)現(xiàn)一次官方的微調(diào)教程。
官方微調(diào)的模型是BAAI/bge-large-en-v1.5,我選擇直接微調(diào)BERT模型,這樣感受微調(diào)的效果更明顯。僅僅是出于學(xué)習(xí)的目的,我才選擇微調(diào)BERT,如果大家打算用于生產(chǎn)環(huán)境,還是要選擇微調(diào)現(xiàn)成的embedding模型。因?yàn)閑mbedding模型也分為預(yù)訓(xùn)練與微調(diào)兩個階段,我們不做預(yù)訓(xùn)練。
embedding 模型需要通過encode方法把文本變成向量,而BERT模型沒有encode方法。故要使用FlagEmbedding導(dǎo)入原生的BERT模型。
from FlagEmbedding.inference.embedder.encoder_only.base import BaseEmbedder
# 省略數(shù)據(jù)集加載代碼
bert_embedding = BaseEmbedder("bert-base-uncased")
# get the embedding of the corpus
corpus_embeddings = bert_embedding.encode(corpus_dataset["text"])
print("shape of the corpus embeddings:", corpus_embeddings.shape)
print("data type of the embeddings: ", corpus_embeddings.dtype)
可瀏覽:eval_raw_bert.ipynb
項(xiàng)目文件介紹
數(shù)據(jù)集構(gòu)建:
- build_train_dataset.ipynb?: 構(gòu)建訓(xùn)練集數(shù)據(jù),隨機(jī)采樣負(fù)樣本數(shù)據(jù)通過修改?neg_num?的值,構(gòu)架了training_neg_10.json和training_neg_50.json兩個訓(xùn)練的數(shù)據(jù)集,比較增加負(fù)樣本的數(shù)量是否能提高模型召回的效果(實(shí)驗(yàn)結(jié)果表明:這里的效果并不好,提升不明顯)。
- build_eval_dataset.ipynb: 構(gòu)建測試集數(shù)據(jù),評估模型召回的效果。與FlagEmbedding數(shù)據(jù)集構(gòu)建結(jié)構(gòu)不同,我個人用這種數(shù)據(jù)集樣式更方便,不需要像FlagEmbedding一樣從下標(biāo)讀出正確的樣本的數(shù)據(jù)。
模型訓(xùn)練:
- finetune_neg10.sh
- finetune_neg50.sh
finetune_neg10.sh的代碼如下:
torchrun --nproc_per_node=1 \
-m FlagEmbedding.finetune.embedder.encoder_only.base \
--model_name_or_path bert-base-uncased \
--train_data ./ft_data/training_neg_10.json \
--train_group_size 8 \
--query_max_len 512 \
--passage_max_len 512 \
--pad_to_multiple_of 8 \
--query_instruction_for_retrieval 'Represent this sentence for searching relevant passages: ' \
--query_instruction_format '{}{}' \
--output_dir ./output/bert-base-uncased_neg10 \
--overwrite_output_dir \
--learning_rate 1e-5 \
--fp16 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--warmup_ratio 0.1 \
--logging_steps 200 \
--save_steps 2000 \
--temperature 0.02 \
--sentence_pooling_method cls \
--normalize_embeddings True \
--kd_loss_type kl_div
bash finetune_neg10.sh > finetune_neg10.log 2>&1 & 把訓(xùn)練的日志保存到 finetune_neg10.log 日志文件中,訓(xùn)練用時6分鐘。
neg10代表每條數(shù)據(jù)10個負(fù)樣本,neg50代表每條數(shù)據(jù)50個負(fù)樣本。
評估:
評估是在所有語料上完成的評估,并不是在指定的固定數(shù)量的負(fù)樣本上完成的評估。由于是在全部語料上完成召回,故使用到了faiss向量數(shù)據(jù)庫。
- eval_raw_bert.ipynb: 評估BERT原生模型
- eval_train_neg10.ipynb: 評估基于10條負(fù)樣本微調(diào)后的模型
- eval_train_neg50.ipynb: 評估基于50條負(fù)樣本微調(diào)后的模型
- eval_bge_m3.ipynb: 評估 BAAI 現(xiàn)在表現(xiàn)效果好的 BGE-M3 模型
結(jié)論:通過評估結(jié)果,可看出BERT經(jīng)過微調(diào)后的提升明顯,但依然達(dá)不到BGE-M3 模型的效果。
微調(diào)硬件配置要求
微調(diào)過程中GPU顯存占用達(dá)到了9G左右
設(shè)備只有一臺GPU
debug 重要代碼分析【選看】
下述代碼是舊版本的代碼,不是最新的FlagEmbedding的代碼:
- 視頻教程,bge模型微調(diào)流程:https://www.bilibili.com/video/BV1eu4y1x7ix/
推薦使用23年10月份的代碼進(jìn)行debug,關(guān)注核心代碼。新版的加了抽象類與繼承,增加了很多額外的東西,使用早期版本debug起來更聚焦一些。
由于需要傳遞參數(shù)再運(yùn)行腳本,需要在pycharm配置一些與運(yùn)行相關(guān)的參數(shù):
下述是embedding計(jì)算損失的核心代碼,這里的query與passage都是batch_size數(shù)量的輸入,如果只是一條query與passage,大家理解起來就容易很多。由于這里是batch_size數(shù)量的輸入,代碼中涉及到矩陣運(yùn)算會給大家?guī)砝斫饫щy。
比較難理解的是下述代碼,這里的target 其實(shí)就是label:
target = torch.arange(
scores.size(0), device=scores.device, dtype=torch.long
)
target = target * (p_reps.size(0) // q_reps.size(0))
p_reps 是相關(guān)文本矩陣, q_reps 是問題矩陣。每一個問題都對應(yīng)固定數(shù)量的相關(guān)文本。p_reps.size(0) // q_reps.size(0) 是每個問題對應(yīng)的相關(guān)文本的數(shù)量。下一行的target 乘以 相關(guān)文本的塊數(shù),得到query對應(yīng)的 Gold Truth(也稱 pos 文本)的下標(biāo),因?yàn)樵诿總€相關(guān)文本中,第一個位置都是正確文本,其后是負(fù)樣本,這些 Gold Truth 下標(biāo)之間的距離是固定,通過乘法就可以計(jì)算出每個 Gold Truth 的下標(biāo)。
額外補(bǔ)充【選看】:
在微調(diào)的過程中,不要錯誤的以為每個問題只和自己的相關(guān)文本計(jì)算score。真實(shí)的情況是,在batch_size的數(shù)據(jù)中,每個問題會與所有的相關(guān)文本計(jì)算score。根據(jù)上述代碼可看出 target 最大的取值是:query的數(shù)量 x 相關(guān)文本數(shù)量,這也印證了每個問題會與所有的相關(guān)文本都計(jì)算score。故我們在隨機(jī)采樣負(fù)樣本的時候,負(fù)樣本數(shù)量設(shè)置的太小也不用太擔(dān)心,因?yàn)樵谟?jì)算過程中負(fù)樣本的數(shù)量會乘以 batch_size。
【注意】:query的數(shù)量 = batch_size
- 損失函數(shù)
image-20250405174449800
image-20250407143100543
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
C-METB 論文中,關(guān)于損失函數(shù)的介紹,公式看起來很復(fù)雜,本質(zhì)就是cross_entropy。
資源分享
上述的代碼開源在github平臺,為了不增大github倉庫的容量,數(shù)據(jù)集沒有上傳到github平臺。若希望直接獲得完整的項(xiàng)目文件夾,從下述提供的網(wǎng)盤分享鏈接進(jìn)行下載:
- github開源地址:https://github.com/JieShenAI/csdn/tree/main/25/04/embedding_finetune
- 通過網(wǎng)盤分享的文件:embedding_finetune.zip鏈接: https://pan.baidu.com/s/1CDRpkkjS1-0jtmIBiTWx1A 提取碼: free
最新的代碼,請以 github 的鏈接為準(zhǔn),網(wǎng)盤分享的文件,本意只是為了存儲數(shù)據(jù),避免增加github倉庫的容量
CSDN: https://jieshen.blog.csdn.net/article/details/147043668
參考資料
[1] BAAI官方微調(diào)教程: ??https://github.com/FlagOpen/FlagEmbedding/blob/master/Tutorials/7_Fine-tuning/7.1.2_Fine-tune.ipynb??
[2] BAAI官方評估教程:??https://github.com/FlagOpen/FlagEmbedding/blob/master/Tutorials/4_Evaluation/4.1.1_Evaluation_MSMARCO.ipynb??
[3] 多文檔知識圖譜問答:??https://jieshen.blog.csdn.net/article/details/146390208??
[4] bge模型微調(diào)流程:??https://www.bilibili.com/video/BV1eu4y1x7ix/??
[5] FlagEmbedding 舊版本可用于debug的代碼:https://github.com/FlagOpen/FlagEmbedding/blob/9b6e521bcb7583ed907f044ca092daef0ee90431/FlagEmbedding/baai_general_embedding/finetune/run.py
本文轉(zhuǎn)載自??AI悠閑區(qū)??,作者:jieshenai
