Graph Neural Networks and Graph Convolution

A no-brainer primer on graph neural networks and graph convolution to get a quick sense about their workings

December 4, 2020 • Pratulya Bubna

This post is intended to provide a very high-level intuition about the workings of Graph Convolutional Networks (GCNs)in the spatial domain. The goal is to provide an overall rough-idea about GNNs and GCNs without delving into the meaty math.

An illustration of a graph-based network. (image credits )

Introduction

Images, text and audio data can easily be represented as a grid — a regular and structured format. Given this representation, Convolutional Neural Networks (CNNs) can be successfully applied to process this data.

However, non-Euclidean data such as social networks, 3D meshes, molecular data, knowledge graphs — inherently graphs, cannot be represented in a grid-like structure and instead lie in an irregular domain. Consequently, the convolution operation in the Euclidean case (CNNs) is not well-defined on non-Euclidean domains.

The learning architecture designed to handle unstructured, non-Euclidean graph data is Graph Neural Networks (GNNs). GNNs do not have strict structural requirements as opposed to regular neural networks that operate on fixed-dimension inputs (eg. CNNs built over MNIST dataset require all input images to be of size 28x28). This means that the number of vertices and edges between input graphs can change.

Variants of GNNs are Graph Convolutional Networks (GCNs) that have evolved due to works in generalizing convolutions to the graph domain.


A graph is a simple way of encoding pairwise relationships among a set of objects.

A graph \(G\) consists of a pair of sets \((V, E)\) — a collection \(V\) of nodes and a collection of \(E\) of edges. Each edge \(e \in E\) joins two nodes and thus is a two-element subset of \(V: e = \{u, v\}\) for some \(u, v \in V\).

Graph Convolutional Networks

CNNs learn features hierarchically by building from simpler ones to more complex. Desirable properties of CNNs include weight sharing, local connectivity, non-linearity and pooling layers, that have helped them achieve outstanding performance in a variety of tasks.

In graphs, the notion of neighborhoods and connectivity is ill-defined, making the application of convolution and pooling operations non-trivial. Two approaches to tackle this: spectral and non-spectral. Whereas spectral approaches deal with spectral representations of graphs (eg. graph Laplacian), non-spectral approaches define convolutions directly on the graph.

That is, the spatial approaches assume that graphs are embedded in a Euclidean space where nodes can have coordinates and weights on them. This assumption helps in generalizing pixels (in case of CNNs) to nodes of the graph, over which we can take the similar sliding window approach to aggregate the features of the neighborhoods.

Graph Convolution

We’ll stick to GCNs in spatial domain, in which, convolution is usually implemented through non-linear functions (eg. MLP) over spatial neighbors, and pooling may be adopted to produce a new coarsened graph by aggregating information from each point’s neighbors.

Message Passing and Neighborhood Aggregation

A simple illustration. (credits Zak Jost: Graph Convolutional Networks")

We can choose to determine a fix-sized (node degree) or a variable neighborhood, which would require to define an operator that can work with different sized neighborhoods. After determining the neighborhood, an aggregator is performed over it, for instance, mean over all the neighbor’s feature vectors.

The aggregation operator generally is chosen to be permutation invariant. That is, the aggregated output is invariant to the order of neighbors picked. The usual choices being sum, mean and max.

An illustration of graph convolution. (credits Zak Jost: Graph Convolutional Networks")

Graph Convolution: Recipe

(left-right 1–4) 4-step recipe. (credits Ravi, Nikhila: "3D Deep Learning with PyTorch3D "", ICML 2020")
  1. Select a vertex
  2. Find its neighboring vertices using the edges
  3. Aggregate (eg. sum) features of the neighbors
  4. Update features of the selected node

GCNs in an equation

The equation encapsulates message passing and neighborhood aggregation in GCNs.

$$ \mathbf{x}_i^{(k)} = \gamma^{(k)} \biggl( \mathbf{x}_i^{(k-1)} , \square_{j \in \mathcal{N}(i)} \quad \phi^{(k)} \Bigl( \mathbf{x}_i^{(k-1)} , \mathbf{x}_j^{(k-1)} , \mathbf{e}_{j,i} \Bigr) \biggr) \quad \text{where,} $$ $$ \mathbf{x}_i^{(k)} \in \mathbb{R}^{F'} $$ $$ \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} \in \mathbb{R}^{F} $$ $$ \mathbf{e}_{j,i} \in \mathbb{R}^{D} $$

in words

Calculating features of node \(\mathbf{x}_i\) in \(k^{th}\) layer is a non-linear function of:

  1. its features in previous layer \(i.e. \, \mathbf{x}_i^{(k-1)}\); and,
  2. aggregegated output over its neighbors \(j \in \mathcal{N}(i)\)
    • each “output” is a non-linear function of \(\mathbf{x}_i^{(k-1)},\) neighbor features \(\mathbf{x}_j^{(k-1)}\) and (optionally) of the edge features \(\mathbf{e}_{j,i}\) between them (all in the previous layer \(k-1\))
Detailed equation (credits: PyTorch Geometric)

:bulb: TODO: add citations
:loudspeaker: the discussion section is below :arrow_down: