圖像化神經網路(2) — 更進一步探討GNN

Martin Huang
12 min readDec 21, 2022

--

第一篇在此。本篇延續[1],後1/3,「Into the Weeds」段落。

其他型態的圖像

在前一篇我們討論的圖像中,各種值(端點、連結、整體)儲存的都是向量。圖像其實可以有更複雜的架構,關鍵在於訊息傳遞允許更有彈性的傳遞。例如多重連結,即兩個端點間有不同種類的連結,用於描述端點間不同性質的關係。在這種狀況,訊息傳遞可以根據不同種類的連結分開進行(圖1.)。圖像也可以是巢狀結構,即一個端點代表一個圖像。在這個層級上,GNN學到端點的資訊將會代表一個圖像,而連結的資訊則是圖像間的互動或關係,其整體則為更宏觀。

圖1. 各種型態的連結。來源[1]

關於GNN訓練的batch

批次訓練(batch)在深度學習已經是幾乎必備的方式,主要是有更具效率的更新權重,以及節省記憶體的益處。在GNN要怎麼延續使用batch的訓練方式呢?

如果是以整個圖像考量,因為每個端點的連結,以及相鄰的端點數量不固定,可能造成GNN訓練時的批次量(batch size)無法固定。因此,提出從圖像中取出部分,但性質上能代表原圖像的次圖像(subgraph)概念。這個類似採樣的步驟,相當受到選到的端點的聯繫,以及連結的影響。此外,不同性質的圖像,對於這樣做法受到的影響程度也不同。例如文章引用的連結,或許局部圖像網路還能代表某種程度的整體;但對於一個化合物分子結構的圖像而言,取樣之後只拿到部分的原子結構,等於是一個新的分子了。如何取樣本身就是一個值得研究的問題。

假如我們希望保留端點相鄰的關係,在取樣時,我們可以設定「同時取走端點周圍相鄰的端點」,取相鄰幾步範圍內的端點也可以設定,同時取他們的連結。如此一來,可以每次取固定數目的端點,然後再加上相鄰的端點,組成端點集(node-set),批次訓練時,可以將這些相鄰的端點視為次圖像。損失(loss)的評估也僅限於端點集內的端點,因為所有相鄰的端點,其連結可能不完整(如果其相鄰的端點超過設定的步數,就不會被納入)。

另一個做法是,先隨機取一端點,並囊括其周圍一定步數的端點。接著再從這些端點之中任選一個端點,再延伸一定步數出去,如此反覆,直到達到特定的端點、連結數量,或是想要的次圖像建立。這樣一來,批次訓練的端點數量就可以固定了。以下列出幾種取樣的方法。

來源[1]

和一般深度學習的批次訓練一樣,GNN的批次訓練目的也是在節省記憶體,所以對於架構越大的圖像,批次訓練就越重要。

歸納偏置(inductive bias)

所謂歸納偏置,指的是模型在預測時,能基於掌握到的特徵而做出判斷。也就是說,模型「確實」有學到東西。例如,利用CNN做貓狗圖片的二元分類,這些圖片中,貓狗在圖片中的位置並不固定。然而,模型仍然能做出準確的分類,不受到貓狗在圖片中位置的差異影響。這個CNN模型就展現出歸納偏置的能力。一般要驗證這個能力,不會用訓練時,也就是模型「看過」的資料測試。Grad-CAM則是可以解釋歸納偏置的原因。

關於Grad-CAM:請看這篇

GNN的模型也必須要有這種特性。在圖像中,模型必須保留值之間的關聯性(像關聯矩陣那樣)同時又保有翻轉不變性。找到圖像中具關連性的特徵,是模型成功的關鍵,也是人類想知道的。

聚集方式的比較

前一篇提過GNN的運作裡,聚集是很關鍵的步驟。由於每個端點相鄰的端點數目不固定,而且也希望這個過程可微分,因此必須找到一個平滑的、不受到端點順序或鄰近端點數目影響的方法。

如何選擇聚集的方法也是一個值得研究的問題。一個概念是:相似的輸入會得到相似的聚集結果,反之亦然。比較簡單的做法是像CNN裡的pooling一樣,取最大值、平均值,總和等等。另外,取變異數也可以考慮。當然,實際上這些方法,對於相同的輸入也未必能有相同的輸出。一個具體的例子如下:比較最大值、平均值和總和,輸出的結果:

左圖:三種聚集方法中,最大值無法分辨兩種輸入。右圖:只有總和可以分辨兩種不同的輸入。來源[1]

因此,實際上沒有一個方法是最佳解,總會遇到使用單一方法遇到的短版。平均值方法適用於相鄰的端點數目變異很大,或想要看各局部的端點群,將其標準化以方便比較。最大值適合挑出個局部端點群中較突出的。總和在兩種方法中取平衡,可以看出局部端點群的趨勢,但沒有標準化,也可能把極端值列入。然而,實務上最常用的還是加總。

聚集的方法,我個人看起來已經很像演算法了。他也是一種處理資料的方法,而資料科學裡的方法千百種,沒有絕對的適應症,只有對於資料本身適不適合的問題。只有足夠了解資料,才能挑選出有效率的方法。

把GCN視為逼近次圖像的函數

這段如果有更好的見解,歡迎幫我補充。

大意是說,一個具有k層的GNN,可以像把傳統k層的神經網路,視為擷取在k範圍內的部分輸入一樣,擷取k步範圍內的次圖像。

例如,聚焦在單一個端點時,k層的GNN意味著擷取從該端點算起,k步內的鄰端點所組成的次圖像。連結也是一樣。但GNN不只擷取單一端點的所有次圖像,它擷取的是所有端點,從各自為中心算起k步內所組成的次圖像。因此,隨著圖像的端點數目增加,次圖像的總量將會以可觀的速率上升。比起先將所有的次圖像全部找出來,再訓練GNN,利用GNN自己的運算方式,沿著固定的方向,依序取端點為中心,會比較有效率。可以和上面的批次取樣一起搭配著參考。

圖像的端點-連結雙重性

在利用GNN解決圖像問題時,圖像的端點和連結,往往背後指向同一個訊息。亦即,端點和連結只是用不同的方式,儲存相同的資訊。這使圖像具雙重性:有時是要解決端點的問題,但轉換成連結的形式會比較容易處理;反之亦然。利用這個概念,延伸出雙重性GNN(Dual-Primal Graph Convolutional Networks)[2]。

圖像卷積如同矩陣相乘,而矩陣相乘如同走在圖像中

在前一篇花了不少篇幅討論圖像卷積和訊息傳遞,但實作上究竟怎麼完成?

我們一樣聚焦回到單一個端點。關聯矩陣A描述該端點與其他端點的互動情形,而特徵矩陣X則記錄儲存各端點的資訊,包含該端點。關聯矩陣的維度為端點數*端點數,而特徵矩陣的維度則為端點數*端點資訊維度。讓B=AX,關聯矩陣和特徵矩陣相乘的過程中,訊息傳遞出去了。

此B矩陣內,各元素的值可表示為

但A是關聯矩陣,其特性為一二元矩陣,即僅兩端點有連結時才有值,否則為0。故對單一端點而言,此內積意義等同於「聚集所有與該端點有連結的j維端點特徵」。在此乘法中,訊息傳遞並未更新端點特徵,而僅是聚集而已。然而,要更新,也只是在這之前/之後再送入一個可更新的神經網路即可。

關聯矩陣非常鬆散,如同上一段和前一篇所講到的,在沒有關聯的端點間,其值全為0。加總時,有非常多項其實都是0,如果還做運算,是浪費運算資源的。因此adjacent list就很重要了,它記錄有連結的兩個端點,我們只要根據這個去搜尋有值的元素即可,也就不必做整個矩陣的乘法,和加總。

如果我們要讓訊息再傳遞一步要怎麼做?簡單,再乘一次關聯矩陣。第一個關聯矩陣,讓端點的資訊透過關聯矩陣,傳到了與它有直接連結的最近端點。第二個關聯矩陣,則讓訊息從這些端點,再傳到與他們有直接連結的最近端點。隨著關聯矩陣的乘方增加,我們可以把訊息傳遞到距離更遠的端點上。因此,從某方面來說,矩陣的反覆乘積,可以視為遍歷圖像的過程。或直接將兩個關聯矩陣先相乘,A²:

第一項只有在關聯矩陣中,端點i與端點1,以及端點1和端點j都有連結時,值才不為0。亦即訊息透過端點1,從節點i跨過兩個連結到節點j。藉由加總,我們一樣可以取得所有符合這樣情況可做訊息傳遞的項目,從而將訊息聚集。由此可推廣到A³、A⁴…A^n。

圖像的Attention網路

另外一種傳遞訊息的方法就是使用Attention網路。例如在傳遞某個端點的訊息到鄰近端點時我們使用加總聚集,但這時也可以加點權重(weight sum)。但要注意的是,權重必須有翻轉不變性,以免影響到端點間關聯性。

要如何取得權重?一個做法是建立一個函數,其輸入為某端點周圍的兩個節點,輸出則為此對端點的權重。利用這個方式,我們可以根據某端點周圍,和該端點的關聯性,給予對應的權重。這個權重也可以再標準化,例如使用softmax函數,將權重壓到0–1之間,對於任務越有關聯的兩個端點,甚至是和某端點有連結的,其權重越高。這是源自於圖像attention網路(GAT)[3]和Set transformer[4]的概念。

由於輸入函數的是一對端點,其翻轉不變性可被保存。在輸入函數前,端點的向量會透過線性映射轉換成query和key,來增強此函數的效果。其實就是對應到attention網路裡面QKV的觀念啦。Q和K由端點的資訊線性變換而來,相乘之後取softmax,作為權重再和也經過線性變換,代表原始輸入向量的V相乘。

來源[1]

這個權重還有個好處:作為解釋一對端點間在此任務的重要性,若此對端點有連結,那就同時解釋此連結在任務的重要性(利用圖像的雙重性)。

圖像的可解釋性/歸因性

訓練完成,部署GNN時,當然會在意其在外部資料的表現。評估表現的一個重要部分是可解釋性,應該說,在GNN,可解釋性比其他神經網路更重要,因為我們就是要尋找關聯,或者關鍵結構,才建立圖像的。由於使用圖像的場合,彼此間差異很大,GNN要解釋的內容也隨著任務種類而有差異。舉例來說,在分子構造,我們可能在意某個次分子結構的有無,是否會有影響;而在文章引用系統,我們可能更在意文章之間的關聯程度。

有一些神經網路被設計來滿足解釋的需求。例如GNNExplainer[5],他能夠萃取和任務相關的重要次圖像。Attribution techniques[6]則會將次圖像依據對任務影響的重要性排序。

來源[1]

生成式模型

對標CV和NLP的GNN都有了,那有沒有對標GAN的呢?

先不提對抗,但生成圖像的網路是值得訓練的。在某些情境下,我們可能希望有個神經網路可以根據需求,產生一些圖像構造,輔助專家激發新的想法。但生成網路最大的挑戰,當然是決定圖像的架構,這包括要使用那些端點,以及彼此要如何連結。比較簡單的具體流程是把一個圖像輸入之後,神經網路可產生出類似屬性或構造的圖像。其中一個解決方法,關鍵是產生adjacent list,除了標示出端點,也標示連結。連結的有無是一種二元分類任務,如果我們專注在預測只有連結的,或只看沒有連結的部分,運算的負擔會減輕許多。

另一個解決方法,是序列式的產生圖像,從既有(輸入)的圖像,逐個端點和連結去調整。這邊有提到利用policy gradient來避免預測離散的梯度,貌似是在強化學習的場合使用,我缺乏這部分的背景知識,所以暫時先保留。有找到一篇解釋policy gradient的網路文章,連結在此,如果有興趣的人可以再過去看看囉。

結語

圖像化神經網路處理因為各種場合需要而建立的圖像學習問題。這兩篇文章中整理了圖像和GNN的基本觀念、一些關鍵,以及稍作深入的探討。GNN playground是很有趣的互動設計,推薦讀者多玩玩看。

這系列後面還會有幾篇文章,有空的話我會實作,到時再和大家分享。謝謝大家看到這邊。

參考資料

[1] A Gentle Introduction to Graph Neural Networks. https://distill.pub/2021/gnn-intro/
[2] Dual-Primal Graph Convolutional Networks. F. Monti, O. Shchur, A. Bojchevski, O. Litany, S. Gunnemann, M.M. Bronstein. 2018.
[3] Graph Attention Networks. P. Velickovic, G. Cucurull, A. Casanova, A. Romero, P. Lio, Y. Bengio. 2017.
[4] Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. J. Lee, Y. Lee, J. Kim, A.R. Kosiorek, S. Choi, Y.W. Teh. 2018.
[5] GNNExplainer: Generating Explanations for Graph Neural Networks. Z. Ying, D. Bourgeois, J. You, M. Zitnik, J. Leskovec. Advances in Neural Information Processing Systems, Vol 32, pp. 9244–9255. Curran Associates, Inc. 2019.
[6] Explainability Methods for Graph Convolutional Neural Networks. P.E. Pope, S. Kolouri, M. Rostami, C.E. Martin, H. Hoffmann. 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). 2019.

--

--

Martin Huang
Martin Huang

Written by Martin Huang

崎嶇的發展 目前主攻CV,但正在往NLP的路上。 歡迎合作或聯絡:martin12345m@gmail.com

No responses yet