I had wanted to do something with JAX for a while, so I started by checking the examples in the main repository and tried doing a couple of changes. The examples are easy to follow, but I wanted to get a deeper understanding of it, so after a choppy attempt with some RL algorithms, I decided to work on something I had implemented before and went for two different Graph Neural Networks papers. I have to say that for an early stages library (version 0.1.59 as I started writing this), the process was reasonably smooth. In this post I’ll talk about my experience on how to build and train Graph Neural Networks (GNNs) with JAX.

I focus on the implementation of Graph Convolutional Networks and Graph Attention Networks and I assume familiarity with the topic, not with JAX. For details about the models, check their papers (GCN, GAT) or their accompanying posts (GCN, GAT). For those not familiar with JAX, I start with an introduction and a simple linear regression example.

## 1. What is JAX

JAX is a Machine Learning library which I would describe (very vaguely) as Numpy with auto differentiation that you can execute on GPUs (and TPUs too!). Additionally, it has XLA compilation and built-in vectorization and parallelization capabilities.

Before jumping straight to the Graph Neural Network specifics, let me review the basics of JAX. Feel free to skip this section if you already know it.

The most important JAX feature to implement neural networks is autodiff, which lets us easily compute derivatives of our functions.

First of all, let’s see how to use it for a simple function like ReLU.

import jax.numpy as np

def relu(x):
return np.maximum(0, x)



By calling grad() on a function, JAX returns another function that computes the gradient of the first function, which we can use for any input value.

For example:

x = 5.0
print(relu(x))


Will print 5.0 and 1.0, which are the values of relu(5.0) and the derivative of ReLU at x=5.0. And not just that, but we can also easily compute higher order derivatives by chaining multiple calls of grad(). For example, we can compute the 2nd derivative like this:

relu_2nd = grad(grad(relu))
print(relu_2nd(x))


As we would expect, relu_2nd(x) will evaluate to 0. for any value of x, as ReLU is a piecewise linear function without curvature.

In the same way, with jax.grad() we can compute derivatives of a function with respect to its parameters, which is a building block for training neural networks. For example, let’s take a look at the following simple linear model and see how to compute the derivatives w.r.t its parameters for an input value of x = 5.0:

def linear(params, x):
w,b = params
return w*x + b

# initial parameters
w, b = 1.5, 2.


What we are doing in the previous code snippet is the following: first we define the function linear(params, x) which gets as arguments the parameters of the linear model (w, b) and the data point x to get a prediction on. Then, calling grad(linear) gives us a function that computes the gradient of linear w.r.t its first argument (the linear model parameters.) Thus, params_grad will be a vector with two values, the first one is the derivative w.r.t w and the second one is the derivative w.r.t b. We can go one step further and compute the gradient of a loss function w.r.t the linear model parameters:

def loss(params, dataset):
x, y = dataset
pred = linear(params, x)
return np.square(pred - y).mean()

# (5.0, 2.0) are the made up values for x and y


For the example I’ve set x=5.0 and y=2.0 arbitrarily, and then compute the gradient of the loss function for that arbitrary label. Notice how loss is a function of linear as well, and loss_grad() will compute the gradient w.r.t to the parameters, chaining the gradient of loss and the gradient of linear.

With all these pieces, we can write a small piece of code that trains a linear model:

import numpy.random as npr
import jax.numpy as np

# first generate some random data
X = npr.uniform(0, 1, 300)
true_w, true_b = 2, 1
# add some noise to the labels
Y = X*true_w + true_b + 0.4*npr.randn(300)

# the linear model
def linear(params, x):
w,b = params
return w*x + b

def loss(params, dataset):
x, y = dataset
pred = linear(params, x)
return np.square(pred - y).mean()

iterations = 500
step_size = 0.01
dataset = (X, Y)
w, b = 1.5, 2. # initial values for the parameters
for i in range(iterations):
params = (w, b)
loss_ = loss(params, dataset)
# compute gradient w.r.t model parameters
# update parameters
print(loss_)


We can visualize the final model (orange line) and compare it to the true data generating model (red line), and we see that we didn’t get too far:

Well, I guess that’s enough of JAX basics, in the next sections you’ll see that the GNN implementations are not that different from this simple example.

## 2. Graph Neural Networks

Graph Neural Networks (GNNs) are neural networks that we can apply to graph structured data in a similar way that we apply Convolutional Neural Networks (CNNs) to images or Recurrent Neural Networks (RNNs) to sequences.

In this post I’ll focus on Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs), but there are several more models. Check this paper for a detailed explanation of different kinds of models. It is a nice overview and very useful to get an understanding of them.

### 2.1 Graph Convolutional Networks

First defined in the paper Semi-Supervised Classification with Graph Convolutional Networks these models are very fast, and also easy to implement, using the normalized adjacency matrix of the graph to propagate information between neighbouring nodes. You can check this post for more details.

Basically, a graph convolutional layer is defined as:

where $$H^{(L)}$$ is the input to the layer, a $$N \times C$$ matrix with as many rows and columns as nodes and input features respectively. $$\hat{A}$$ is the degree normalized adjacency matrix and $$\Theta^{(l)}$$ is a $$C \times F$$ matrix of learnable parameters.

### 2.2 Graph Attention Networks

This is another family of GNNs that we proposed in the paper Graph Attention Networks. Instead of using the values in the normalized adjacency matrix to propagate information, these models compute attention coefficients between neighbouring nodes, and use those coefficients when propagating the nodes’ features. Check this post for more details.

A layer can still be defined as:

but this time, the propagation matrix $$\hat{A}$$ is not the normalized adjacency matrix, but a matrix whose elements are attention coefficients computed between neighbouring nodes:

with the unnormalized coefficient $$e_{ij}$$ between each pair of nodes $$(i,j)$$ being a function of their features:

The other difference between GAT layers and GCN layers is that GAT layers use a MultiHead approach, similar to the attention proposed in the original Transformer paper. With this approach each layer consists on several independent attention heads whose output is concatenated.

## 3. JAX implementation

For my GNNs implementations, I based my code structure on the stax package in jax.experimental, which implements several neural network layers. The whole idea is that a layer decouples its parameters from its forward computation, and the parameters are always passed as an argument to the forward function. If you are familiar with Pytorch, this would be like calling layer(params, x) instead of layer(x) to compute the forward pass of a layer.

Therefore, to define a layer in JAX we have to define two functions:

1. init_fun: initializes the parameters of the layer.
2. apply_fun: defines the forward computation function.

Something like this:

def Layer():
def init_fun(*args, **kwargs):
# initialize parameters and compute output shape
return output_shape, params
def apply_fun(params, x, *args, **kwargs):
# process the input and return the output
return out


A model is nothing more than a collection of layers, so it is defined in the same way, an init_fun that will call the initializers of the layers, and an apply_fun that will define the forward computation of the model from input to output by chaining different layers and activation functions.

Let’s see how to do that for the first model, GCNs.

### 3.1 Graph Convolutional Networks

Let’s start by defining a graph convolutional layer, which is the building block of a GCN. As I said before, we need to define an initilization function and a forward function for the layer. The initialization function takes care of initializing the parameters of the layer, which in this case are a linear projection matrix and a vector of biases.

out_dim = 64
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W_init, b_init = glorot_uniform(), zeros
W = W_init(k1, (input_shape[-1], out_dim))
b = b_init(k2, (out_dim,)) if bias else None
return output_shape, (W, b)


The initialization function only needs two arguments:

• rng: a random key used to generate random numbers
• input_shape: a tuple indicating the input shape, necessary to know the shape of the layer parameters.

And returns two values:

• output_shape: the shape of the output computed by the layer.
• (W, b): the initialized parameters.

The reason for receiving and returning input and output shapes is that we will chain these initialization functions when creating a model with multiple layers, so each layer will know exactly which input shape to expect.

Finally, the forward function can be easily defined as:

def apply_fun(params, x, adj, **kwargs):
W, b = params
support = np.dot(x, W)
if bias:
out += b
return out


Notice how the parameters are an argument to the function? This makes it a pure function since it only depends on its inputs, instead of using values stored in global variables or class attributes. params is a tuple of parameters, like the one returned by init_fun. Additionally, a GCN layer needs the adjacency matrix of the graph to propagate information in the graph with np.matmul(adj, support).

After defining these two functions, we can put them together to form a Graph Convolutional Layer:

def GraphConvolution(out_dim, bias=False):
"""
Layer constructor function for a Graph Convolution layer
as the one in https://arxiv.org/abs/1609.02907
"""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W_init, b_init = glorot_uniform(), zeros
W = W_init(k1, (input_shape[-1], out_dim))
b = b_init(k2, (out_dim,)) if bias else None
return output_shape, (W, b)

W, b = params
support = np.dot(x, W)
if bias:
out += b
return out

return init_fun, apply_fu


On the forward step, this layer will project the input nodes’ features using a learned projection defined by W and b and then propagate them according to the normalized adjacency matrix.

To use these layers to create a Graph Convolutional Network, we follow the same approach for the model, defining an init_fun and an apply_fun functions for the model.

First, we call the layer functions like this:

gc1_init, gc1_fun = GraphConvolution(nhid)
_, drop_fun = Dropout(dropout)
gc2_init, gc2_fun = GraphConvolution(nclass)


The GraphConvolution function is called twice to generate two initialization and forward functions, one of each for each layer. You can also see that we are using a Dropout layer too. I’m not going to explain it here since it follows the exact same pattern as the GraphConvolution layer, but I want you to notice that I discard the initializer of the layer, since it doesn’t have parameters to initialize. You can check the exact implementation of the Dropout layer here. Keep in mind that up to this point we haven’t initialized or used these layers yet, we have just instantiated their initialization and forward functions.

With that, we have all we need to define the init_fun for the GCN model:

def init_fun(rng, input_shape):
init_funs = [gc1_init, gc2_init]
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params


This function is initializing the model layers and storing their parameters in params. Then, following the same pattern as the layers’ init_fun function, it returns the output shape and the parameters of the model, which are stored in a list with one item per layer.

The other function that we have to define is the apply_fun for the GCN model:

def apply_fun(params, x, adj, is_training=False, **kwargs):
rng = kwargs.pop('rng', None)
k1, k2, k3, k4 = random.split(rng, 4)
x = drop_fun(None, x, is_training=is_training, rng=k1)
x = gc1_fun(params[0], x, adj, rng=k2)
x = nn.relu(x)
x = drop_fun(None, x, is_training=is_training, rng=k3)
x = gc2_fun(params[1], x, adj, rng=k4)
x = nn.log_softmax(x)
return x


This function uses the gc1_fun, drop_fun and gc2_fun that we have obtained before, and basically defines the forward pass of the full model. Easy, right? Also notice how I used an additional argument called is_training. This is a boolean flag used by Dropout to change its behaviour at eval time.

Putting all this pieces together, we can build a GCN model like this:

def GCN(nhid: int, nclass: int, dropout: float):
"""
This function implements the GCN model that uses 2 Graph Convolutional layers.
"""
gc1_init, gc1_fun = GraphConvolution(nhid)
_, drop_fun = Dropout(dropout)
gc2_init, gc2_fun = GraphConvolution(nclass)

init_funs = [gc1_init, gc2_init]

def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params

def apply_fun(params, x, adj, is_training=False, **kwargs):
rng = kwargs.pop('rng', None)
k1, k2, k3, k4 = random.split(rng, 4)
x = drop_fun(None, x, is_training=is_training, rng=k1)
x = gc1_fun(params[0], x, adj, rng=k2)
x = nn.relu(x)
x = drop_fun(None, x, is_training=is_training, rng=k3)
x = gc2_fun(params[1], x, adj, rng=k4)
x = nn.log_softmax(x)
return x

return init_fun, apply_fun


Let’s see how to use the same pattern for Graph Attention Networks now.

### 3.2 Graph Attention Networks

For Graph Attention Networks we follow the exact same pattern, but the layer and model definitions are slightly more complex, since a Graph Attention Layer requires a few more operations and parameters. This time, similar to Pytorch implementation of Attention and MultiHeaded Attention layers, the layer definitions are split into two:

1. GraphAttentionLayer: implements a single attention layer, equivalent to one head.
2. MultiHeadLayer: implementes the multi-head logic, using several GraphAttentionLayer.

def GraphAttentionLayer(out_dim, dropout):
"""
Layer constructor function for a Graph Attention layer.
"""
_, drop_fun = Dropout(dropout)
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2, k3, k4 = random.split(rng, 4)
W_init = glorot_uniform()
# projection
W = W_init(k1, (input_shape[-1], out_dim))

a_init = glorot_uniform()
a1 = a_init(k2, (out_dim, 1))
a2 = a_init(k3, (out_dim, 1))

return output_shape, (W, a1, a2)

def apply_fun(params, x, adj, rng, activation=nn.elu, is_training=False,
**kwargs):
W, a1, a2 = params
k1, k2, k3 = random.split(rng, 3)
x = drop_fun(None, x, is_training=is_training, rng=k1)
x = np.dot(x, W)

f_1 = np.dot(x, a1)
f_2 = np.dot(x, a2)
logits = f_1 + f_2.T
coefs = nn.softmax(
nn.leaky_relu(logits, negative_slope=0.2) + np.where(adj, 0., -1e9))

coefs = drop_fun(None, coefs, is_training=is_training, rng=k2)
x = drop_fun(None, x, is_training=is_training, rng=k3)

ret = np.matmul(coefs, x)

return activation(ret)

return init_fun, apply_fun


The layer looks a bit more complicated, but it isn’t much different from the previous GraphConvolution layer. First, it projects the input nodes to a new space with W and then it propagates the nodes’ features with np.matmul(coefs, x), as we did before. The main difference is that the values in coefs are attention coefficients computed from the node features, instead of coming from the adjacency matrix. coefs is build by computing an attention coefficient between each pair of nodes, and then using softmax over all the attention coefficients for each node, to normalize them. The input to the softmax is masked out to consider only the direct neighbours to each node.

With this layer, we can easily build the multi-head mechanism following the same pattern of writing and init_fun and an apply_fun:

def MultiHeadLayer(nheads: int, nhid: int, dropout: float, last_layer: bool=False):
layer_funs, layer_inits = [], []
att_init, att_fun = GraphAttentionLayer(nhid, dropout=dropout)
layer_inits.append(att_init)
layer_funs.append(att_fun)

def init_fun(rng, input_shape):
params = []
for att_init_fun in layer_inits:
rng, layer_rng = random.split(rng)
layer_shape, param = att_init_fun(layer_rng, input_shape)
params.append(param)
input_shape = layer_shape
if not last_layer:
# multiply by the number of heads
input_shape = input_shape[:-1] + (input_shape[-1]*len(layer_inits),)
return input_shape, params

def apply_fun(params, x, adj, is_training=False, **kwargs):
rng = kwargs.pop('rng', None)
layer_outs = []
rng, _ = random.split(rng)
if not last_layer:
x = np.concatenate(layer_outs, axis=1)
else:
x = np.mean(np.stack(layer_outs), axis=0)

return x

return init_fun, apply_fun


As you can see, we instantiate as many GraphAttentionLayer layers as the number of heads, and store each of their att_init and att_fun functions into two lists. Then, we write an init_fun for the MultiHeadLayer which will initialize each head and compute the appropriate output shape. The apply_fun is not very different from the others, in this case we run the input through each head and then concatenate their outputs (or average them for the last layer).

With this two pieces, the Graph Attention Model definition is quite straightforward:

def GAT(nheads: List[int], nhid: List[int], nclass: int, dropout: float):
"""
Graph Attention Network model definition.
"""

init_funs = []
attn_funs = []

nhid += [nclass]
for layer_i in range(len(nhid)):
last = layer_i == len(nhid) - 1
dropout=dropout, last_layer=last)
attn_funs.append(layer_fun)
init_funs.append(layer_init)

def init_fun(rng, input_shape):
params = []
for i, init_fun in enumerate(init_funs):
rng, layer_rng = random.split(rng)
layer_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
input_shape = layer_shape
return input_shape, params

def apply_fun(params, x, adj, is_training=False, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, len(attn_funs))

for i, layer_fun in enumerate(attn_funs):
x = layer_fun(params[i], x, adj, rng=rngs[i], is_training=is_training)

return nn.log_softmax(x)

return init_fun, apply_fun


Something worth mentioning here is the difference between the params argument for the GAT model and the params for the GCN model. Did you notice anything different?

The difference is their structure. For GCNs, params is a List of Tuple, where each Tuple holds the (W, b) values for each layer. Therefore, it looks something like this:

params = [
(W_0, b_0), # first layer
(W_1, b_1)  # output layer
]


For GATs, the structure is a bit different, because we have different heads at each layer. Therefore, the structure of the params argument is a List of List of Tuple like this:

params = [
[ # first layer
(W_0, a1_0, a2_0), # first head
(W_1, a1_1, a2_1), # second head
...
(W_k, a1_k, a2_k)  # k-th head
],
[ # output layer
...
]
]


You could think that this change in the structure would mean that we will have to adapt our code to support both of them. However, that is not the case, because jax.grad() will return the gradients of the loss w.r.t params with the same structure as params, and the optimizers only need params and grads to have the same structure, but can work with arbitrary structures, so we don’t have to worry about that.

### 3.3 Main loop

If you made it to this point, you should have an idea on how to define the models: via an init_fun and an apply_fun. Once we have defined the model functions, all we have to do to use them is to write the standard training loop, as well as the loss function and the optimizer.

Let’s start by instantiating and initializing a model, for example a GCN.

init_fun, predict_fun = GCN(nhid=hidden,
nclass=labels.shape[1],
dropout=dropout)
input_shape = (-1, n_nodes, n_feats)
rng_key, init_key = random.split(rng_key)
_, init_params = init_fun(init_key, input_shape)


As we did individually for each layer, we first call the model function GCN() with the desired configuration, and we get the initilization function init_fun and the forward computation function predict_fun. Next, we get the initial model parameters init_params as a List by using the init_fun function.

Now we only have a few things left: define the loss function, the optimizer, and the main loop.

The loss functions in JAX should also be pure functions. They get the model parameters and the input data as arguments, like this:

def loss(params, batch):
"""
The idxes of the batch indicate which nodes are used to compute the loss.
"""
inputs, targets, adj, is_training, rng, idx = batch
preds = predict_fun(params, inputs, adj, is_training=is_training, rng=rng)
ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
l2_loss = 5e-4 * optimizers.l2_norm(params)
return ce_loss + l2_loss


Here the input data consists on a list of the input nodes features and their true labels, along with the adjacency matrix, the is_training label, the random key and the set of node indexes that we want to compute the loss on. These indexes change at train and eval time, so they also have to be an input to the loss function. Notice how the parameters are the first argument of the loss function, while everything else is packed as a second argument. The reason is that when computing the gradient of the loss function w.r.t the parameters, JAX assumes by default that the first argument are the parameters. This works for most use cases, but it can always be changed to suit one’s specific needs.

The reason for passing the random key around as an argument is that this way, the functions depend uniquely on their arguments, not on an external random key defined somewhere else, making them true pure functions.

For the optimizer I used the optimizer package from jax.experimental because I wanted to use ADAM to replicate the papers, but we could easily write our own SGD optimizer similarly to what I have shown for the linear regression model. The optimizer is defined as:

opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)


See how we create the optimizer state from the initial values of the model parameters. Then, whenever we want to retrieve the parameters, to use them as an argument for example, we can use get_params(opt_state). Finally, opt_update is the function that we will use to update the parameters based on the gradient of the loss, and is used like this:

opt_state = opt_update(i, grad(loss)(params, batch), opt_state)


where i is the iteration number and grad(loss)(params, batch) are be the gradients of the loss function w.r.t the parameters of the model, as we have seen before in the linear regression example.

All that’s left is to wrap the loss computation and parameter update into a single function:

def update(i, opt_state, batch):
params = get_params(opt_state)


And with this we can write a simple training loop that will train our graph neural network models:

print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
# define training batch
batch = (features, labels, adj, True, rng_key, idx_train)
# update parameters
opt_state = update(epoch, opt_state, batch)
epoch_time = time.time() - start_time

# validate
params = get_params(opt_state)
eval_batch = (features, labels, adj, False, rng_key, idx_val)
val_acc = accuracy(params, eval_batch)
val_loss = loss(params, eval_batch)
print(f"Iter {epoch}/{num_epochs} ({epoch_time:.4f} s) val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}")

# new random key at each iteration, othwerwise dropout uses always the same mask
rng_key, _ = random.split(rng_key)


I’m not showing how to load the dataset and preprocess the data, since that is dataset specific, but you can check my full implementation of Graph Convolutional Networkss in JAX to see the full training script using the Cora dataset, as well as the use of @jax.jit to speed up the training. Also, I haven’t explained how to use GATs instead of GCNs because the way to use them is the same. If you want to see it, check also my repository implementing Graph Attention Networks in JAX.

## 4. Other JAX resources

I hope this post has helped you understand JAX and how to use it. While I was doing my implementation of these two models a resource that I found very useful and educative was Sabrina Mielke’s post “From PyTorch to JAX: towards neural net frameworks that purify stateful code”. I encourage everyone to give it a good read. If you want to check other JAX codebases, you can start with Flax and Haiku, two neural networks libraries that use JAX by Google and Deepmind respectively. Additionally, if you are interested in Reinforcement Learning, RLax uses JAX to implement some RL algorithms.

### Get in Touch

I will gladly answer any questions or discuss anything about the code, you can contact me at gcucurull at gmail dot com.