使用SPIN技術對LLM進行自我博弈微調訓練
2024年是大型語言模型(llm)的快速發展的一年,對于大語言模型的訓練一個重要的方法是對齊方法,它包括使用人類樣本的監督微調(SFT)和依賴人類偏好的人類反饋強化學習(RLHF)。這些方法在llm中發揮了至關重要的作用,但是對齊方法對人工注釋數據有的大量需求。這一挑戰使得微調成為一個充滿活力的研究領域,研究人員積極致力于開發能夠有效利用人類數據的方法。
加州大學最近的一項研究介紹了一種名為SPIN(Self Play fIne tuNing)的新技術。SPIN從AlphaGo Zero和AlphaZero等游戲中成功的自我對弈機制中汲取靈感。它能夠使LLM參與自我游戲的能力。這消除了對專業注釋者的需求,無論是人類還是更高級的模型(如GPT-4)。SPIN涉及訓練一個新的語言模型,并通過一系列迭代來區分它自己生成的響應和人類生成的響應。最終目標是開發得到一種語言模型,使其產生的反應與人類產生的反應沒有區別。
自我博弈
自我博弈是一種算法通過對抗自身副本來學習的技術。這種方法增加了學習環境的挑戰性和復雜性,允許代理與自己的不同版本進行交互。例如AlphaGo Zero,就是一個自我博弈的案例。
自我博弈在MARL中的有效性已經得到證實,但將其應用于大型語言模型(llm)的增強是一種新的方法。在大型語言模型中應用自我博弈有可能進一步提高他們的能力,使他們能夠生成更連貫、信息豐富的文本。
自我游戲既可以用于競爭環境,也可以用于合作環境。在競爭環境中,算法的副本相互競爭以達到特定的目標。在協作設置中,算法的副本一起工作以實現共同的目標。它還可以與其他學習技術相結合,如監督學習和強化學習,以進一步提高算法的性能。
SPIN
SPIN就像一個雙人游戲。在這個游戲中:
主模型(新LLM) -這個代理的角色是學習如何區分由語言模型(LLM)生成的響應和由人類創建的響應。在每個迭代中,主模型是正在積極訓練的LLM。其目標是提高其識別和區分反應的能力。
對手模型(舊LLM) -對手模型的任務是生成與人類產生的反應沒有區別的結果。對手模型是來自前一個迭代(輪)的LLM。它使用自我博弈機制,根據過去的知識產生結果。對手模型的目標是創造逼真的反應,讓新的LLM無法判斷他是否是機器生成的。
這個流程是不是很像GAN,但是還是不太一樣
SPIN的動態涉及使用監督微調(SFT)數據集,該數據集由輸入(x)和輸出(y)對組成。這些示例由人工注釋,并作為訓練主模型識別類人響應的基礎。一些公開的SFT數據集包括Dolly15K、Baize、Ultrachat等。
主模型的訓練
為了訓練主模型區分語言模型(LLM)和人類反應,SPIN使用了一個目標函數。這個函數測量真實數據和對手模型產生的反應之間的預期值差距。主模型的目標是最大化這一期望值差距。這包括將高值分配給與真實數據的響應配對的提示,并將低值分配給由對手模型生成的響應配對。這個目標函數被表述為最小化問題。
主模型的工作是最小化損失函數,即衡量來自真實數據的配對分配值與來自對手模型反應的配對分配值之間的差異。在整個訓練過程中,主模型調整其參數以最小化該損失函數。這個迭代過程一直持續下去,直到主模型能夠熟練地有效區分LLM的反應和人類的反應。
對手模型的更新
更新對手模型涉及改進主模型的能力,他們在訓練時已經學會區分真實數據和語言模型反應。隨著主模型的改進及其對特定函數類的理解,我們還需要更新如對手模型的參數。當主玩家面對相同的提示時,它便會使用學習得到的辨別能力去評估它們的價值。
對手模型玩家的目標是增強語言模型,使其響應與主玩家的真實數據無法區分。這就需要設置一個流程來調整語言模型的參數。目的是在保持穩定性的同時,最大限度地提高主模型對語言模型反應的評價。這涉及到一種平衡行為,確保改進不會偏離原始語言模型太遠。
聽著有點亂,我們簡單總結下:
訓練的時候只有一個模型,但是將模型分為前一輪的模型(舊LLM/對手模型)和主模型(正在訓練的),使用正在訓練的模型的輸出與上一輪模型的輸出作為對比,來優化當前模型的訓練。但是這里就要求我們必須要有一個訓練好的模型作為對手模型,所以SPIN算法只適合在訓練結果上進行微調。
SPIN算法
SPIN從預訓練的模型生成合成數據。然后使用這些合成數據對新任務上的模型進行微調。
上面時原始論文中Spin算法的偽代碼,看著有點難理解,我們通過Python來復現更好地解釋它是如何工作的。
1、初始化參數和SFT數據集
原論文采用Zephyr-7B-SFT-Full作為基本模型。對于數據集,他們使用了更大的Ultrachat200k語料庫的子集,該語料庫由使用OpenAI的Turbo api生成的大約140萬個對話組成。他們隨機抽取了50k個提示,并使用基本模型來生成合成響應。
# Import necessary libraries
from datasets import load_dataset
import pandas as pd
# Load the Ultrachat 200k dataset
ultrachat_dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
# Initialize an empty DataFrame
combined_df = pd.DataFrame()
# Loop through all the keys in the Ultrachat dataset
for key in ultrachat_dataset.keys():
# Convert each dataset key to a pandas DataFrame and concatenate it with the existing DataFrame
combined_df = pd.concat([combined_df, pd.DataFrame(ultrachat_dataset[key])])
# Shuffle the combined DataFrame and reset the index
combined_df = combined_df.sample(frac=1, random_state=123).reset_index(drop=True)
# Select the first 50,000 rows from the shuffled DataFrame
ultrachat_50k_sample = combined_df.head(50000)
作者的提示模板“### Instruction: {prompt}\n\n### Response:”
# for storing each template in a list
templates_data = []
for index, row in ultrachat_50k_sample.iterrows():
messages = row['messages']
# Check if there are at least two messages (user and assistant)
if len(messages) >= 2:
user_message = messages[0]['content']
assistant_message = messages[1]['content']
# Create the template
instruction_response_template = f"### Instruction: {user_message}\n\n### Response: {assistant_message}"
# Append the template to the list
templates_data.append({'Template': instruction_response_template})
# Create a new DataFrame with the generated templates (ground truth)
ground_truth_df = pd.DataFrame(templates_data)
然后得到了類似下面的數據:
SPIN算法通過迭代更新語言模型(LLM)的參數使其與地面真實響應保持一致。這個過程一直持續下去,直到很難區分生成的響應和真實情況,從而實現高水平的相似性(降低損失)。
SPIN算法有兩個循環。內部循環基于我們正在使用的樣本數量運行,外部循環總共運行了3次迭代,因為作者發現模型的性能在此之后沒有變化。采用Alignment Handbook庫作為微調方法的代碼庫,結合DeepSpeed模塊,降低了訓練成本。他們用RMSProp優化器訓練Zephyr-7B-SFT-Full,所有迭代都沒有權重衰減,就像通常用于微調llm一樣。全局批大小設置為64,使用bfloat16精度。迭代0和1的峰值學習率設置為5e-7,迭代2和3的峰值學習率隨著循環接近自播放微調的結束而衰減為1e-7。最后選擇β = 0.1,最大序列長度設置為2048個標記。下面就是這些參數
# Importing the PyTorch library
import torch
# Importing the neural network module from PyTorch
import torch.nn as nn
# Importing the DeepSpeed library for distributed training
import deepspeed
# Importing the AutoTokenizer and AutoModelForCausalLM classes from the transformers library
from transformers import AutoTokenizer, AutoModelForCausalLM
# Loading the zephyr-7b-sft-full model from HuggingFace
tokenizer = AutoTokenizer.from_pretrained("alignment-handbook/zephyr-7b-sft-full")
model = AutoModelForCausalLM.from_pretrained("alignment-handbook/zephyr-7b-sft-full")
# Initializing DeepSpeed Zero with specific configuration settings
deepspeed_config = deepspeed.config.Config(train_batch_size=64, train_micro_batch_size_per_gpu=4)
model, optimizer, _, _ = deepspeed.initialize(model=model, config=deepspeed_config, model_parameters=model.parameters())
# Defining the optimizer and setting the learning rate using RMSprop
optimizer = deepspeed.optim.RMSprop(optimizer, lr=5e-7)
# Setting up a learning rate scheduler using LambdaLR from PyTorch
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.2 ** epoch)
# Setting hyperparameters for training
num_epochs = 3
max_seq_length = 2048
beta = 0.1
2、生成合成數據(SPIN算法內循環)
這個內部循環負責生成需要與真實數據保持一致的響應,也就是一個訓練批次的代碼
# zephyr-sft-dataframe (that contains output that will be improved while training)
zephyr_sft_output = pd.DataFrame(columns=['prompt', 'generated_output'])
# Looping through each row in the 'ultrachat_50k_sample' dataframe
for index, row in ultrachat_50k_sample.iterrows():
# Extracting the 'prompt' column value from the current row
prompt = row['prompt']
# Generating output for the current prompt using the Zephyr model
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids, max_length=200, num_beams=5, no_repeat_ngram_size=2, top_k=50, top_p=0.95)
# Decoding the generated output to human-readable text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Appending the current prompt and its generated output to the new dataframe 'zephyr_sft_output'
zephyr_sft_output = zephyr_sft_output.append({'prompt': prompt, 'generated_output': generated_text}, ignore_index=True)
這就是一個提示的真實值和模型輸出的樣例。
新的df zephyr_sft_output,其中包含提示及其通過基本模型Zephyr-7B-SFT-Full生成的相應輸出。
3、更新規則
在編碼最小化問題之前,理解如何計算llm生成的輸出的條件概率分布是至關重要的。原論文使用馬爾可夫過程,其中條件概率分布pθ (y∣x)可通過分解表示為:
這種分解意味著給定輸入序列的輸出序列的概率可以通過將給定輸入序列的每個輸出標記與前一個輸出標記的概率相乘來計算。例如輸出序列為“I enjoy reading books”,輸入序列為“I enjoy”,則在給定輸入序列的情況下,輸出序列的條件概率可以計算為:
馬爾可夫過程條件概率將用于計算真值和Zephyr LLM響應的概率分布,然后用于計算損失函數。但首先我們需要對條件概率函數進行編碼。
# Conditional Probability Function of input text
def compute_conditional_probability(tokenizer, model, input_text):
# Tokenize the input text and convert it to PyTorch tensors
inputs = tokenizer([input_text], return_tensors="pt")
# Generate text using the model, specifying additional parameters
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
# Assuming 'transition_scores' is the logits for the generated tokens
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
# Get the length of the input sequence
input_length = inputs.input_ids.shape[1]
# Assuming 'transition_scores' is the logits for the generated tokens
logits = torch.tensor(transition_scores)
# Apply softmax to obtain probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# Extract the generated tokens from the output
generated_tokens = outputs.sequences[:, input_length:]
# Compute conditional probability
conditional_probability = 1.0
for prob in probs[0]:
token_probability = prob.item()
conditional_probability *= token_probability
return conditional_probability
損失函數它包含四個重要的條件概率變量。這些變量中的每一個都取決于基礎真實數據或先前創建的合成數據。
而lambda是一個正則化參數,用于控制偏差。在KL正則化項中使用它來懲罰對手模型的分布與目標數據分布之間的差異。論文中沒有明確提到lambda的具體值,因為它可能會根據所使用的特定任務和數據集進行調優。
def LSPIN_loss(model, updated_model, tokenizer, input_text, lambda_val=0.01):
# Initialize conditional probability using the original model and input text
cp = compute_conditional_probability(tokenizer, model, input_text)
# Update conditional probability using the updated model and input text
cp_updated = compute_conditional_probability(tokenizer, updated_model, input_text)
# Calculate conditional probabilities for ground truth data
p_theta_ground_truth = cp(tokenizer, model, input_text)
p_theta_t_ground_truth = cp(tokenizer, model, input_text)
# Calculate conditional probabilities for synthetic data
p_theta_synthetic = cp_updated(tokenizer, updated_model, input_text)
p_theta_t_synthetic = cp_updated(tokenizer, updated_model, input_text)
# Calculate likelihood ratios
lr_ground_truth = p_theta_ground_truth / p_theta_t_ground_truth
lr_synthetic = p_theta_synthetic / p_theta_t_synthetic
# Compute the LSPIN loss
loss = lambda_val * torch.log(lr_ground_truth) - lambda_val * torch.log(lr_synthetic)
return loss
如果你有一個大的數據集,可以使用一個較小的lambda值,或者如果你有一個小的數據集,則可能需要使用一個較大的lambda值來防止過擬合。由于我們數據集大小為50k,所以可以使用0.01作為lambda的值。
4、訓練(SPIN算法外循環)
這就是Pytorch訓練的一個基本流程,就不詳細解釋了:
# Training loop
for epoch in range(num_epochs):
# Model with initial parameters
initial_model = AutoModelForCausalLM.from_pretrained("alignment-handbook/zephyr-7b-sft-full")
# Update the learning rate
scheduler.step()
# Initialize total loss for the epoch
total_loss = 0.0
# Generating Synthetic Data (Inner loop)
for index, row in ultrachat_50k_sample.iterrows():
# Rest of the code
...
# Output == prompt response dataframe
zephyr_sft_output
# Computing loss using LSPIN function
for (index1, row1), (index2, row2) in zip(ultrachat_50k_sample.iterrows(), zephyr_sft_output.iterrows()):
# Assuming 'prompt' and 'generated_output' are the relevant columns in zephyr_sft_output
prompt = row1['prompt']
generated_output = row2['generated_output']
# Compute LSPIN loss
updated_model = model # It will be replacing with updated model
loss = LSPIN_loss(initial_model, updated_model, tokenizer, prompt)
# Accumulate the loss
total_loss += loss.item()
# Backward pass
loss.backward()
# Update the parameters
optimizer.step()
# Update the value of beta
if epoch == 2:
beta = 5.0
我們運行3個epoch,它將進行訓練并生成最終的Zephyr SFT LLM版本。官方實現還沒有在GitHub上開源,這個版本將能夠在某種程度上產生類似于人類反應的輸出。我們看看他的運行流程
表現及結果
SPIN可以顯著提高LLM在各種基準測試中的性能,甚至超過通過直接偏好優化(DPO)補充額外的GPT-4偏好數據訓練的模型。
當我們繼續訓練時,隨著時間的推移,進步會變得越來越小。這表明模型達到了一個閾值,進一步的迭代不會帶來顯著的收益。這是我們訓練數據中樣本提示符每次迭代后的響應。