成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

大模型中常用的注意力機制GQA詳解以及Pytorch代碼實現

人工智能
分組查詢注意力 (Grouped Query Attention) 是一種在大型語言模型中的多查詢注意力 (MQA) 和多頭注意力 (MHA) 之間進行插值的方法,它的目標是在保持 MQA 速度的同時實現 MHA 的質量。

分組查詢注意力 (Grouped Query Attention) 是一種在大型語言模型中的多查詢注意力 (MQA) 和多頭注意力 (MHA) 之間進行插值的方法,它的目標是在保持 MQA 速度的同時實現 MHA 的質量。

這篇文章中,我們將解釋GQA的思想以及如何將其轉化為代碼。

GQA是在論文 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints paper.中提出,這是一個相當簡單和干凈的想法,并且建立在多頭注意力之上。

GQA

標準多頭注意層(MHA)由H個查詢頭、鍵頭和值頭組成。每個頭都有D個維度。Pytorch的代碼如下:

from torch.nn.functional import scaled_dot_product_attention
 
 # shapes: (batch_size, seq_len, num_heads, head_dim)
 query = torch.randn(1, 256, 8, 64)
 key = torch.randn(1, 256, 8, 64)
 value = torch.randn(1, 256, 8, 64)
 
 output = scaled_dot_product_attention(query, key, value)
 print(output.shape) # torch.Size([1, 256, 8, 64])

對于每個查詢頭,都有一個對應的鍵。這個過程如下圖所示:

而GQA將查詢頭分成G組,每組共享一個鍵和值。可以表示為:

使用可視化的表示就能非常清楚的了解GQA的工作原理,就像我們上面說的那樣,GQA是一個相當簡單和干凈的想法

Pytorch代碼實現

讓我們編寫代碼將這種將查詢頭劃分為G組,每個組共享一個鍵和值。我們可以使用einops庫有效地執行對張量的復雜操作。

首先,定義查詢、鍵和值。然后設置注意力頭的數量,數量是隨意的,但是要保證num_heads_for_query % num_heads_for_key = 0,也就是說要能夠整除。我們的定義如下:

import torch
 
 # shapes: (batch_size, seq_len, num_heads, head_dim)
 query = torch.randn(1, 256, 8, 64)
 key = torch.randn(1, 256, 2, 64)
 value = torch.randn(1, 256, 2, 64)
 
 num_head_groups = query.shape[2] // key.shape[2]
 print(num_head_groups) # each group is of size 4 since there are 2 kv_heads

為了提高效率,交換seq_len和num_heads維度,einops可以像下面這樣簡單地完成:

from einops import rearrange
 
 query = rearrange(query, "b n h d -> b h n d")
 key = rearrange(key, "b s h d -> b h s d")
 value = rearrange(value, "b s h d -> b h s d")

然后就是需要在查詢矩陣中引入”分組“的概念。

from einops import rearrange
 query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
 print(query.shape) # torch.Size([1, 4, 2, 256, 64])

上面的代碼我們將二維重塑為二維:對于我們定義的張量,原始維度8(查詢的頭數)現在被分成兩組(以匹配鍵和值中的頭數),每組大小為4。

最后最難的部分是計算注意力的分數。但其實它可以在一行中通過insum操作完成的

from einops import einsum, rearrange
 # g stands for the number of groups
 # h stands for the hidden dim
 # n and s are equal and stands for sequence length
  
 scores = einsum(query, key, "b g h n d, b h s d -> b h n s")
 print(scores.shape) # torch.Size([1, 2, 256, 256])

scores張量和上面的value張量的形狀是一樣的。我們看看到底是怎么操作的

einsum幫我們做了兩件事:

1、一個查詢和鍵的矩陣乘法。在我們的例子中,這些張量的形狀是(1,4,2,256,64)和(1,2,256,64),所以沿著最后兩個維度的矩陣乘法得到(1,4,2,256,256)。

2、對第二個維度(維度g)上的元素求和——如果在指定的輸出形狀中省略了維度,einsum將自動完成這項工作,這樣的求和是用來匹配鍵和值中的頭的數量。

最后是注意分數與值的標準乘法:

import torch.nn.functional as F
 
 scale = query.size(-1) ** 0.5
 attention = F.softmax(similarity / scale, dim=-1)
 
 # here we do just a standard matrix multiplication
 out = einsum(attention, value, "b h n s, b h s d -> b h n d")
 
 # finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim)
 out = rearrange(out, "b h n d -> b n h d")
 print(out.shape) # torch.Size([1, 256, 2, 64])

這樣最簡單的GQA實現就完成了,只需要不到16行python代碼:

最后再簡單提一句MQA:多查詢注意(MQA)是另一種簡化MHA的流行方法。所有查詢將共享相同的鍵和值。原理圖如下:

可以看到,MQA和MHA都可以從GQA推導出來。具有單個鍵和值的GQA相當于MQA,而具有與頭數量相等的組的GQA相當于MHA。

GQA的好處是什么?

GQA是最佳性能(MQA)和最佳模型質量(MHA)之間的一個很好的權衡。

下圖顯示,使用GQA,可以獲得與MHA幾乎相同的模型質量,同時將處理時間提高3倍,達到MQA的性能。這對于高負載系統來說可能是必不可少的。

在pytorch中沒有GQA的官方實現。所以我找到了一個比較好的非官方實現,有興趣的可以試試:

https://github.com/fkodom/grouped-query-attention-pytorch

GQA論文:

https://arxiv.org/pdf/2305.13245.pdf

責任編輯:華軒 來源: DeepHub IMBA
相關推薦

2024-06-28 08:04:43

語言模型應用

2023-07-30 15:42:45

圖神經網絡PyTorch

2024-07-16 14:15:09

2023-05-05 13:11:16

2024-12-09 00:00:10

2025-02-26 14:32:51

2024-12-17 14:39:16

2018-08-26 22:25:36

自注意力機制神經網絡算法

2024-09-19 10:07:41

2021-08-04 10:17:19

開發技能代碼

2017-08-03 11:06:52

2024-04-17 12:55:05

谷歌模型注意力

2024-10-31 10:00:39

注意力機制核心組件

2024-11-04 10:40:00

AI模型

2024-08-12 08:40:00

PyTorch代碼

2025-02-24 13:00:00

YOLOv12目標檢測Python

2022-03-25 11:29:04

視覺算法美團

2024-07-01 12:17:54

2014-11-04 10:34:27

JavaCache

2020-09-17 12:40:54

神經網絡CNN機器學習
點贊
收藏

51CTO技術棧公眾號

主站蜘蛛池模板: 国产亚洲第一页 | 三级免费 | 日韩av黄色 | 天天综合国产 | 久久久涩 | 国产精品国产三级国产aⅴ入口 | 日日夜夜视频 | 91传媒在线观看 | 国产综合一区二区 | 久久久久久久久久久蜜桃 | 日韩成年人视频在线 | 久久久国产一区二区三区 | 在线日韩不卡 | 亚州中文 | jizz18国产| 日韩中文字幕在线播放 | 久久精品久久久 | 日韩一区二区在线免费观看 | 欧美一级黄带 | 精品av| 成人国产一区二区三区精品麻豆 | 不卡一二三区 | 日本成人在线免费视频 | 成人网av| 欧美在线a| 高清不卡毛片 | 久久精品亚洲精品国产欧美 | 热99精品视频| 日本精品一区二区三区视频 | 一区二区中文字幕 | a在线观看| 日本电影网站 | 一区视频在线播放 | 国产精品视频免费观看 | 亚洲www| 久久av一区二区 | 久久国产麻豆 | 亚洲国产精品一区二区久久 | 亚洲网一区 | 91精品久久久久久综合五月天 | 日本在线视频一区二区 |