The Annotated BERT注釋加量版,讀懂代碼才算讀懂了BERT 原創(chuàng)
前面我們從0實現(xiàn)了Transformer和GPT2的預(yù)訓(xùn)練過程,并且通過代碼注釋和打印數(shù)據(jù)維度使這個過程更容易理解,今天我將用同樣的方法繼續(xù)學習Bert。
原始Transformer是一個Encoder-Decoder架構(gòu),GPT是一種Decoder only模型,而Bert則是一種Encoder only模型,所以我們主要關(guān)注Transformer的左側(cè)部分。
后臺回復(fù)bert獲取訓(xùn)練數(shù)據(jù)集、代碼和論文下載鏈接
閱讀本文時請結(jié)合代碼
https://github.com/AIDajiangtang/annotated-transformer/blob/master/AnnotatedBert.ipynb
0.準備訓(xùn)練數(shù)據(jù)
0.0下載數(shù)據(jù)
原始BERT使用BooksCorpus和English Wikipedia作為預(yù)訓(xùn)練數(shù)據(jù),但這個數(shù)據(jù)集太大了,我們本次使用IMDb網(wǎng)站的50,000條電影評論數(shù)據(jù)來預(yù)訓(xùn)練,它是一個包含兩列數(shù)據(jù)的csv文件,其中review列是電影評論,sentiment列是情感標簽,即正面(positive)或負面(negative),我們本次只使用review列的電影評論。
(后臺回復(fù)bert獲取數(shù)據(jù)集下載鏈接)
下面打印出一條評論
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked.
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO.
Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. Forget pretty pictures painted for mainstream audiences, forget charm, forget romance...OZ doesn't mess around. The first episode I ever saw struck me as so nasty it was surreal, I couldn't say I was ready for it, but as I watched more, I developed a taste for Oz, and got accustomed to the high levels of graphic violence. Not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) Watching Oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.
ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=1000)
為了加快訓(xùn)練,通過ds_from和ds_to參數(shù)設(shè)置只讀取前1000條評論。
0.1計算上下文長度
上下文長度是指輸入序列的最大長度,再講Transformer和GPT2時,是直接通過超參數(shù)設(shè)置的,今天我們將根據(jù)訓(xùn)練數(shù)據(jù)統(tǒng)計得出,通過pandas逐行讀取1000條數(shù)據(jù),將每條評論按'.'分割成句子,并將所有句子的長度存儲到一個數(shù)組中。取句子長度數(shù)組中第90百分位的值。
通過計算,找到最優(yōu)的句子長度:27,如果樣本長度大于27會被截斷,小于27會用特殊字符填充。
舉個簡單的例子,假設(shè)句子長度數(shù)組為 [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],那么第90百分位的值就是90。
0.2分詞
本次使用的是basic_english分詞方法,它是一種非常簡單且直接的分詞方法,先將所有文本轉(zhuǎn)換為小寫,然后去除標點符號,最后按空格和標點符號將文本拆分成單詞。
"Hello, world! This is an example sentence."
['hello', 'world', 'this', 'is', 'an', 'example', 'sentence']
接下來將拆分后的單詞轉(zhuǎn)換成一個數(shù)字id,這個過程需要根據(jù)訓(xùn)練數(shù)據(jù)構(gòu)造一個詞表,也就是找到訓(xùn)練數(shù)據(jù)中所有唯一單詞。
通過統(tǒng)計可知,這1000條數(shù)據(jù)包含詞匯數(shù):9626
然后將下面特殊字符加到詞表前面。
CLS = '[CLS]'
PAD = '[PAD]'
SEP = '[SEP]'
MASK = '[MASK]'
UNK = '[UNK]'
0.3構(gòu)造訓(xùn)練數(shù)據(jù)
BERT是一種Encoder only架構(gòu),每一個token會與其它所有token計算注意力,無論是它前面的還是后面的。這樣能充分吸收上下文信息,Encoder only的模型適合理解任務(wù)。
而Decoder只與它前面的token計算注意力。從這種意義上看,GPT只利用了上文,但這種自回歸的方式也有好處,就是適合生成任務(wù)。
為了學習雙向表示,除了模型結(jié)構(gòu),構(gòu)造訓(xùn)練數(shù)據(jù)方式也有所不同。
GPT是用當前詞預(yù)測下一個詞,假設(shè)訓(xùn)練數(shù)據(jù)的token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],context_length=4,stride=4,batch_size=2。
Input IDs: [tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8])]
Target IDs: [tensor([2, 3, 4, 5]),tensor([6, 7, 8, 9])]
BERT采用兩種方式構(gòu)造預(yù)訓(xùn)練數(shù)據(jù):
MLM會隨機將一個樣本中的某些詞替換成[MASK],或者替換成詞表中的其它詞,在本例中,會替換15%的詞,其中80%替換成[MASK],20%替換成詞表中的其它詞。
NSP則是將相鄰的句子構(gòu)造成正樣本對,將不相鄰的句子視為負樣本對,兩個句子之間加一個[SEP]分割符。
BERT不善于生成任務(wù),那它如何完成問答等下游任務(wù)?其實,BERT會在每個樣本開頭都會放一個[CLS] token,通過CLS輸出進行二分類。
知道方法后,接下來構(gòu)造訓(xùn)練數(shù)據(jù),首先遍歷這1000條電影評論文本。
以第一條評論為例
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked.
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO.
Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. Forget pretty pictures painted for mainstream audiences, forget charm, forget romance...OZ doesn't mess around. The first episode I ever saw struck me as so nasty it was surreal, I couldn't say I was ready for it, but as I watched more, I developed a taste for Oz, and got accustomed to the high levels of graphic violence. Not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) Watching Oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.
將該評論按照“.” 分割成句子,遍歷每個句子。
第一個句子:
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked
第二個句子:
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO
第一個句子分詞:
['one', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', 'that', 'after', 'watching', 'just', '1', 'oz', 'episode', 'you', "'", 'll', 'be', 'hooked']
第二個句子分詞:
['they', 'are', 'right', ',', 'as', 'this', 'is', 'exactly', 'what', 'happened', 'with', 'me', '.', 'the', 'first', 'thing', 'that', 'struck', 'me', 'about', 'oz', 'was', 'its', 'brutality', 'and', 'unflinching', 'scenes', 'of', 'violence', ',', 'which', 'set', 'in', 'right', 'from', 'the', 'word', 'go']
將每個句子隨機選擇15%的單詞進行隨機掩碼,開頭加上[CLS],padding到上下文長度27,然后將兩個句子拼接在一起,用[SEP]分割符分開。
['[CLS]', 'one', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', '[MASK]', 'after', 'watching', 'just', '1', 'oz', 'episode', 'you', "'", '[MASK]', '[MASK]', 'hooked', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[SEP]', '[CLS]', 'they', 'are', 'right', ',', 'as', 'this', 'is', '[MASK]', 'what', 'happened', '[MASK]', 'me', '[MASK]', 'the', '[MASK]', 'financiers', 'that', 'struck', 'me', 'about', 'oz', 'was', 'its', 'brutality', 'and', 'unflinching']
根據(jù)上面掩碼句子構(gòu)造輸入掩碼,[MASK]的位置設(shè)置成Flase,其余為True。
[True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, False, True, False, True, False, False, True, True, True, True, True, True, True, True, True, True]
將帶掩碼的句子轉(zhuǎn)換成token ids,這個也是最終要輸入到模型中的X。
[0, 5, 6, 7, 8, 9, 10, 11, 2, 13, 14, 15, 16, 17, 18, 19, 20, 2, 2, 23, 1, 1, 1, 1, 1, 1, 1, 3, 0, 24, 25, 26, 27, 28, 29, 30, 2, 32, 33, 2, 35, 2, 7, 2, 32940, 12, 39, 35, 40, 17, 41, 42, 43, 44, 45]
將掩碼前的句子轉(zhuǎn)換成token ids,這個就是標簽Y。
[0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 1, 1, 1, 1, 1, 1, 1, 3, 0, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 7, 37, 38, 12, 39, 35, 40, 17, 41, 42, 43, 44, 45]
通過模型輸出與標簽Y計算MLM損失。
那NSP的損失呢?在構(gòu)造句子對時,如果兩個句子是相鄰的,那么標簽就是1,否則是0,最終通過[CLS]的輸出計算二分類損失。
最終根據(jù)前1000行數(shù)據(jù)構(gòu)造了一個DataFrame,DataFrame中每一條是一個樣本,一共包含17122個樣本,每個樣本包含四列。
一個是輸入X,維度[1,55]
一個是標簽Y,維度[1,55],
輸入掩碼,維度[1,55]
NSP分類標簽,0或者1。
55等于2兩個句子的長度加上一個[SEP]分割符,每個句子長度27。
1.預(yù)訓(xùn)練
超參數(shù)
EMB_SIZE = 64 #詞嵌入維度
HIDDEN_SIZE = 36 //
EPOCHS = 4
BATCH_SIZE = 12 #batch size
NUM_HEADS = 4 //頭的個數(shù)
根據(jù)超參數(shù)BATCH_SIZE = 12,也就是每個batch包含12個樣本,所以輸入X維度[12,55],標簽Y維度[12,55]。
1.0詞嵌入
接下來將token ids轉(zhuǎn)換成embedding,在Bert中,每個token都涉及到三種嵌入,第一種是Token embedding,token id轉(zhuǎn)換成詞嵌入向量,第二種是位置編碼。還有一種是Segment embedding。用于表示哪個句子,0表示第一個句子,1表示第二個句子。
根據(jù)超參數(shù)EMB_SIZE = 64,所以詞嵌入維度64,Token embedding通過一個嵌入層[9626,64]將輸入[12,55]映射成[12,55,64]。
9626是詞表的大小,[9626,64]的嵌入層可以看作是有9626個位置索引的查找表,每個位置存儲64維向量。
位置編碼可以通過學習的方式獲得,也可以通過固定計算方式獲得,本次采用固定計算方式。
Segment embedding和輸入X大小一致,第一個句子對應(yīng)為0,第二個位置為1。
最后將三個embedding相加,然后將輸出的embedding[12,55,64]輸入到編碼器中。
1.1多頭注意力
編碼器的第一個操作是多頭注意力,與Transformer和GPT中不同的是,不計算[PAD]的注意力,會將[PAD]對應(yīng)位置的注意力分數(shù)設(shè)置為一個非常小的值,使之經(jīng)過softmax后為0。
多頭注意力的輸出維度[12,55,64]。
1.2MLP
與Transformer和GPT中的一致,MLP的輸出維度[12,55,64]。
1.3輸出
編碼器的輸出[12,55,64],接下來通過與標簽計算損失來更新參數(shù)。
MLM損失
將Encoder的輸出[12,55,64]通過一個線性層[64,9626]映射成概率分布[12,55,9626]。
因為只需要計算[MASK]對應(yīng)位置的損失,所以會通過一些技巧將標簽和輸出中,非[MASK]位置設(shè)置為0。
最后與輸出標簽Y計算多分類交叉熵損失。
NSP損失
通過另一個線性層[64,2]將開頭的[CLS]的輸出[12,64]映射成[12,2],表示屬于正負類的概率,然后與標簽計算交叉熵損失。
2.0推理
最簡單的是完形填空,輸入一段文本[1,55],然后將某些詞替換成[MASK],將[MASK]的輸出通過一個輸出頭映射成[1,9626]。
因為我們在預(yù)訓(xùn)練時使用了“next sentence prediction”(NSP),可以構(gòu)造一個閉集VQA,就是為一個問題事先準備幾個答案,分別將問題和答案拼接在一起輸入到BERT,通過[CLS]的輸出去分類。
或者去預(yù)測答案的起始和終止位置,這就涉及到下游任務(wù)的微調(diào)了。
總結(jié)
至此,我們已經(jīng)完成了GPT2和BERT的預(yù)訓(xùn)練過程,為了讓模型能跟隨人類指令,后面還要對預(yù)訓(xùn)練模型進行指令微調(diào)。
參考
??https://arxiv.org/pdf/1810.04805??
??https://github.com/coaxsoft/pytorch_bert??
??https://towardsdatascience.com/a-complete-guide-to-bert-with-code-9f87602e4a11??
??https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891??
??https://coaxsoft.com/blog/building-bert-with-pytorch-from-scratch??
本文轉(zhuǎn)載自公眾號人工智能大講堂
