微軟工程師用PyTorch實(shí)現(xiàn)圖注意力網(wǎng)絡(luò),可視化效果驚艷
近日,一個(gè)關(guān)于圖注意力網(wǎng)絡(luò)可視化的項(xiàng)目吸引了大批研究人員的興趣,上線僅僅一天,收獲 200+ 星。該項(xiàng)目是關(guān)于用 PyTorch 實(shí)現(xiàn)的圖注意力網(wǎng)絡(luò)(GAT),包括易于理解的可視化。

項(xiàng)目地址:https://github.com/gordicaleksa/pytorch-GAT
在正式介紹項(xiàng)目之前,先提一下圖神經(jīng)網(wǎng)絡(luò)(GNN)。GNN 是一類基于深度學(xué)習(xí)的處理圖域信息的方法。由于其較好的性能和可解釋性,GNN 最近已成為一種廣泛應(yīng)用的圖分析方法。現(xiàn)已廣泛應(yīng)用于計(jì)算生物學(xué)、計(jì)算藥理學(xué)、推薦系統(tǒng)等。
GNN 把深度學(xué)習(xí)應(yīng)用到圖結(jié)構(gòu) (Graph) 中,其中的圖卷積網(wǎng)絡(luò) GCN 可以在 Graph 上進(jìn)行卷積操作,但是 GCN 存在一些缺陷。因此,Bengio 團(tuán)隊(duì)在三年前提出了圖注意力網(wǎng)絡(luò)(GAT,Graph Attention Network),旨在解決 GCN 存在的問題。
GAT 是空間(卷積)GNN 的代表。由于 CNNs 在計(jì)算機(jī)視覺領(lǐng)域取得了巨大的成功,研究人員決定將其推廣到圖形上,因此 GAT 應(yīng)運(yùn)而生。
現(xiàn)在,有人用 PyTorch 實(shí)現(xiàn)了 GAT 可視化。我們來看看該項(xiàng)目是如何實(shí)現(xiàn)的。

可視化
Cora 可視化
說到 GNN,就不得不介紹一下 Cora 數(shù)據(jù)集。Cora 數(shù)據(jù)集由許多機(jī)器學(xué)習(xí)論文組成,是近年來圖深度學(xué)習(xí)很喜歡使用的數(shù)據(jù)集。Cora 中的節(jié)點(diǎn)代表研究論文,鏈接是這些論文之間的引用。項(xiàng)目作者添加了一個(gè)用于可視化 Cora 和進(jìn)行基本網(wǎng)絡(luò)分析的實(shí)用程序。Cora 如下圖所示:

節(jié)點(diǎn)大小對應(yīng)于其等級(jí)(即進(jìn)出邊的數(shù)量)。邊的粗細(xì)大致對應(yīng)于邊的「popular」或「連接」程度。以下是顯示 Cora 上等級(jí)(進(jìn)出邊的數(shù)量)分布的圖:

進(jìn)和出的等級(jí)圖是一樣的,因?yàn)樘幚淼氖菬o向圖。在底部的圖(等級(jí)分布)上,我們可以看到一個(gè)有趣的峰值發(fā)生在 [2,4] 范圍內(nèi)。這意味著多數(shù)節(jié)點(diǎn)有少量的邊,但是有 1 個(gè)節(jié)點(diǎn)有 169 條邊(綠色大節(jié)點(diǎn))。
注意力可視化
有了一個(gè)訓(xùn)練好的 GAT 模型以后,我們就可以將某些節(jié)點(diǎn)所學(xué)的注意力可視化。節(jié)點(diǎn)利用注意力來決定如何聚合周圍的節(jié)點(diǎn),如下圖所示:

這是 Cora 節(jié)點(diǎn)中邊數(shù)最多的節(jié)點(diǎn)之一(引用)。顏色表示同一類的節(jié)點(diǎn)。
熵直方圖
另一種理解 GAT 沒有在 Cora 上學(xué)習(xí)注意力模式 (即它在學(xué)習(xí)常量注意力) 的方法是,將節(jié)點(diǎn)鄰域的注意力權(quán)重視為概率分布,計(jì)算熵,并在每個(gè)節(jié)點(diǎn)鄰域積累信息。
我們希望 GAT 的注意力分布有偏差。你可以看到橙色的直方圖是理想均勻分布的樣子,而淺藍(lán)色的是學(xué)習(xí)后的分布,它們是完全一樣的。

分析 Cora 嵌入空間 (t-SNE)
GAT 的輸出張量為 shape=(2708,7),其中 2708 是 Cora 中的節(jié)點(diǎn)數(shù),7 是類數(shù)。用 t-SNE 把這些 7 維向量投影成 2D,得到:

使用方法
方法 1:Jupyter Notebook
只需從 Anaconda 控制臺(tái)運(yùn)行 Jupyter Notebook,它將在你的默認(rèn)瀏覽器中打開 session。打開 The Annotated GAT.ipynb 即可開始。
注意,如果你得到了 DLL load failed while importing win32api: The specified module could not be found,只需要 pip uninstall pywin32,或者 pip install pywin32、onda install pywin32。
方法 2:使用你選擇的 IDE
如果使用自己選擇的 IDE,只需要將 Python 環(huán)境和設(shè)置部分連接起來。
訓(xùn)練 GAT
在 Cora 上訓(xùn)練 GAT 所需的一切都已經(jīng)設(shè)置好了,運(yùn)行時(shí)只需調(diào)用 python training_script.py
此外,你還可以:
添加 --should_visualize - 以可視化你的圖形數(shù)據(jù)
在數(shù)據(jù)的測試部分添加 --should_test - 以評估 GAT
添加 --enable_tensorboard - 開始保存度量標(biāo)準(zhǔn)(準(zhǔn)確率、損失)
代碼部分的注釋很完善,因此你可以了解到訓(xùn)練本身是如何運(yùn)行的。
該腳本將:
將 checkpoint* .pth 模型轉(zhuǎn)儲(chǔ)到 models/checkpoints/
將 final* .pth 模型轉(zhuǎn)儲(chǔ)到 models/binaries/
將度量標(biāo)準(zhǔn)保存到中 runs/,只需 tensorboard --logdir=runs 在 Anaconda 中運(yùn)行即可將其可視化
定期將一些訓(xùn)練元數(shù)據(jù)寫入控制臺(tái)
通過 tensorboard --logdir=runs 在控制臺(tái)中調(diào)用,并將 http://localhost:6006/URL 粘貼到瀏覽器中,可以在訓(xùn)練過程中將度量標(biāo)準(zhǔn)可視化:


可視化工具
如果要可視化 t-SNE 嵌入,請注意或嵌入該 visualize_gat_properties 函數(shù)的注釋,并設(shè)置 visualization_type 為:
VisualizationType.ATTENTION - 如果希望可視化節(jié)點(diǎn)附近的注意力
VisualizationType.EMBEDDING - 如果希望可視化嵌入(通過 t-SNE)
VisualizationType.ENTROPY - 如果想可視化熵直方圖
然后,你就得到了一張優(yōu)秀的可視化效果圖(VisualizationType.ATTENTION 可選):


硬件需求
GAT 不需要那種很強(qiáng)的硬件資源,尤其是如果你只想運(yùn)行 Cora 的話,有 2GB 以上的 GPU 就可以了。
在 RTX 2080 GPU 上訓(xùn)練 GAT 大約需要 10 秒;
保留 1.5 GB 的 VRAM 內(nèi)存(PyTorch 的緩存開銷,為實(shí)際張量分配的內(nèi)存少得多);
模型本身只有 365 KB。
視頻鏈接:https://v.qq.com/x/page/v3225t65a0q.html?start=8