可解釋性 CLIP:架構(gòu)解釋及其在分割中的應(yīng)用
可解釋性是人工智能模型的關(guān)鍵話題之一。最近的復(fù)雜人工智能傾向于成為一個(gè)黑盒算法,使得人類難以理解人工智能為何提供這些結(jié)果。最近,我讀了一篇論文,“通過(guò)增強(qiáng)開(kāi)放詞匯任務(wù)的可解釋性進(jìn)行CLIP手術(shù)”[1],主要關(guān)于CLIP的可解釋技術(shù)。盡管這篇論文展示了CLIP的極佳可解釋性,但很少有博客對(duì)此進(jìn)行解釋。因此,我將在這篇博客中介紹CLIP_Surgery的架構(gòu)及其應(yīng)用。
1. 快速回顧C(jī)LIP
CLIP是由OpenAI[2]開(kāi)發(fā)的改變游戲規(guī)則的人工智能之一。得益于其獨(dú)特的架構(gòu),它能夠進(jìn)行零樣本圖像分類。架構(gòu)如下所示:
CLIP具有圖像和文本編碼器,用于創(chuàng)建圖像和文本嵌入。訓(xùn)練數(shù)據(jù)是圖像和文本對(duì),例如帶有文本“一只狗的照片”的狗的圖像。它利用對(duì)比預(yù)訓(xùn)練來(lái)對(duì)齊圖像和文本嵌入,如果圖像和文本是一對(duì),則對(duì)齊,如果不是一對(duì)則不進(jìn)行對(duì)齊。為了直觀理解,讓我們考慮以下示例。在這個(gè)示例中,我們使用三個(gè)圖像和文本對(duì)(上圖中的N = 3)。
圖像和文本編碼器輸出的嵌入維度始終是(1,512),每個(gè)圖像和文本都是如此。在這個(gè)示例中,我們有維度為(3,512)的圖像和文本嵌入。使用嵌入的余弦相似度,我們可以計(jì)算相似度矩陣,如上圖中的矩陣。在對(duì)比預(yù)訓(xùn)練中,CLIP利用這個(gè)相似度矩陣來(lái)對(duì)齊匹配對(duì)(=對(duì)角線元素)使其相似,而其他對(duì)(=其他元素)則變得不相似。具體來(lái)說(shuō),論文[2]中的偽代碼過(guò)程如下:
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
計(jì)算圖像和文本嵌入的余弦相似度后,它們應(yīng)用交叉熵?fù)p失,使得相似度矩陣中的對(duì)角線元素變?yōu)橐唬渌刈優(yōu)榱恪W髡叻Q這種計(jì)算為對(duì)比損失。CLIP僅通過(guò)這種對(duì)比損失進(jìn)行訓(xùn)練。
對(duì)于零樣本分類,過(guò)程如下。首先,我們輸入n個(gè)候選文本并獲得維度為(n,512)的嵌入。接下來(lái),我們計(jì)算目標(biāo)圖像嵌入和候選文本嵌入之間的相似性。最后,我們可以選擇最相似的候選作為類別。不是很簡(jiǎn)單嗎?
這個(gè)過(guò)程簡(jiǎn)單直觀,但我們需要用數(shù)百萬(wàn)的圖像和文本對(duì)以及數(shù)百個(gè)GPU來(lái)訓(xùn)練CLIP。從原始論文中,他們使用了非常大的小批量大小32,768,并在592個(gè)V100 GPU上訓(xùn)練了18天。因此,許多公司將這個(gè)模型作為基礎(chǔ)模型,而不是從頭開(kāi)始訓(xùn)練。
2. CLIP手術(shù)算法的解釋
CLIP手術(shù)主要是為了增強(qiáng)CLIP結(jié)果的可解釋性而開(kāi)發(fā)的。令人驚訝的是,CLIP手術(shù)可以在沒(méi)有任何額外訓(xùn)練的情況下可視化對(duì)應(yīng)標(biāo)簽的激活圖。由于其良好的激活圖可視化,這種技術(shù)可以應(yīng)用于“分割任何事物”,這是分割任務(wù)的基礎(chǔ)模型。我將在后面的章節(jié)中介紹應(yīng)用。
作者徹底檢查了注意力層,以實(shí)現(xiàn)無(wú)需訓(xùn)練的良好可解釋性。請(qǐng)參見(jiàn)下圖。
左側(cè)顯示了原始CLIP的注意力層,而右側(cè)顯示了CLIP手術(shù)的注意力層。他們明確指出,查詢-鍵自注意力激活了與標(biāo)簽相對(duì)應(yīng)的相反語(yǔ)義區(qū)域。另一方面,值-值自注意力可以只關(guān)注語(yǔ)義區(qū)域。這是什么意思?下圖顯示了查詢-鍵自注意力和值-值自注意力的激活圖可視化。
如您所見(jiàn),查詢鍵自注意力除了目標(biāo)標(biāo)簽區(qū)域外,還可視化了不相關(guān)的區(qū)域。相反,值-值自注意力可以專注于相應(yīng)的目標(biāo)標(biāo)簽區(qū)域。基于實(shí)驗(yàn),查詢-鍵自注意力可能導(dǎo)致特征圖混淆。請(qǐng)注意,這一事實(shí)是啟發(fā)式的,并非由數(shù)學(xué)定理推導(dǎo)出來(lái)。
此外,他們意識(shí)到激活圖在所有標(biāo)簽中具有冗余特征。請(qǐng)參見(jiàn)下圖。
如您所見(jiàn),冗余區(qū)域在跨標(biāo)簽的相同位置出現(xiàn)。因此,他們想出了一個(gè)主意,即通過(guò)移除所有標(biāo)簽中的共同激活區(qū)域來(lái)移除冗余特征。
他們是如何實(shí)現(xiàn)的?具體來(lái)說(shuō),官方實(shí)現(xiàn)如下。
# weights to restrain influence of obvious classes on others
# (batch_size, 1, 512) @ (the number of labels, 512).T = (batch_size, 1, the number of labels)
prob = image_features[:, :1, :] @ text_features.t()
# prob has (batch_size, 1, the number of labels)
prob = (prob * 2).softmax(-1)
# w has (batch_size, 1, the number of labels)
w = prob / prob.mean(-1, keepdim=True)
# element-wise multiplied features
# b is batch_size
# n_t is the number of labels
# n_i is the number of tokens (=197)
# c is the feature dimension (=512)
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
# feats has (batch_size, n_i, n_t, c)
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
feats *= w.reshape(1, 1, n_t, 1)
# redundant_feats has (batch_size, n_i, n_t, c)
redundant_feats = feats.mean(2, keepdim=True) # along cls dim
feats = feats - redundant_feats
# sum the element-wise multiplied features as cosine similarity
# similarity has (batch_size, n_i, n_t)
similarity = feats.sum(-1)
為了更好地演示,我為代碼中的每次計(jì)算添加了維度大小轉(zhuǎn)換。現(xiàn)在,讓我們一步一步地弄清楚。
第一個(gè)模塊計(jì)算權(quán)重向量,以保持每個(gè)類別的影響相等。首先,我們從圖像嵌入中提取類別標(biāo)記。在變換器架構(gòu)中,類別標(biāo)記是標(biāo)記維度中的第一個(gè)。請(qǐng)注意,類別標(biāo)記應(yīng)該包含有關(guān)所有其他標(biāo)記的信息(如果您不熟悉視覺(jué)變換器,可以參考這篇博客[5])。然后,我們計(jì)算余弦相似度并獲得相似度矩陣。接下來(lái),我們將相似度矩陣的值轉(zhuǎn)換為標(biāo)簽維度上的概率,并獲取權(quán)重矩陣。
在第二個(gè)模塊中,我們計(jì)算除了冗余特征之外的特征矩陣。首先,我們計(jì)算圖像和文本嵌入的逐元素特征矩陣。直觀地說(shuō),跨標(biāo)簽的激活區(qū)域在這張圖中將具有更高的值,如上圖所示。因此,我們可以通過(guò)在標(biāo)簽上計(jì)算平均值從特征矩陣中獲得冗余特征。從原始特征矩陣中減去冗余特征后,我們可以獲得純凈的特征矩陣。
在最后一個(gè)模塊中,我們通過(guò)沿特征維度對(duì)特征矩陣求和來(lái)獲得相似度矩陣。
對(duì)于特征圖可視化,我們需要將相似度矩陣歸一化、重塑和插值到輸入圖像大小(您可以稍后檢查使用附加代碼的實(shí)現(xiàn))作為后處理。下圖顯示了CLIP手術(shù)的結(jié)果。
如您所見(jiàn),它可以捕獲與標(biāo)簽相對(duì)應(yīng)的語(yǔ)義區(qū)域。您可以感受到這種可視化的強(qiáng)大之處。
到目前為止,我們已經(jīng)看到了CLIP手術(shù)的詳細(xì)算法。在最后一節(jié)中,我們將檢查其對(duì)現(xiàn)實(shí)世界數(shù)據(jù)的能力及其應(yīng)用。
3. 應(yīng)用:檢查現(xiàn)實(shí)世界數(shù)據(jù)的能力以及為“分割任何事物”提供點(diǎn)
在最后一節(jié)中,我將指導(dǎo)您了解CLIP手術(shù)在現(xiàn)實(shí)世界數(shù)據(jù)和“分割任何事物”(SAM)中的應(yīng)用。讓我們深入了解它們!
(1) 環(huán)境設(shè)置
作為第一步,您需要設(shè)置一個(gè)環(huán)境。我使用了ubuntu20.04、cuda11.7和Python3.10環(huán)境。首先,我使用conda創(chuàng)建虛擬環(huán)境。
conda create --name sam python==3.10 -y
conda activate sam
conda install pip
## optional: To avoid install libraries on the local environment,
## check the which pip will be used to store libraries
which pip
# I use /opt/conda/envs/sam/bin/pip in my enviornment.
接下來(lái),您需要按照官方說(shuō)明安裝Pytorch和torchvision。您可以安裝與您的環(huán)境相對(duì)應(yīng)的版本。例如,下面的命令是我的案例。
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
然后,您需要使用以下命令安裝SAM存儲(chǔ)庫(kù)和模型權(quán)重。
pip install git+https://github.com/facebookresearch/segment-anything.git
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
您還需要安裝CLIP手術(shù)存儲(chǔ)庫(kù)。
git clone https://github.com/xmed-lab/CLIP_Surgery.git
最后,您需要安裝幾個(gè)包。您可以通過(guò)pip以“pip install <library>”的格式安裝它們。
tqdm==4.66.5
ftfy==6.2.3
matplotlib
opencv-python
regex
現(xiàn)在,您已經(jīng)完成了環(huán)境設(shè)置。
(2) CLIP手術(shù)對(duì)Flickr30k數(shù)據(jù)集的能力
首先,我想檢查CLIP手術(shù)對(duì)現(xiàn)實(shí)世界數(shù)據(jù)的能力,使用Flickr30k數(shù)據(jù)集。因此,我將比較CLIP和CLIP手術(shù)激活圖。我稍后會(huì)附上使用的代碼。下圖是比較的結(jié)果。
如您所見(jiàn),原始CLIP無(wú)法精確檢測(cè)對(duì)象,但當(dāng)對(duì)象存在時(shí),CLIP手術(shù)可以檢測(cè)與標(biāo)簽相對(duì)應(yīng)的對(duì)象。然而,當(dāng)對(duì)象不存在時(shí),例如貓和植物,CLIP手術(shù)仍然存在問(wèn)題。這個(gè)問(wèn)題的一個(gè)原因是后處理中的最小-最大歸一化。當(dāng)激活圖中只有不相關(guān)區(qū)域時(shí),最小-最大歸一化可能會(huì)增強(qiáng)它們值之間的差異。為了解決這個(gè)問(wèn)題,我們可以在最小-最大歸一化之前簡(jiǎn)單地添加一個(gè)閾值。在Flickr數(shù)據(jù)集的情況下,相關(guān)區(qū)域值的閾值是0.1以上,這是通過(guò)相似度圖的直方圖檢查的。結(jié)果如下所示。
多虧了閾值,我們可以移除不相關(guān)的區(qū)域。閾值可能根據(jù)數(shù)據(jù)集而變化;因此,我們應(yīng)該使用直方圖檢查并找到該值。
(3) 為“分割任何事物”提供點(diǎn)
由于激活圖可視化的精確性,CLIP手術(shù)可以應(yīng)用于“分割任何事物”的點(diǎn)提供器。供您參考,SAM是Meta在2023年開(kāi)發(fā)的分割基礎(chǔ)模型之一。下圖顯示了架構(gòu)。
SAM的分割能力令人難以置信。然而,它不是通過(guò)帶有標(biāo)簽的分割數(shù)據(jù)集訓(xùn)練的,所以我們需要在指定對(duì)象時(shí)提供一些點(diǎn)、邊界框或掩碼。正如您所猜,這些類型的注釋非常耗時(shí)。在這里,CLIP手術(shù)幫助我們自動(dòng)找到點(diǎn)。讓我們看看如何在實(shí)際實(shí)現(xiàn)中結(jié)合CLIP手術(shù)和SAM。
為了為SAM生成點(diǎn),我們對(duì)激活圖進(jìn)行下采樣并對(duì)值進(jìn)行排序以選擇相關(guān)區(qū)域。在官方實(shí)現(xiàn)中,他們使用維度為(7 x 7)的激活圖來(lái)找到最相關(guān)的區(qū)域。當(dāng)目標(biāo)對(duì)象不存在時(shí)也存在問(wèn)題,所以我稍微修改了原始實(shí)現(xiàn),添加了一個(gè)閾值。結(jié)果如下所示。
橙色點(diǎn)指的是與標(biāo)簽相關(guān)的點(diǎn),而藍(lán)色點(diǎn)代表標(biāo)簽的負(fù)點(diǎn)。如您所見(jiàn),它可以以相當(dāng)準(zhǔn)確的精度檢測(cè)目標(biāo)標(biāo)簽坐標(biāo)。請(qǐng)注意,點(diǎn)的準(zhǔn)確性來(lái)自CLIP的能力。因此,如果CLIP不理解目標(biāo),它無(wú)法準(zhǔn)確提供目標(biāo)點(diǎn)。我將附上在此應(yīng)用中使用的Jupyter筆記本。
詳細(xì)代碼可以參考鏈接:https://gist.github.com/tanukon/55715b577a32998f3417e7cea268c658#file-clip_surgery_experiment-ipynb