機器學習 | 從0開發大模型之模型預訓練
1、參數初始化
初始化參數模板:
from transformers import PretrainedConfig
class MyPretrainConfig(PretrainedConfig):
model_type = "myllm"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 16,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 512,
dropout: float = 0.0,
flash_attn: bool = True,
use_moe: bool = False,
num_experts_per_tok=2,
n_routed_experts=4,
n_shared_experts: bool = True,
scoring_func='softmax',
aux_loss_alpha=0.01,
seq_aux=True,
norm_topk_prob=True,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
self.num_experts_per_tok = num_experts_per_tok # 每個token選擇的專家數量
self.n_routed_experts = n_routed_experts # 總的專家數量
self.n_shared_experts = n_shared_experts # 共享專家
self.scoring_func = scoring_func # 評分函數,默認為'softmax'
self.aux_loss_alpha = aux_loss_alpha # 輔助損失的alpha參數
self.seq_aux = seq_aux # 是否在序列級別上計算輔助損失
self.norm_topk_prob = norm_topk_prob # 是否標準化top-k概率
super().__init__(**kwargs)
這里依賴 transformers
庫的 PretrainedConfig
,其中 MyPretrainConfig
參數如下:
dim: int = 512
:模型的維度,默認為 512n_layers: int = 8
:模型的層數,默認為 8n_heads: int = 16
:注意力頭的數量,默認為 16n_kv_heads: int = 8
:鍵值對的頭數,默認為 8vocab_size: int = 6400
:詞匯表的大小,默認為 6400hidden_dim: int = None
:隱藏層的維度,默認為None
,可以根據需要設置multiple_of: int = 64
:模型維度必須是這個值的倍數,默認為 64norm_eps: float = 1e-5
:歸一化的 epsilon 值,默認為 1e-5max_seq_len: int = 512
:最大序列長度,默認為 512dropout: float = 0.0
:dropout 概率,默認為 0.0flash_attn: bool = True
:是否使用快速注意力機制,默認為True
num_experts_per_tok=2
:每個 token 選擇的專家數量,默認為 2n_routed_experts=4
:總的專家數量,默認為 4n_shared_experts: bool = True
:是否使用共享專家,默認為True
scoring_func='softmax'
:評分函數,默認為'softmax'
aux_loss_alpha=0.01
:輔助損失的 alpha 參數,默認為 0.01seq_aux=True
:是否在序列級別上計算輔助損失,默認為True
norm_topk_prob=True
:是否標準化 top-k 概率,默認為True
**kwargs
:接收其他關鍵字參數,傳遞給父類的構造函數
PretrainedConfig
提供預訓練的參數模板,由于每個模型都是不一樣的,所以一般做成配置文件攜帶模型一起發布。
2、加載預處理的數據
加載上一篇文章已經處理好的預處理數據,代碼如下:
data_path_list = [f'./pretrain_data.bin']
train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
train_sampler = None
num_workers = 16 # 可以根據系統的 CPU 核心數來調整
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=num_workers,
sampler=train_sampler
)
其中 PretrainDataset
是加載代碼,主要目的是將數據轉換到內存中,方便 DataLoader
獲取:
class PretrainDataset(Dataset):
def __init__(self, data_path_lst, max_length=512, memmap=False):
super().__init__()
if memmap:
with open(data_path_lst[0], 'r') as f:
nbytes = f.seek(0, 2)
flen = f.tell() // np.dtype('uint16').itemsize
self.data = np.memmap(data_path_lst[0], dtype=np.dtype('uint16'), shape=(flen // max_length, max_length))
else:
data_lst = []
for data_path in data_path_lst:
with open(data_path, 'rb') as f:
data = np.fromfile(f, dtype=np.uint16)
data_lst.append(data)
data = np.concatenate(data_lst)
data = data[:max_length * int(len(data) / max_length)]
self.data = data.reshape(-1, max_length)
print("memmap:{} train data.shape:{}".format(memmap, self.data.shape))
print("downloading finished.....")
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index: int):
sample = self.data[index]
X = np.array(sample[:-1]).astype(np.int64)
Y = np.array(sample[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y)
其中 Dataset
是 from torch.utils.data import Dataset
通用代碼。
3、初始化模型
初始化模型,借鑒 llama2.c
的代碼,路徑:https://github.com/karpathy/llama2.c/blob/master/model.py,使用 Transformer
的 decoder
階段,即 Decoder-Only
,主要是如下邏輯:
- 初始化:創建tok_embeddings,dropout,layers和CausalLMOutputWithPast等
- forward:獲取迭代輸出的結果
具體代碼如下:
class Transformer(PreTrainedModel):
last_loss: Optional[torch.Tensor]
def __init__(self, params: MyPretrainConfig):
super().__init__(params)
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
# share the unembedding parameters with the embedding parameters
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
# some useful precompute for the RoPE relative positional embeddings
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
self.last_loss = None
self.OUT = CausalLMOutputWithPast()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
h = self.norm(h)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
self.last_loss = None
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('last_loss', self.last_loss)
return self.OUT
...
然后通過上述模型初始化,并打印模型:
def init_model():
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = Transformer(lm_config).to(device)
print(f'LLM總參數量:{count_parameters(model) / 1e6:.3f} 百萬')
return model
model = init_model()
print(model)
獲取輸出結果如下:
Transformer(
(tok_embeddings): Embedding(6400, 512)
(dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-7): 8 x TransformerBlock(
(attention): Attention(
(wq): Linear(in_features=512, out_features=512, bias=False)
(wk): Linear(in_features=512, out_features=256, bias=False)
(wv): Linear(in_features=512, out_features=256, bias=False)
(wo): Linear(in_features=512, out_features=512, bias=False)
(attn_dropout): Dropout(p=0.0, inplace=False)
(resid_dropout): Dropout(p=0.0, inplace=False)
)
(feed_forward): FeedForward(
(w1): Linear(in_features=512, out_features=1408, bias=False)
(w2): Linear(in_features=1408, out_features=512, bias=False)
(w3): Linear(in_features=512, out_features=1408, bias=False)
(dropout): Dropout(p=0.0, inplace=False)
)
(attention_norm): RMSNorm()
(ffn_norm): RMSNorm()
)
)
(norm): RMSNorm()
(output): Linear(in_features=512, out_features=6400, bias=False)
)
模型初始化這里就不詳細說了,這個系列出一篇文章具體分析 llama2.c
源碼,講述是如何實現模型創建的。
4、選擇optimizer
執行模型初始化后則選擇優化器,這里代碼如下:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
4.1 GradScaler
GradScaler
在 PyTorch 中的作用是用于自動混合精度(Automatic Mixed Precision, AMP)訓練時的梯度縮放,具體來說,它的主要功能包括:
- 防止梯度下溢:在使用混合精度訓練時,模型的權重和激活值可能會使用較低的精度(如半精度浮點數,FP16)。這可能導致在反向傳播過程中計算出的梯度值過小,從而出現梯度下溢(即梯度變為零),
GradScaler
會自動調整梯度的縮放因子,以確保梯度在更新時不會下溢; - 提高訓練速度:使用混合精度可以減少內存使用和計算時間,從而加速訓練過程,
GradScaler
通過動態調整縮放因子,幫助在保持數值穩定性的同時,充分利用混合精度的優勢; - 簡化代碼:使用
GradScaler
可以簡化混合精度訓練的實現,開發者不需要手動管理縮放因子和反縮放操作;
在訓練過程中,通常會使用 scaler.scale(loss).backward()
來計算縮放后的損失的梯度,然后使用 scaler.step(optimizer)
來更新模型參數,最后使用 scaler.update()
來更新縮放因子,這樣可以確保訓練過程的穩定性和效率。
4.2 optimizer
optimizer
在深度學習中是一個非常重要的組件,其主要作用是更新模型的參數,以最小化損失函數,具體來說,optimizer
的作用包括:
- 參數更新:優化器根據計算得到的梯度信息來更新模型的參數(權重和偏置),通過調整這些參數,優化器試圖使模型在訓練數據上的表現更好;
- 控制學習率:優化器通常會使用學習率(learning rate)來控制每次參數更新的幅度。學習率是一個超參數,決定了模型在每次迭代中向最優解移動的步長;
- 實現不同的優化算法:PyTorch 提供了多種優化算法(如 SGD、Adam、RMSprop 等),每種算法都有其獨特的更新規則和策略。選擇合適的優化器可以影響模型的收斂速度和最終性能;
- 處理動量和自適應學習率:一些優化器(如 Adam 和 RMSprop)使用動量和自適應學習率的策略來加速收斂和提高穩定性。這些策略可以幫助優化器在訓練過程中更有效地探索參數空間;
- 支持正則化:某些優化器可以集成正則化技術(如 L2 正則化),以防止模型過擬合;
在下面的迭代訓練中,主要作用是根據損失值調整優化器參數:
# 反向傳播
scaler.scale(loss).backward()
# 梯度剪裁和更新參數
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
# 清零梯度
optimizer.zero_grad(set_to_none=True)
5、迭代訓練
上述預處理數據加載完,模型執行了初始化,然后優化器也初始化后,就可以進行迭代訓練了,不過迭代訓練最重要的是設置學習率,根據loss動態調整參數,代碼如下:
for epoch in range(epochs):
start_time = time.time()
for step, (X, Y) in enumerate(train_loader):
X = X.to(device)
Y = Y.to(device)
# 設置學習率
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向傳播和損失計算
with ctx:
out = model(X, Y)
loss = out.last_loss
# 反向傳播
scaler.scale(loss).backward()
# 梯度剪裁和更新參數
if (step + 1) % accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
# 清零梯度
optimizer.zero_grad(set_to_none=True)
if step % 100 == 0:
spend_time = time.time() - start_time
print(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
epoch,
epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
model.eval()
ckp = f'{save_dir}/pretrain_{lm_config.dim}.pth'
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
out = model(X, Y)
前向傳播,計算輸出scaler.scale(loss).backward()
反向傳播,計算梯度,執行accumulation_steps
后更新梯度model.eval()
和model.train()
分別是模型評估和訓練,并保存當前模型到指定的文件夾
本人在T4的GPU上,跑了30+小時完成迭代訓練,如果使用CPU時間會X4,我在附錄中放了完整的代碼,有興趣的可以跑一下。
附錄
完成代碼:
import os
import time
import math
import warnings
import inspect
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from contextlib import nullcontext
from model.model import Transformer
from torch.utils.data import Dataset
from transformers import PretrainedConfig
from typing import Any, Optional, Tuple
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')
basepath = "../datasets"
class MyPretrainConfig(PretrainedConfig):
model_type = "myllm"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 16,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 512,
dropout: float = 0.0,
flash_attn: bool = True,
num_experts_per_tok=2,
n_routed_experts=4,
n_shared_experts: bool = True,
scoring_func='softmax',
aux_loss_alpha=0.01,
seq_aux=True,
norm_topk_prob=True,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
self.num_experts_per_tok = num_experts_per_tok # 每個token選擇的專家數量
self.n_routed_experts = n_routed_experts # 總的專家數量
self.n_shared_experts = n_shared_experts # 共享專家
self.scoring_func = scoring_func # 評分函數,默認為'softmax'
self.aux_loss_alpha = aux_loss_alpha # 輔助損失的alpha參數
self.seq_aux = seq_aux # 是否在序列級別上計算輔助損失
self.norm_topk_prob = norm_topk_prob # 是否標準化top-k概率
super().__init__(**kwargs)
class PretrainDataset(Dataset):
def __init__(self, data_path_lst, max_length=512, memmap=False):
super().__init__()
if memmap:
with open(data_path_lst[0], 'r') as f:
nbytes = f.seek(0, 2)
flen = f.tell() // np.dtype('uint16').itemsize
self.data = np.memmap(data_path_lst[0], dtype=np.dtype('uint16'), shape=(flen // max_length, max_length))
else:
data_lst = []
for data_path in data_path_lst:
with open(data_path, 'rb') as f:
data = np.fromfile(f, dtype=np.uint16)
data_lst.append(data)
data = np.concatenate(data_lst)
data = data[:max_length * int(len(data) / max_length)]
self.data = data.reshape(-1, max_length)
print("memmap:{} train data.shape:{}".format(memmap, self.data.shape))
print("downloading finished.....")
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index: int):
sample = self.data[index]
X = np.array(sample[:-1]).astype(np.int64)
Y = np.array(sample[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return freqs_cos, freqs_sin
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# apply rotation using real numbers
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# flatten last two dimensions
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: MyPretrainConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# use flash attention or a manual implementation?
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
bsz, seqlen, _ = x.shape
# QKV
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
# restore time as batch dimension and concat heads
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: MyPretrainConfig):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin):
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Transformer(PreTrainedModel):
last_loss: Optional[torch.Tensor]
def __init__(self, params: MyPretrainConfig):
super().__init__(params)
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
# share the unembedding parameters with the embedding parameters
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
# some useful precompute for the RoPE relative positional embeddings
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
self.last_loss = None
self.OUT = CausalLMOutputWithPast()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
h = self.norm(h)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
self.last_loss = None
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('last_loss', self.last_loss)
return self.OUT
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def estimate_mfu(self, fwdbwd_per_iter, dt):
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
# first estimate the number of flops we do per iteration.
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
N = sum(p.numel() for p in self.parameters())
cfg = self.params
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# express our flops throughput as ratio of A100 bfloat16 peak flops
flops_achieved = flops_per_iter * (1.0/dt) # per second
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
@torch.inference_mode()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
Also note this is a super inefficient version of sampling with no key/value cache.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
logits = self(idx_cond)
logits = logits[:, -1, :] # crop to just the final time step
if temperature == 0.0:
# "sample" the single most likely index
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# pluck the logits at the final step and scale by desired temperature
logits = logits / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
def get_lr(it, all):
warmup_iters = 0
lr_decay_iters = all
min_lr = learning_rate / 10
if it < warmup_iters:
return learning_rate * it / warmup_iters
if it > lr_decay_iters:
return min_lr
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
def init_model():
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = Transformer(lm_config).to(device)
print(f'LLM總參數量:{count_parameters(model) / 1e6:.3f} 百萬')
return model
if __name__ == "__main__":
# -----------------------------------------------------------------------------
lm_config = MyPretrainConfig()
max_seq_len = lm_config.max_seq_len
out_dir = 'out'
epochs = 20 # 訓練輪數
batch_size = 8 # batch_size
learning_rate = 1e-4 # 學習率
device = 'cuda:0' # or cpu
dtype = 'bfloat16'
save_dir = os.path.join(out_dir)
os.makedirs(save_dir, exist_ok=True)
os.makedirs(out_dir, exist_ok=True)
tokens_per_iter = batch_size * max_seq_len
torch.manual_seed(1337)
device_type = device if "cuda" in device else "cpu"
print(f"device_type: {device_type}")
ctx = (
nullcontext()
if device_type == "cpu"
else torch.cuda.amp.autocast()
)
# -----------------------------------------------------------------------------
# -----init dataloader------
data_path_list = [f'{basepath}/pretrain_data.bin']
train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
train_sampler = None
num_workers = 16 # 可以根據系統的 CPU 核心數來調整
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=num_workers,
sampler=train_sampler
)
# init model
model = init_model()
print(model)
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# training loop
accumulation_steps = 8
iter_per_epoch = len(train_loader)
for epoch in range(epochs):
start_time = time.time()
for step, (X, Y) in enumerate(train_loader):
X = X.to(device)
Y = Y.to(device)
# 設置學習率
lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向傳播和損失計算
with ctx:
out = model(X, Y)
loss = out.last_loss
# 反向傳播
scaler.scale(loss).backward()
# 梯度剪裁和更新參數
if (step + 1) % accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
# 清零梯度
optimizer.zero_grad(set_to_none=True)
if step % 100 == 0:
spend_time = time.time() - start_time
print(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
epoch,
epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
model.eval()
ckp = f'{save_dir}/pretrain_{lm_config.dim}.pth'
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
參考
(1)https://github.com/jingyaogong/minimind?tab=readme-ov-file#%E6%95%B0%E6%8D%AE%E9%9B%86%E4%B8%8B%E8%BD%BD%E5%9C%B0%E5%9D%80
(2)https://github.com/karpathy/llama2.c/blob/master/train.py