CNN은 이미지, 텍스트, 비디오 데이터를 다루는 데 아주 유용하다. 이 데이터들을 graph의 관점에서 보면 고정된 size를 갖는 간단한 graph(혹은 sequence)로 생각해볼 수 있다.
하지만 실생활에서의 수많은 데이터는 size가 정해져있지 않고, 훨씬 복잡한 graph로 표현해야 하는 경우가 많다. 이러한 경우에는 CNN보다는 GNN이 훨씬 유용하다. (특히 scene graph 연구를 하다보면 GNN을 필수적으로 사용해야 할 것이다.)
물론 scene(vision)과 관련된 task에서는 attention 개념을 활용하는 모델이 더 많이 사용될 수는 있겠지만, graph 이론을 배워보는 과정에서 attention을 적용하지 않고 GNN의 graph representation 성능을 최대한 끌어올린 GIN 모델을 제안한 논문을 읽어보고, 정확히 이해해보려 한다.
목차
Core Questions
논문을 읽고 얻은 insight는 다음과 같다.
- What did authors try to accomplish? (Contributions)
- Graph Isomorphism Network (GIN)을 제안했다.
- Expressive power(서로 다른 graph는 서로 다른 embedding space로 mapping하는 능력)가 아주 높다.
- 그래프 구조의 유사도를 알아낸다.
- 다양한 GNN 변형 모델(GCN, GraphSAGE 등)의 한계점 및 특정 task에서 뛰어난 점을 이론적으로 분석했다. (포스팅에서는 다루지 않음)
- Graph Isomorphism Network (GIN)을 제안했다.
- What were the key elements of the approach?
- GNN에 Weisfeiler-Lehman (WL) graph isomorphism test 개념을 적용한 idea
- Close connection between GNNs + WL Test
- Weisfeiler-Lehman graph isomorphsim test : injective function 개념
- GNN에 Weisfeiler-Lehman (WL) graph isomorphism test 개념을 적용한 idea
- What can I use myself?
- GNN 기반의 모델을 알아보고, 앞으로 연구에 응용하여 사용해볼 수 있을 것이다.
- What other references do I want to follow?
- GATs (Graph ATtention networks) 등, 다른 GNN 변형 모델
Introduction
그래프 구조의 데이터를 학습할 때 중요한 것 중 하나는 구조를 효과적으로 representation하는 것이다. 따라서 그래프 구조를 학습하는 것을 representation learning이라고도 한다.
2016년부터 이러한 representation learning 방법으로 neural network를 사용하는 GNN과 그것을 기반으로 변형한 다양한 연구가 많았다.
GNN은 message passing(또는 recursive neighborhood aggregation)이라는 원리로 동작한다. Aggregation이 \(k\)번 반복된 후에, 어떤 노드의 feature vector는 그 노드의 \(k\)-hop neighborhood (k번 떨어져 있는 이웃 노드)까지의 구조적인 정보를 갖게 된다.
그래프 전체의 representation은 이러한 node feature들을 pooling하여 얻을 수 있다.
이를 기반으로, neighborhood aggregation 또는 graph-level pooling을 다른 방법으로 하는 다양한 모델이 등장했고, 이 모델들은 node classification, link prediction, graph classification 등의 그래프를 활용하는 다양한 task에 최고 성능을 보였다.
하지만 GNN을 변형하는 idea는 대부분 GNN의 한계와 특징에 대한 이해 없이, 직관 혹은 수없이 많은 실험을 통해 얻은 경우가 많다. (애초에 GNN의 representational capacity를 이론적으로 분석한 경우가 거의 없다고 한다.)
논문에서 제안하는 새로운 GNN 기반 모델인 GIN은 어떻게 representational power를 높였는지 알아보자.
또한 GNN의 representational capacity를 이론적으로 분석해보고, GNN 기반 모델들이 어떻게 학습 과정에서 graph를 표현하고, 서로 다른 graph 구조를 구별하는지 알아볼 것이다.
Preliminaries
GNN과 WL test가 어떤건지 좀 더 자세히 살펴보자.
우선 공통적으로 사용되는 notation은 아래와 같다.
- \(G = (V, E)\) : node \(v \in V\)의 feature vector가 \(X_v\)인 그래프를 나타낸다.
- Node classification에서 : 각 노드 \(v \in V\)가 label \(y_v\)를 갖고, 목표는 representation vector \(h_v\)를 학습하여 \(v\)의 label을 예측하여 \(y_v = f(h_v)\)를 만족하도록 하는 것이다.
- Graph classification에서 : 그래프 set \(\{G_1, \dots, G_N\} \subseteq \mathcal{g} \)와 각 graph의 label \(\{y_1, \dots, y_N\} \subseteq \mathcal{y}\)가 주어졌을 때, representation vector \(h_G\)를 학습하여 그래프의 label을 예측하여 \(y_G = g(h_G)\)를 만족하도록 하는 것이다.
Graph Neural Networks (GNNs)
GNN은 node classification에서 노드의 representation vector \(h_v\)를 학습하기 위해 노드 feature \(X_v\)를 사용하고, graph classification에서는 그래프의 representation vector \(h_G\)를 학습하기 위해 graph structure를 사용한다.
그 과정에서, 어떤 노드의 representation을 반복하여 update하는데, 이때 그 노드의 이웃 노드의 representation을 합한다. 이를 neighborhood aggregation strategy라 한다.
이러한 aggregation을 \(k\)번 반복하면 한 노드의 representation은 \(k\)-hop neighborhood (k만큼 떨어진 이웃노드) 까지의 구조적인 정보를 담게 된다.
이 과정을 수식으로 나타내면 다음과 같다.
\( a_v^{(k)} = \operatorname{AGGREGATE}^{(k)} \left( \left\{ h_u^{(k-1)} : u \in \mathcal{N} (v) \right\} \right) \)
\( h_v^{(k)} = \operatorname{COMBINE}^{(k)} \left( h_v^{(k-1)}, a_v^{(k)} \right) \)
각 term의 의미는 다음과 같다.
- \(h_v^{(k)}\) : \(k\)번째 iteration(layer)에서 node \(v\)의 feature vector
- initial value, 즉 \(h_v^{(0)}\)는 node feature vector \(X_v\)로 둔다.
- \(\mathcal{N} (v)\) : \(v\)의 인접 노드 set
여기서는 \(\operatorname{AGGREGATE}^{(k)}(\cdot)\)와 \(\operatorname{COMBINE}^{(k)} (\cdot)\)을 어떤 함수로 설정할 것인지에 따라 여러 모델이 존재한다.
여기서는 대표적인 GraphSAGE와 GCN 모델을 살펴보자.
먼저, GraphSAGE 모델에서는 \(\operatorname{AGGREGATE}\) step에서는 max-pooling, \(\operatorname{COMBINE}\) step에서는 linear mapping 함수를 사용한다.
\( a_v^{(k)} = \operatorname{MAX} \left( \left\{ \operatorname{ReLU} \left( W \cdot h_u^{(k-1)} \right), \forall u \in \mathcal{N}(v) \right\} \right) \)
\( h_v^{(k)} = W \cdot \left[ h_v^{(k-1)}, a_v^{(k)} \right] \)
여기서 \(W\)는 학습 가능한 matrix이며, \(\operatorname{MAX}\)는 element-wise max-pooling을 의미한다.
그리고 Graph Convolutional Network (GCN) 모델에서는 element-wise mean-pooling을 사용하며, \(\operatorname{AGGREGATE}, \operatorname{COMBINE}\) step을 묶어서 다음과 같이 진행한다.
\( h_v^{(k)} = \operatorname{ReLU}\left( W \cdot \operatorname{MEAN} \left\{ h_u^{(k-1)}, \forall u \in \mathcal{N} (v) \cup \{ v \} \right\} \right) \)
이렇게 얻은 node representation \(h_v^{(k)}\)를 사용하여 node classification 혹은 graph classification을 수행한다.
Node classification task에서는 마지막 iteration의 \(h_v^{(K)}\)로 prediction을 수행한다.
Graph classification task에서는 node feature를 모든 node들에 대해 합하여(aggregation) 그래프 전체의 representation \(h_G\)를 얻기 위한 과정이 더 필요한데, 이때 사용하는 함수를 \(\operatorname{READOUT}\) 함수라 한다.
\( h_G = \operatorname{READOUT} \left( \left\{ h_v^{(K)} \vert v \in G \right\} \right) \)
\(\operatorname{READOUT}\) 함수는 입력의 순서에 상관 없이 같은 출력을 생성하는(permutation invariant) 함수(예를 들어, summation, graph-level pooling function 등)를 사용할 수 있다.
Weisfeiler-Lehman test
Graph isomorphism problem은 두 그래프가 위상수학적으로(topologically) 같은지를 알아보는 NP-hard 문제로, 아직 다항 시간(polynomial-time) 알고리즘이 존재하지 않는다. (Graph isomorphism problem, WL test에 대한 더 자세한 내용은 링크를 참조하자.)
Weisfeiler-Lehman test (WL test)는 이 문제를 (그나마) 효율적으로 풀 수 있는 방법으로, 간단히 다음 과정을 거쳐 graph의 종류를 구분하는 알고리즘이다.
- 어떤 node의 label과 그 노드의 이웃 노드의 label을 합한다(aggregation).
- Hash table을 사용하여 합한 label(key)을 새로운 (unique한) label(value)으로 지정한다.
위 과정을 반복한 후, 두 그래프의 노드의 label이 다르면 non-isomorphic (위상학적으로 같지 않은) graph로 본다.
이후에 이를 기반으로 WL subtree kernel이라는 개념이 등장했다(color refinement algorithm). WL subtree kernel은 graph의 유사도를 측정하는 것으로, WL test의 각 iteration에서의 node label에 color라는 값을 할당하여 이를 그래프의 feature vector(=color count vector, rooted subtree)로 사용한다.
즉, \(k\)번째 iteration에서의 노드의 label이 곧 그 노드를 root로 하고, height가 \(k\)인 subtree 구조를 갖는 것이다.
따라서 graph feature는 해당 graph의 서로 다른 rooted subtree를 세는 vector로 볼 수 있다. (rooted subtree도 결국 vector로 나타날 것이다.)
Graph Isomorphism Network (GIN)
그림에서 rooted subtree는 어떤 노드의 feature이다. (그 노드를 root로 하여 그 이웃 노드들과의 구조 정보를 표현)
GNN은 이러한 rooted subtree 구조를 반복적으로 update한다.
각 feature vector마다 unique한 label \(\{a, b, c, \dots \}\)을 할당한다. 따라서 한 노드의 이웃 노드 set의 feature vector들은 multiset\)을 형성하게 된다.
그냥 set이 아니라 multiset인 이유는, 이웃 노드 중 서로 다른 노드가 같은 feature vector를 가질 수 있는데, 이 경우 multiset에는 같은 feature vector가 여러 번 등장하게 되기 때문이다.
Multiset은 2개 element를 갖는 tuple로 다음과 같이 정의할 수 있다.
\( X = (S, m) \)
여기서 \(S\)는 feqture vector들의 set(여기는 서로 다른 unique한 feature vector만 포함된다.)이고, \(m: S \rightarrow \mathbb{N}_{\geq 1}\)을 통해 여러 번 포함된 요소를 표현해준다.
이에 따라, GNN의 representational power가 강하다는 것은 곧 두 노드의 subtree 구조가 같을 때만 GNN의 맵핑에 의해 두 노드의 embedding이 같아진다는 것이다.
이를 분석하기 위해서는 subtree는 이웃 노드, 즉 multiset에 의해 재귀적으로 정의되므로, GNN이 두 multiset을 같은 embedding(=representation)으로 맵핑하는지를 알아보면 될 것이다. 다시말해, GNN의 representational power가 최대라면, 서로 다른 두 이웃 노드(multiset)를 절대 같은 representation으로 맵핑하지 않을 것이다. 이는 곧 aggregation 과정이 injective(일대일 대응)하다는 의미이다.
Building Powerful Graph Neural Networks
Maximally Powerful GNN
가장 강력한 GNN이 되기 위해서는 다음 정리를 만족해야 한다.
GNN \( \mathcal{A} : \mathcal{g} \rightarrow \mathbb{R}^d \)에 대해, layer 수가 충분할 경우, \(\mathcal{A}\)가 WL test 결과 non-isomorphic한 graph \(G_1, G_2\)를 서로 다른 embedding으로 맵핑할 조건은 다음과 같다.
a) \(\mathcal{A}\)는 다음 식을 반복하며 node feature를 aggregate 및 update한다.
\( h_v^{(k)} = \phi \left( h_v^{(k-1)}, f \left( \left\{ h_u^{(k-1)} : u \in \mathcal{N} (v) \right\} \right) \right) \)
여기서 함수 \(f\)는 multiset(이웃 노드)에 적용되며, \(\phi\)는 일대일 함수이다.
b) \(\mathcal{A}\)의 graph-level readout 함수는 node features \(\left\{ h_v^{(k)} \right\}\)의 multiset에 적용되며, 일대일 함수이다.
Input node feature가 셀 수 있는 경우에 injectiveness(일대일 대응) 개념이 잘 적용되므로, 이 경우만 다루도록 한다. (셀 수 없다는 것은 node feature가 연속적이라는 의미이다.)
Input feature space \(\mathcal{X}\)가 셀 수 있다고 가정하자.
\(g^{(k)}\)를 GNN의 \(k\)번째 layer로 parmeterize한 함수라 하자. (이때 \(g^{(1)}\)은 제한된 size의 multisets \(X \subset \mathcal{X}\)으로 정의한다.)
여기서 \(g^{(k)}\)의 범위, 즉 node hidden feature \(h_v^{(k)}\)의 space 또한 모든 \(k = 1, \dots , L\)에 대해 셀 수 있다.
여기서, GNN을 사용했을 때의 이점을 알 수 있다. GNN을 사용하면 서로 다른 그래프를 구별하는 것 뿐만 아니라 그래프의 구조적 유사성 (similarity of graph structures)을 알아낼 수 있다.
WL test에서는 node feature vector가 one-hot eoncding이므로 subtree의 similarity까지는 알 수 없는데, GNN에서는 subtree를 저차원 공간으로 임베딩하는 것을 학습하기 때문에 위 정리를 만족시키면서 WL test를 진행할 수 있다.
유사성을 알면 서로 다른 graph에서 subtree(neighbor 구조 정보)가 같은 경우가 드문 경우, edge feature와 node feature 중 noise가 있는 경우에 유용하다.
Computing Node Feature \(h_v\) in GIN
이제 \(\operatorname{AGGREGATE}\), \(\operatorname{COMBINE}\) 과정을 살펴보자.
GIN은 위에서 소개한 정리를 만족하면서도, 구조가 간단한 아키텍쳐이다.
먼저, 논문에서는 \(\operatorname{AGGREGATE}\) step에서 injective multiset function을 deep multisets로 모델링하였다.
Deep multisets란 neural network를 통해 전체 multiset function을 parameterize하는 방법이다.
\(\mathbf{X}\)가 셀 수 있을 때, 다음을 만족하는 function \(f: \mathcal{X} \rightarrow \mathbb{R}^n\)는 size가 제한된 각 multiset \(X\)에 대해 유일하다. (set에서는 mean aggregator를 쓰는 것이 일반적인데, multiset은 같은 요소가 여러 번 등장할 수 있으므로 mean aggregator를 사용하지 못하므로 아래와 같이 sum aggregator를 사용한다.)
\( h(X) = \sum_{x \in X} f(x) \)
그리고 multiset function \(g\)(GNN의 \(k\)번째 layer를 parameterize한 함수)는 function \(\phi\)를 사용하여 다음과 같이 분해할 수 있다.
\( g(X) = \phi \left( \sum_{x \in X} f(x) \right) \)
Aggregation 과정을 위 식을 통해 노드와 그 노드의 이웃 노드(multiset)에 대한 universal function으로 표현할 수 있고, 이에 따라 위에서 언급한 정리에서의 조건 중 injectiveness conidtion(a)를 만족시킬 수 있다.
그리고 aggregation 과정들 간의 formulation을 input feature vector, 즉 어떤 node representation인 \(c \in \mathcal{X}\)와 무리수를 포함하는 무수히 많은 수 중 하나인 \(\epsilon\) 개념을 추가하여 표현할 수 있다. (multiset \(X\)는 \(\mathcal{X}\)의 제한된 size를 갖는 subset)
\( h(c, X) = (1 + \epsilon) \cdot f(c) + \sum_{x \in X} f(x) \) (각 pair \((c, X)\)에 대해 유일)
\( g (c, X) = \varphi \left( (1 + \epsilon) \cdot f(c) + \sum_{x \in X} f(x) \right) \)
최종 식은 input feature, 즉 주어진 노드 중 하나의 representation \(c\)에 대한 \(f\)와 어떤 multiset(어떤 노드의 이웃 노드)도 함께 고려한다. \(c\)를 node \(v\)의 representation, \(X\)를 \(v\)의 이웃 노드를 나타내는 multiset으로 설정한다면 \(h_v\)를 구할 수 있을 것이다.
위 식의 \(f\)와 \(\varphi\)를 학습하기 위해 MLP를 사용한다. (universal approximation theorem)
실제 구현 시에는 \(f^{(k+1)} \circ \varphi^{(k)}\)를 하나의 MLP로 모델링하는데, 이는 MLP가 함수의 합성(composition of functions)을 표현할 수 있기 때문이다. 첫 번째 iteration에서는 input feature가 one-hot encoding인 경우 summation만 일대일 대응이기 때문에 MLP가 필요 없다. 또한 \(\epsilon\)은 고정된 scalar로 줄 수도, 학습 parameter로 줄 수도 있다.
따라서 GIN의 node representation은 다음과 같은 식으로 update된다.
\( h_v^{(k)} = \operatorname{MLP}^{(k)} \left( \left( 1 + \epsilon^{(k)} \right) \cdot h_v^{(k-1)} + \sum_{u \in \mathcal{N} (v)} h_u^{(k-1)} \right) \)
GNN에서의 \(\operatorname{AGGREGATE}, \operatorname{COMBINE}\) 두 과정을 다시 복기해보자.
\( a_v^{(k)} = \operatorname{AGGREGATE}^{(k)} \left( \left\{ h_u^{(k-1)} : u \in \mathcal{N} (v) \right\} \right)\)
\(h_v^{(k)} = \operatorname{COMBINE}^{(k)} \left( h_v^{(k-1)}, a_v^{(k)} \right) \)
비교해보면 GIN에서 \(sum\) 부분이 \(\operatorname{AGGREGATE}\)과정, \(\operatorname{MLP}\) 부분이 \(\operatorname{COLMBINE}\)과정임을 알 수 있다.
Graph-level READOUT of GIN
GIN에서 학습한 \(h_v\), 즉 node embedding은 node classification이나 link prediction task에서는 바로 쓰일 수 있다. 하지만 graph classification task에서는 readout 함수를 거치면서 node 각각의 embedding을 연결하여 그래프 전체의 embedding으로 바꿔주어야 한다.
한 가지 중요한 사실은, iteration이 증가함에 따라 node representation과 이에 대응하는 subtree structure가 점점 정제되고, global해진다는 것이다. (node representation이 갖게 되는 이웃 노드의 범위가 늘어나므로)
충분한 반복 횟수가 좋은 representational power의 핵심이긴 하지만 반복 횟수가 적을 때의 feature가 generalize 성능이 더 좋을 수도 있다.
따라서 모든 구조 정보를 고려하기 위해, 모델의 모든 iteration에서의 정보를 사용한다. 수식으로 표현하자면, 기존 readout 함수는 마지막 iteration만 고려하여 \(h_G = \operatorname{READOUT} \left( \left\{ h_v^{(K)} \vert v \in G \right\} \right)\)로 나타냈지만, GIN에서는 다음과 같은 READOUT 함수를 사용한다.
\( h_G = \operatorname{CONCAT} \left( \operatorname{READOUT} \left( \left\{ h_v^{(k)} \vert v \in G \right\} \right) \; \vert \; k = 0, 1, \dots, K \right) \)
Multiset에 대해 sum aggregator를 사용하는 이유
Aggregation 과정에서 왜 sum aggregator를 사용할 때가 가장 representational power가 좋은지 알아보자.
왼쪽과 같은 input multiset(합쳐질 이웃 노드)이 들어왔다고 가정하면, sum aggregator는 전체 multiset을 그대로 합치지만, mean은 요소들의 비율이나 분포 정보를 갖게 되고(파랑:빨강 개수가 4:2였으므로 2:1만 반영됨), max는 겹치는 요소 중 작은 것들을 모두 무시해버린다.
다라서 위 예시와 같이, 분명 다른 graph 구조이지만 mean 또는 max aggregator를 사용하면 똑같다고 취급하는(즉, 서로 다른 노드를 embedding space의 같은 위치로 맵핑해버리는) 경우가 발생하게 된다.
그 결과 representational power가 약해지는 것이다.
최근댓글