繞開(kāi)算力限制,如何用單GPU微調(diào) LLM?這是一份「梯度累積」算法教程
自從大模型變成熱門(mén)趨勢(shì)之后,GPU 就成了緊俏的物資。很多企業(yè)的儲(chǔ)備都不一定充足,更不用說(shuō)個(gè)人開(kāi)發(fā)者了。有沒(méi)有什么方法可以更高效的利用算力訓(xùn)練模型?
在最近的一篇博客,Sebastian Raschka 介紹了「梯度累積」的方法,能夠在 GPU 內(nèi)存受限時(shí)使用更大 batch size 訓(xùn)練模型,繞開(kāi)硬件限制。
在此之前,Sebastian Raschka 也分享過(guò)一篇運(yùn)用多 GPU 訓(xùn)練策略加速大型語(yǔ)言模型微調(diào)的文章,包括模型或 tensor sharding 等機(jī)制,這些機(jī)制將模型權(quán)重和計(jì)算分布在不同的設(shè)備上,以解決 GPU 的內(nèi)存限制。
微調(diào) BLOOM 模型進(jìn)行分類(lèi)
假設(shè)我們有興趣采用近期預(yù)訓(xùn)練的大型語(yǔ)言模型來(lái)處理文本分類(lèi)等下游任務(wù)。那么,我們可能會(huì)選擇使用 GPT-3 的開(kāi)源替代品 BLOOM 模型,特別是「僅有」 5.6 億個(gè)參數(shù)的 BLOOM 版本 —— 它應(yīng)該可以毫無(wú)問(wèn)題地融入至傳統(tǒng) GPU 的 RAM 中(Google Colab 免費(fèi)版本擁有 15 Gb RAM 的 GPU)。
一旦開(kāi)始,就很可能遇到問(wèn)題:內(nèi)存會(huì)在訓(xùn)練或微調(diào)期間迅速增加。訓(xùn)練這個(gè)模型的唯一方法是使批大小為 1(batch size=1)。
使用批大小為 1(batch size=1)為目標(biāo)分類(lèi)任務(wù)微調(diào) BLOOM 的代碼如下所示。你也可以在 GitHub 項(xiàng)目頁(yè)面下載完整代碼:
https://github.com/rasbt/gradient-accumulation-blog/blob/main/src/1_batchsize-1.py
你可以將此代碼直接復(fù)制并粘貼到 Google Colab 中,但還必須將隨附的 local_dataset_utilities.py 文件拖放到從該文件導(dǎo)入了一些數(shù)據(jù)集實(shí)用程序的同一文件夾中。
# pip install torch lightning matplotlib pandas torchmetrics watermark transformers datasets -U
import os
import os.path as op
import time
from datasets import load_dataset
from lightning import Fabric
import torch
from torch.utils.data import DataLoader
import torchmetrics
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from watermark import watermark
from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset
from local_dataset_utilities import IMDBDataset
def tokenize_text (batch):
return tokenizer (batch ["text"], truncatinotallow=True, padding=True, max_length=1024)
def train (num_epochs, model, optimizer, train_loader, val_loader, fabric):
for epoch in range (num_epochs):
train_acc = torchmetrics.Accuracy (
task="multiclass", num_classes=2).to (fabric.device)
for batch_idx, batch in enumerate (train_loader):
model.train ()
### FORWARD AND BACK PROP
outputs = model (
batch ["input_ids"],
attention_mask=batch ["attention_mask"],
labels=batch ["label"]
)
fabric.backward (outputs ["loss"])
### UPDATE MODEL PARAMETERS
optimizer.step ()
optimizer.zero_grad ()
### LOGGING
if not batch_idx % 300:
print (f"Epoch: {epoch+1:04d}/{num_epochs:04d}"
f"| Batch {batch_idx:04d}/{len (train_loader):04d}"
f"| Loss: {outputs ['loss']:.4f}")
model.eval ()
with torch.no_grad ():
predicted_labels = torch.argmax (outputs ["logits"], 1)
train_acc.update (predicted_labels, batch ["label"])
### MORE LOGGING
model.eval ()
with torch.no_grad ():
val_acc = torchmetrics.Accuracy (task="multiclass", num_classes=2).to (fabric.device)
for batch in val_loader:
outputs = model (
batch ["input_ids"],
attention_mask=batch ["attention_mask"],
labels=batch ["label"]
)
predicted_labels = torch.argmax (outputs ["logits"], 1)
val_acc.update (predicted_labels, batch ["label"])
print (f"Epoch: {epoch+1:04d}/{num_epochs:04d}"
f"| Train acc.: {train_acc.compute ()*100:.2f}%"
f"| Val acc.: {val_acc.compute ()*100:.2f}%"
)
train_acc.reset (), val_acc.reset ()
if __name__ == "__main__":
print (watermark (packages="torch,lightning,transformers", pythnotallow=True))
print ("Torch CUDA available?", torch.cuda.is_available ())
device = "cuda" if torch.cuda.is_available () else "cpu"
torch.manual_seed (123)
# torch.use_deterministic_algorithms (True)
##########################
### 1 Loading the Dataset
##########################
download_dataset ()
df = load_dataset_into_to_dataframe ()
if not (op.exists ("train.csv") and op.exists ("val.csv") and op.exists ("test.csv")):
partition_dataset (df)
imdb_dataset = load_dataset (
"csv",
data_files={
"train": "train.csv",
"validation": "val.csv",
"test": "test.csv",
},
)
#########################################
### 2 Tokenization and Numericalization
#########################################
tokenizer = AutoTokenizer.from_pretrained ("bigscience/bloom-560m", max_length=1024)
print ("Tokenizer input max length:", tokenizer.model_max_length, flush=True)
print ("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True)
print ("Tokenizing ...", flush=True)
imdb_tokenized = imdb_dataset.map (tokenize_text, batched=True, batch_size=None)
del imdb_dataset
imdb_tokenized.set_format ("torch", columns=["input_ids", "attention_mask", "label"])
os.environ ["TOKENIZERS_PARALLELISM"] = "false"
#########################################
### 3 Set Up DataLoaders
#########################################
train_dataset = IMDBDataset (imdb_tokenized, partition_key="train")
val_dataset = IMDBDataset (imdb_tokenized, partition_key="validation")
test_dataset = IMDBDataset (imdb_tokenized, partition_key="test")
train_loader = DataLoader (
dataset=train_dataset,
batch_size=1,
shuffle=True,
num_workers=4,
drop_last=True,
)
val_loader = DataLoader (
dataset=val_dataset,
batch_size=1,
num_workers=4,
drop_last=True,
)
test_loader = DataLoader (
dataset=test_dataset,
batch_size=1,
num_workers=2,
drop_last=True,
)
#########################################
### 4 Initializing the Model
#########################################
fabric = Fabric (accelerator="cuda", devices=1, precisinotallow="16-mixed")
fabric.launch ()
model = AutoModelForSequenceClassification.from_pretrained (
"bigscience/bloom-560m", num_labels=2)
optimizer = torch.optim.Adam (model.parameters (), lr=5e-5)
model, optimizer = fabric.setup (model, optimizer)
train_loader, val_loader, test_loader = fabric.setup_dataloaders (
train_loader, val_loader, test_loader)
#########################################
### 5 Finetuning
#########################################
start = time.time ()
train (
num_epochs=1,
model=model,
optimizer=optimizer,
train_loader=train_loader,
val_loader=val_loader,
fabric=fabric,
)
end = time.time ()
elapsed = end-start
print (f"Time elapsed {elapsed/60:.2f} min")
with torch.no_grad ():
model.eval ()
test_acc = torchmetrics.Accuracy (task="multiclass", num_classes=2).to (fabric.device)
for batch in test_loader:
outputs = model (
batch ["input_ids"],
attention_mask=batch ["attention_mask"],
labels=batch ["label"]
)
predicted_labels = torch.argmax (outputs ["logits"], 1)
test_acc.update (predicted_labels, batch ["label"])
print (f"Test accuracy {test_acc.compute ()*100:.2f}%")
作者使用了 Lightning Fabric,因?yàn)樗梢宰岄_(kāi)發(fā)者在不同硬件上運(yùn)行此代碼時(shí)靈活地改變 GPU 數(shù)量和多 GPU 訓(xùn)練策略。它還允許僅通過(guò)調(diào)整查準(zhǔn)率 flag 來(lái)啟用混合精度訓(xùn)練(mixed-precision training)。在這種情況下,混合精度訓(xùn)練可以將訓(xùn)練速度提高三倍,并將內(nèi)存需求降低約 25%。
上面展示的主要代碼都是在主函數(shù)(if __name__ == "__main__" 的 context)中執(zhí)行的,即使只使用單個(gè) GPU,也推薦使用 PyTorch 運(yùn)行環(huán)境執(zhí)行多 GPU 訓(xùn)練。而后,包含在 if __name__ == "__main__" 中的以下三個(gè)代碼部分負(fù)責(zé)數(shù)據(jù)加載:
# 1 加載數(shù)據(jù)集
# 2 token 化和數(shù)值化
# 3 設(shè)置數(shù)據(jù)加載器
第 4 節(jié)是初始化模型(Initializing the Model)中,然后在第 5 節(jié) 微調(diào)(Finetuning)中,調(diào)用 train 函數(shù),這是開(kāi)始讓事情變得有趣的地方。在 train (...) 函數(shù)中,實(shí)現(xiàn)了標(biāo)準(zhǔn)的 PyTorch 循環(huán)。核心訓(xùn)練循環(huán)的注釋版本如下所示:
批大小為 1(Batch size=1)的問(wèn)題是梯度更新將會(huì)變得非?;靵y和困難,正如下述訓(xùn)練模型時(shí)基于波動(dòng)的訓(xùn)練損失和糟糕的測(cè)試集性能所看到的:
...
torch : 2.0.0
lightning : 2.0.0
transformers: 4.27.2
Torch CUDA available? True
...
Epoch: 0001/0001 | Batch 23700/35000 | Loss: 0.0969
Epoch: 0001/0001 | Batch 24000/35000 | Loss: 1.9902
Epoch: 0001/0001 | Batch 24300/35000 | Loss: 0.0395
Epoch: 0001/0001 | Batch 24600/35000 | Loss: 0.2546
Epoch: 0001/0001 | Batch 24900/35000 | Loss: 0.1128
Epoch: 0001/0001 | Batch 25200/35000 | Loss: 0.2661
Epoch: 0001/0001 | Batch 25500/35000 | Loss: 0.0044
Epoch: 0001/0001 | Batch 25800/35000 | Loss: 0.0067
Epoch: 0001/0001 | Batch 26100/35000 | Loss: 0.0468
Epoch: 0001/0001 | Batch 26400/35000 | Loss: 1.7139
Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.9570
Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.1857
Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0090
Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.9790
Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0503
Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.2625
Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.1010
Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0035
Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0009
Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0234
Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.8394
Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.9497
Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.1437
Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.1317
Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0112
Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0073
Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.7393
Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0512
Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.1337
Epoch: 0001/0001 | Batch 32400/35000 | Loss: 1.1875
Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.2727
Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.1545
Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0022
Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.2681
Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.2467
Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0620
Epoch: 0001/0001 | Batch 34500/35000 | Loss: 2.5039
Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0131
Epoch: 0001/0001 | Train acc.: 75.11% | Val acc.: 78.62%
Time elapsed 69.97 min
Test accuracy 78.53%
由于沒(méi)有多的 GPU 可用于張量分片(tensor sharding),又能做些什么來(lái)訓(xùn)練具有更大批大小(batch size)的模型呢?
其中一種解決方法就是梯度累積,可以通過(guò)它來(lái)修改前面提到的訓(xùn)練循環(huán)。
什么是梯度積累?
梯度累積是一種在訓(xùn)練期間虛擬增加批大?。╞atch size)的方法,當(dāng)可用的 GPU 內(nèi)存不足以容納所需的批大小時(shí),這非常有用。在梯度累積中,梯度是針對(duì)較小的批次計(jì)算的,并在多次迭代中累積(通常是求和或平均),而不是在每一批次之后更新模型權(quán)重。一旦累積梯度達(dá)到目標(biāo)「虛擬」批大小,模型權(quán)重就會(huì)使用累積梯度進(jìn)行更新。
參考下面更新的 PyTorch 訓(xùn)練循環(huán):
如果將 accumulation_steps 設(shè)置為 2,那么 zero_grad () 和 optimizer.step () 將只會(huì)每隔一秒調(diào)用一次。因此,使用 accumulation_steps=2 運(yùn)行修改后的訓(xùn)練循環(huán)與將批大小(batch size)加倍具有相同的效果。
例如,如果想使用 256 的批大小,但只能將 64 的批大小放入 GPU 內(nèi)存中,就可以對(duì)大小為 64 的四個(gè)批執(zhí)行梯度累積。(處理完所有四個(gè)批次后,將獲得相當(dāng)于單個(gè)批大小為 256 的累積梯度。)這樣能夠有效地模擬更大的批大小,而無(wú)需更大的 GPU 內(nèi)存或跨不同設(shè)備的張量分片。
雖然梯度累積可以幫助我們訓(xùn)練具有更大批量大小的模型,但它不會(huì)減少所需的總計(jì)算量。實(shí)際上,它有時(shí)會(huì)導(dǎo)致訓(xùn)練過(guò)程略慢一些,因?yàn)闄?quán)重更新的執(zhí)行頻率較低。盡管如此,它卻能幫我們解決限制問(wèn)題,即批大小非常小時(shí)導(dǎo)致的更新頻繁且混亂。
例如,現(xiàn)在讓我們運(yùn)行上面的代碼,批大小為 1,需要 16 個(gè)累積步驟(accumulation steps)來(lái)模擬批大小等于 16。
輸出如下:
...
torch : 2.0.0
lightning : 2.0.0
transformers: 4.27.2
Torch CUDA available? True
...
Epoch: 0001/0001 | Batch 23700/35000 | Loss: 0.0168
Epoch: 0001/0001 | Batch 24000/35000 | Loss: 0.0006
Epoch: 0001/0001 | Batch 24300/35000 | Loss: 0.0152
Epoch: 0001/0001 | Batch 24600/35000 | Loss: 0.0003
Epoch: 0001/0001 | Batch 24900/35000 | Loss: 0.0623
Epoch: 0001/0001 | Batch 25200/35000 | Loss: 0.0010
Epoch: 0001/0001 | Batch 25500/35000 | Loss: 0.0001
Epoch: 0001/0001 | Batch 25800/35000 | Loss: 0.0047
Epoch: 0001/0001 | Batch 26100/35000 | Loss: 0.0004
Epoch: 0001/0001 | Batch 26400/35000 | Loss: 0.1016
Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.0021
Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.0015
Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0008
Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.0060
Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0001
Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.0426
Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.0012
Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0025
Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0025
Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0000
Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.0495
Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.0164
Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.0067
Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.0037
Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0005
Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0013
Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.0112
Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0053
Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.0012
Epoch: 0001/0001 | Batch 32400/35000 | Loss: 0.1365
Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.0210
Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.0374
Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0007
Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.0341
Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.0259
Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0005
Epoch: 0001/0001 | Batch 34500/35000 | Loss: 0.4792
Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0003
Epoch: 0001/0001 | Train acc.: 78.67% | Val acc.: 87.28%
Time elapsed 51.37 min
Test accuracy 87.37%
根據(jù)上面的結(jié)果,損失的波動(dòng)比以前小了。此外,測(cè)試集性能提升了 10%。由于只迭代了訓(xùn)練集一次,因此每個(gè)訓(xùn)練樣本只會(huì)遇到一次。訓(xùn)練用于 multiple epochs 的模型可以進(jìn)一步提高預(yù)測(cè)性能。
你可能還會(huì)注意到,這段代碼的執(zhí)行速度也比之前使用的批大小為 1 的代碼快。如果使用梯度累積將虛擬批大小增加到 8,仍然會(huì)有相同數(shù)量的前向傳播(forward passes)。然而,由于每八個(gè) epoch 只更新一次模型,因此反向傳播(backward passes)會(huì)很少,這樣可更快地在一個(gè) epoch(訓(xùn)練輪數(shù))內(nèi)迭代樣本。
結(jié)論
梯度累積是一種在執(zhí)行權(quán)重更新之前通過(guò)累積多個(gè)小的批梯度來(lái)模擬更大的批大小的技術(shù)。該技術(shù)在可用內(nèi)存有限且內(nèi)存中可容納批大小較小的情況下提供幫助。
但是,首先請(qǐng)思考一種你可以運(yùn)行批大小的場(chǎng)景,這意味著可用內(nèi)存大到足以容納所需的批大小。在那種情況下,梯度累積可能不是必需的。事實(shí)上,運(yùn)行更大的批大小可能更有效,因?yàn)樗试S更多的并行性且能減少訓(xùn)練模型所需的權(quán)重更新次數(shù)。
總之,梯度累積是一種實(shí)用的技術(shù),可以用于降低小批大小干擾信息對(duì)梯度更新準(zhǔn)確性的影響。這是迄今一種簡(jiǎn)單而有效的技術(shù),可以讓我們繞過(guò)硬件的限制。
PS:可以讓這個(gè)運(yùn)行得更快嗎?
沒(méi)問(wèn)題??梢允褂?PyTorch 2.0 中引入的 torch.compile 使其運(yùn)行得更快。只需要添加一些 model = torch.compile,如下圖所示:
GitHub 上提供了完整的腳本。
在這種情況下,torch.compile 在不影響建模性能的情況下又減少了十分鐘的訓(xùn)練時(shí)間:
poch: 0001/0001 | Batch 26400/35000 | Loss: 0.0320
Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.0010
Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.0006
Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0015
Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.0157
Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0015
Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.0540
Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.0035
Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0016
Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0015
Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0008
Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.0877
Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.0232
Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.0014
Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.0032
Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0004
Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0062
Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.0032
Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0066
Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.0017
Epoch: 0001/0001 | Batch 32400/35000 | Loss: 0.1485
Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.0324
Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.0155
Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0007
Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.0049
Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.1170
Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0002
Epoch: 0001/0001 | Batch 34500/35000 | Loss: 0.4201
Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0018
Epoch: 0001/0001 | Train acc.: 78.39% | Val acc.: 86.84%
Time elapsed 43.33 min
Test accuracy 87.91%
請(qǐng)注意,與之前相比準(zhǔn)確率略有提高很可能是由于隨機(jī)性。