Lec 07 & 08 GNN
1. A General GNN Framework
- GNN layer
- Message
- Aggregate
- Layer connectivity
- Stack layers sequentially
- Ways of adding skip connections
- Graph Augmentation
- Raw input graph $\ne$ computational graph
- Graph feature augmentation
- Graph structure augmentation
- Learning Objective
- supervised / unsupervised
- node / edge / graph level
2. A single layer of a GNN
Idea : compress a set of vector into a single vector
Message computation
Idea : each node will create message and sent to other node later
Message function :
$$
m_u^{(l)} = MSG^{(l)}(h_u^{(l-1)})),u\in N(v) \cup {v}
$$
Aggregation
Idea : each node will aggregate the message from node $v$’s neighbors
Aggregation function :
$$
h_v^{(l)} = AGG^{(l)}({m_u^{(l)},u\in N(v)},m_v^{(l)})
$$
Activation
can be added to message computation or aggregation
3. Various GNN
GCN
$$
h_v^{(l)} = \sigma(W^{(l)} \sum_{u\in N(v)}\frac{h_u^{(l-1)}}{|N(v)|})\
MSG^{(l)}(\cdot) = W^{(l)}\frac{h_u^{(l-1)}}{|N(v)|}\
AGG^{(l)}(\cdot) = \sum_{i \in N(v)} MSG^{(l)}(\cdot)
$$
GraphSAGE
$$
h_v^{l} =\sigma(W^{(l)}\cdot CONCAT(h_v^{(l-1)},AGG({h_u^{(l-1)},\forall u \in N(v)})))\
AGG(\cdot) \in {Mean(\cdot),Pool(\cdot),LSTM(\cdot)}\
h_v^{(l)}\leftarrow \frac{h_v^{(l)}}{|h_v^{(l)}|_2}
$$
GAT
$$
h_v^l = \sigma(\sum_{u\in N(v)}\alpha_{vu} W^{(l)}h_u^{(l-1)})\
\alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k\in N(v)}\exp(e_{vk})}\
e_{vu} = a(W^{(l)}h_u^{(l-1)},W^{(l)}h_v^{(l-1)})
$$
where a is attention mechanism function
GAT with Multi-Attention
$$
h_v^{(l)}[1] = \sigma(\sum_{u\in N(v)}\alpha_{vu}^1 W^{(l)}h_u^{(l-1)})\
h_v^{(l)}[2] = \sigma(\sum_{u\in N(v)}\alpha_{vu}^2 W^{(l)}h_u^{(l-1)})\
h_v^{(l)}[3] = \sigma(\sum_{u\in N(v)}\alpha_{vu}^3 W^{(l)}h_u^{(l-1)})\
h_v^{(l)} = AGG(h_v^{(l)}[1],h_v^{(l)}[2],h_v^{(l)}[3])
$$
4. GNN Layer in Practice
Linear
Batch Normalization
Stabilize neural network training
$$
\mu_j=\frac1N\sum_{i=1}^NX_{i,j}\
\sigma_j^2=\frac1N\sum_{i=1}^N(X_{i,j}-\mu_j)^2\
\widehat X_{i,j} = \frac{X_{i,j}-\mu_j}{\sqrt{\sigma_j^2+\epsilon}}\
Y_{i,j} = \gamma_j \widehat X_{i,j} + \beta_j
$$Dropout
prevent overfitting
During training : with some probability $p$, randomly set some neurons to zero
During testing : use all neurons to calculate
Activation
ReLU
$$
ReLU(x_i) = \max(x_i,0)
$$Sigmoid
$$
\sigma(x_i) = \frac 1 {1+e^{-x_i}}
$$Parametric ReLU
$$
PReLU(x_i) = \max(x_i,0) + a_i\min(x_i,0)
$$
where $a_i$ is a trainable parameter and it performs better than ReLU
Attention / Gating
Control the importance of a message
Aggregation
5. Stacking Layers of a GNN
Stack layers sequentially
$$
h_v^{(0)} = x_v\
h_v^{(l)} = GNN(h_v^{(l-1)})\
y = h_v^{(L)}
$$
Over-Smoothing Problem
all the node embeddings converge to the same value
due to the shared neighbors quickly grows when increase the number of hops
Solutions
be cautious when adding GNN layers : the L should be a bit more than the receptive field
But How to enhance the expressive power with less GNN layers
within GNN layer : make AGG / Transformation become a DNN
add layers that no pass messages : Pre-processing layers, or post-processing layers
Skip Connections (ResNet)
- Function
$$
F(x)\rightarrow F(x) + x
$$Intuition
create a mixture of models
6. Graph Manipulation
Why Manipulate Graphs
- Feature Level
- input graph lack features
- structure level
- graph is too sparse so that message passing is inefficient
- graph is too dense so that message passing is too costly
- graph is too large so that can not fit the computational graph into GPU
Graph Feature Manipulation
Reason
- input graph does not have node features
- certain structures are hard to learn by GNN, e.g. Cycle Graph
approaches
- assign constant values to nodes
- assign unique ids to nodes
- add features by Clustering coefficient, PageRank, Centrality and so on
Augment sparse graphs
Add virtual edges
Intuition : use $A+A^2$ instead of using $A$
Use case : Bipartite graphs
Add virtual nodes
approach : add virtual node which connect all nodes in the graph
Benefits : greatly improves message passing
Augment Dense Graphs
approach : Neighborhood Sampling, that is, randomly sample a node’s neighborhood when compute the embeddings each time
Benefits : greatly reduce the computational cost
7. Predict With GNN
Node-Level Prediction
Input : d-dim node embeddings ${h_v^{(L)} \in \mathbb R^d,\forall v\in G}$
Output : $\widehat y_v = Head_{node}(h_v^{(L)})=W^{(H)}h_v^{(L)}$
Edge-Level Prediction
Output : $\widehat y_{uv} = Head_{edge}(h_u^{(L)},h_v^{(L)})$
Concatenation plus Linear
$\widehat y_{uv} = Linear(Concat(h_u^{(L)},h_v^{(L)}))$
Dot Product
1-way : $\widehat y_{uv} = (h_u^{(L)})^T h_v^{(L)}$
k-way : $\widehat y_{uv}^{(i)} = h_u^{(L)}W^{(i)}h_v^{(L)},i\in[1,k]$
Graph-level Prediction
Output : $\widehat y_{G} = Head_{graph}(h_v^{(L)}\in \mathbb R^d,\forall v\in G)$
Global Pooling
global mean / max / sum
Weakness : lose information with large graph
Hierarchical Global Pooling
use $ReLU(Sum(\cdot))$ to hierarchically aggregate
Differentiable Pooling
GNN
$$
Z = GNN(X,A)
$$
Differentiable Pooling GNN
$$
(A^{(l+1)},X^{(l+1)}) = DIFFPOOL(A^{(l)},Z^{(l)})
$$
- compute the cluster that a node belong to
- pooling in the cluster, and the edges between cluster would be generated
- joint train GNN and Differentiable Pooling GNN
8. Graph Split
Transductive setting
the input graph can be observed in all the dataset split, it means that we will only split the labels
Inductive setting
break the edges between splits to get multiple graphs