GNN-Notes-LEC-07-08

Lec 07 & 08 GNN

1. A General GNN Framework

  • GNN layer
    • Message
    • Aggregate
  • Layer connectivity
    • Stack layers sequentially
    • Ways of adding skip connections
  • Graph Augmentation
    • Raw input graph $\ne$ computational graph
    • Graph feature augmentation
    • Graph structure augmentation
  • Learning Objective
    • supervised / unsupervised
    • node / edge / graph level

2. A single layer of a GNN

Idea : compress a set of vector into a single vector

Message computation

Idea : each node will create message and sent to other node later

Message function :
$$
m_u^{(l)} = MSG^{(l)}(h_u^{(l-1)})),u\in N(v) \cup {v}
$$

Aggregation

Idea : each node will aggregate the message from node $v$’s neighbors

Aggregation function :
$$
h_v^{(l)} = AGG^{(l)}({m_u^{(l)},u\in N(v)},m_v^{(l)})
$$

Activation

can be added to message computation or aggregation

3. Various GNN

GCN

$$
h_v^{(l)} = \sigma(W^{(l)} \sum_{u\in N(v)}\frac{h_u^{(l-1)}}{|N(v)|})\
MSG^{(l)}(\cdot) = W^{(l)}\frac{h_u^{(l-1)}}{|N(v)|}\
AGG^{(l)}(\cdot) = \sum_{i \in N(v)} MSG^{(l)}(\cdot)
$$

GraphSAGE

$$
h_v^{l} =\sigma(W^{(l)}\cdot CONCAT(h_v^{(l-1)},AGG({h_u^{(l-1)},\forall u \in N(v)})))\
AGG(\cdot) \in {Mean(\cdot),Pool(\cdot),LSTM(\cdot)}\
h_v^{(l)}\leftarrow \frac{h_v^{(l)}}{|h_v^{(l)}|_2}
$$

GAT

$$
h_v^l = \sigma(\sum_{u\in N(v)}\alpha_{vu} W^{(l)}h_u^{(l-1)})\
\alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k\in N(v)}\exp(e_{vk})}\
e_{vu} = a(W^{(l)}h_u^{(l-1)},W^{(l)}h_v^{(l-1)})
$$

where a is attention mechanism function

GAT with Multi-Attention

$$
h_v^{(l)}[1] = \sigma(\sum_{u\in N(v)}\alpha_{vu}^1 W^{(l)}h_u^{(l-1)})\
h_v^{(l)}[2] = \sigma(\sum_{u\in N(v)}\alpha_{vu}^2 W^{(l)}h_u^{(l-1)})\
h_v^{(l)}[3] = \sigma(\sum_{u\in N(v)}\alpha_{vu}^3 W^{(l)}h_u^{(l-1)})\
h_v^{(l)} = AGG(h_v^{(l)}[1],h_v^{(l)}[2],h_v^{(l)}[3])
$$

4. GNN Layer in Practice

  • Linear

  • Batch Normalization

    Stabilize neural network training
    $$
    \mu_j=\frac1N\sum_{i=1}^NX_{i,j}\
    \sigma_j^2=\frac1N\sum_{i=1}^N(X_{i,j}-\mu_j)^2\
    \widehat X_{i,j} = \frac{X_{i,j}-\mu_j}{\sqrt{\sigma_j^2+\epsilon}}\
    Y_{i,j} = \gamma_j \widehat X_{i,j} + \beta_j
    $$

  • Dropout

    prevent overfitting

    During training : with some probability $p$, randomly set some neurons to zero

    During testing : use all neurons to calculate

  • Activation

    • ReLU
      $$
      ReLU(x_i) = \max(x_i,0)
      $$

    • Sigmoid
      $$
      \sigma(x_i) = \frac 1 {1+e^{-x_i}}
      $$

    • Parametric ReLU
      $$
      PReLU(x_i) = \max(x_i,0) + a_i\min(x_i,0)
      $$
      where $a_i$ is a trainable parameter and it performs better than ReLU

  • Attention / Gating

    Control the importance of a message

  • Aggregation

5. Stacking Layers of a GNN

Stack layers sequentially

$$
h_v^{(0)} = x_v\
h_v^{(l)} = GNN(h_v^{(l-1)})\
y = h_v^{(L)}
$$

Over-Smoothing Problem

all the node embeddings converge to the same value

due to the shared neighbors quickly grows when increase the number of hops

Solutions

  • be cautious when adding GNN layers : the L should be a bit more than the receptive field

    But How to enhance the expressive power with less GNN layers

    • within GNN layer : make AGG / Transformation become a DNN

    • add layers that no pass messages : Pre-processing layers, or post-processing layers

  • Skip Connections (ResNet)

    • Function

    $$
    F(x)\rightarrow F(x) + x
    $$

    • Intuition

      create a mixture of models

6. Graph Manipulation

Why Manipulate Graphs

  • Feature Level
    • input graph lack features
  • structure level
    • graph is too sparse so that message passing is inefficient
    • graph is too dense so that message passing is too costly
    • graph is too large so that can not fit the computational graph into GPU

Graph Feature Manipulation

Reason

  • input graph does not have node features
  • certain structures are hard to learn by GNN, e.g. Cycle Graph

approaches

  • assign constant values to nodes
  • assign unique ids to nodes
  • add features by Clustering coefficient, PageRank, Centrality and so on

Augment sparse graphs

Add virtual edges

Intuition : use $A+A^2$ instead of using $A$

Use case : Bipartite graphs

Add virtual nodes

approach : add virtual node which connect all nodes in the graph

Benefits : greatly improves message passing

Augment Dense Graphs

approach : Neighborhood Sampling, that is, randomly sample a node’s neighborhood when compute the embeddings each time

Benefits : greatly reduce the computational cost

7. Predict With GNN

Node-Level Prediction

Input : d-dim node embeddings ${h_v^{(L)} \in \mathbb R^d,\forall v\in G}$

Output : $\widehat y_v = Head_{node}(h_v^{(L)})=W^{(H)}h_v^{(L)}$

Edge-Level Prediction

Output : $\widehat y_{uv} = Head_{edge}(h_u^{(L)},h_v^{(L)})$

Concatenation plus Linear

$\widehat y_{uv} = Linear(Concat(h_u^{(L)},h_v^{(L)}))$

Dot Product

1-way : $\widehat y_{uv} = (h_u^{(L)})^T h_v^{(L)}$

k-way : $\widehat y_{uv}^{(i)} = h_u^{(L)}W^{(i)}h_v^{(L)},i\in[1,k]$

Graph-level Prediction

Output : $\widehat y_{G} = Head_{graph}(h_v^{(L)}\in \mathbb R^d,\forall v\in G)$

Global Pooling

global mean / max / sum

Weakness : lose information with large graph

Hierarchical Global Pooling

use $ReLU(Sum(\cdot))$ to hierarchically aggregate

Differentiable Pooling

GNN
$$
Z = GNN(X,A)
$$
Differentiable Pooling GNN
$$
(A^{(l+1)},X^{(l+1)}) = DIFFPOOL(A^{(l)},Z^{(l)})
$$

  • compute the cluster that a node belong to
  • pooling in the cluster, and the edges between cluster would be generated
  • joint train GNN and Differentiable Pooling GNN

8. Graph Split

Transductive setting

the input graph can be observed in all the dataset split, it means that we will only split the labels

Inductive setting

break the edges between splits to get multiple graphs