{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quickstart\n", "\n", "This notebook shows how to create a synthetic graph and then how to train a model in a classification task\n", "using [Haiku Geometric](https://github.com/alexOarga/haiku-geometric).\n", "\n", "[![Open in Colab](https://img.shields.io/static/v1.svg?logo=google-colab&label=Quickstart&message=Open%20In%20Colab&color=blue)](https://colab.research.google.com/github/alexOarga/haiku-geometric/blob/main/docs/source/notebooks/1_quickstart.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Haiku Geometric - Graph Neural Networks in JAX\n", "\n", "Haiku Geometric is a collection of graph neural network (GNN) implementations in [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html). It tries to provide **object-oriented** and **easy-to-use** modules for GNNs.\n", "\n", "Haiku Geometric is built on top of [Haiku](https://github.com/deepmind/dm-haiku) and [Jraph](https://github.com/deepmind/jraph).\n", "It is deeply inspired by [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric). \n", "In most cases, Haiku Geometric tries to replicate the API of PyTorch Geometric to allow code sharing between the two.\n", "\n", "Haiku Geometric is still under development and I would advise against using it in production.\n", "\n", "- You can find all the available graph neural network layers [here](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers).\n", "- Visit also [Haiku Geometric documentation](https://haiku-geometric.readthedocs.io/en/latest/index.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a synthetic graph\n", "\n", "We will create the following graph: \n", "\n", "![synthetic graph](_static/graph1.png) \n", "\n", "To do so, we create the following variables: \n", "\n", "- `nodes`: a 2D array of shape `[num_nodes, num_node_features]` with features. \n", "- `senders`: a 1D array of shape `[num_edges]` with the source edge node indices.' \n", "- `receivers`: a 1D array of shape `[num_edges]` with the destination edge node indices.' \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "import jax.numpy as jnp\n", "\n", "nodes = jnp.array([\n", " [0.1, 0.2, 1.0], # node 0 features\n", " [0.4, 0.4, 0.3], # node 1 features\n", " [0.8, 0.0, 0.9], # node 2 features\n", " [0.0, 1.0, 1.0] # node 3 features\n", "])\n", "senders = jnp.array([0, 1, 1, 2, 2])\n", "receivers = jnp.array([1, 0, 2, 2, 3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a model\n", "\n", "We will create a model with 2 graph convolutional networks ([haiku-geometric.nn.GCNConv](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.conv.GCNConv)) layers\n", "followed by a linear ([hk.Linear](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.Linear)) layer.\n", "Notice that to do so we group our layer in a new [Haiku](https://dm-haiku.readthedocs.io/en/latest/) module denoted `MyNet`.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import haiku as hk\n", "from haiku_geometric.nn import GCNConv\n", "\n", "class MyNet(hk.Module):\n", " def __init__(self, hidden_channels, out_channels):\n", " super().__init__()\n", " self.conv1 = GCNConv(hidden_channels)\n", " self.conv2 = GCNConv(hidden_channels)\n", " self.linear = hk.Linear(out_channels)\n", "\n", " def __call__(self, nodes,senders, receivers):\n", " x = self.conv1(nodes, senders, receivers)\n", " x = jax.nn.relu(x)\n", " x = self.conv2(x, senders, receivers)\n", " x = self.linear(nodes)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transforming the model\n", "\n", "We now define a `forward` function that instantiates the net and performs a call. \n", "This function will be transformed by Haiku and will perform a forward pass on the model.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def forward(nodes, senders, receivers):\n", " net = MyNet(16, 7)\n", " return net(nodes, senders, receivers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we transform the `forward` function as explained in the [Haiku documentation](https://dm-haiku.readthedocs.io/en/latest/).\n", "After transforming the function, we have to initialize the model with the `init` function\n", "that receives our graph data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model = hk.transform(forward)\n", "model = hk.without_apply_rng(model)\n", "rng = jax.random.PRNGKey(42)\n", "params = model.init(rng, nodes=nodes, senders=senders, receivers=receivers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After this, we are ready to perform a forward pass on the model.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 0.00770418, -0.7566054 , 0.51024306, 0.2543769 ,\n", " 0.4244291 , 1.0645634 , -0.30671927],\n", " [-0.10649211, -0.5037036 , 0.24744353, 0.20532413,\n", " 0.06193589, 0.6883482 , 0.1389835 ],\n", " [ 0.27398756, -0.32722455, 0.59584326, -0.2710259 ,\n", " 0.59495777, 1.479022 , 0.37957942],\n", " [-0.47271663, -1.6297377 , 0.53237855, 1.0204307 ,\n", " 0.07947233, 1.1653316 , -0.5966778 ]], dtype=float32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = model.apply(params, nodes=nodes, senders=senders, receivers=receivers)\n", "output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning on graphs\n", "\n", "Lets say that we want to perform classification on the graph.\n", "We will consider the following array of ground truth labels (one class for each node)\n", "that we will try to predict:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "labels = jnp.array([0, 1, 2, 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are ready to perform learning with our model (e.g. with gradient descent). \n", "To do so we will use an optimizer from [optax](https://optax.readthedocs.io/en/latest/).\n", "In this case we will use the Adam optimizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install optax" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import optax\n", "\n", "opt_init, opt_update = optax.adam(learning_rate=0.1)\n", "opt_state = opt_init(params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define out loss function, where we first performa a forward pass to computed\n", "the `logits`, ant the compute the loss, in this case, softmax cross entropy loss. \n", "Notice that the function is JAX compatible and we can use the `jax.jit` decorator to\n", "speed up the training.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def loss_fn(params):\n", " logits = model.apply(params, nodes=nodes, senders=senders, receivers=receivers)\n", " x_loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)\n", " return jnp.sum(x_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also define a function that computes the gradients of the loss function (\n", "by using the `jax.grad` function) and updates the model parameters." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def update(params, opt_state):\n", " g = jax.grad(loss_fn)(params)\n", " updates, opt_state = opt_update(g, opt_state)\n", " return optax.apply_updates(params, updates), opt_state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also need a function that computes the accuracy of the model.\n", "Again, this function is compatible with the ``jax.jit`` decorator.\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def accuracy(params):\n", " logits = model.apply(params, nodes=nodes, senders=senders, receivers=receivers)\n", " return jnp.mean(jnp.argmax(logits, axis=-1) == labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can perform the training loop!\n", "We will train for 10 epochs:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0: accuracy = 0.0\n", "Step 1: accuracy = 0.25\n", "Step 2: accuracy = 0.25\n", "Step 3: accuracy = 0.25\n", "Step 4: accuracy = 0.5\n", "Step 5: accuracy = 0.75\n", "Step 6: accuracy = 1.0\n", "Step 7: accuracy = 1.0\n", "Step 8: accuracy = 1.0\n", "Step 9: accuracy = 1.0\n" ] } ], "source": [ "for step in range(10):\n", " params, opt_state = update(params, opt_state)\n", " acc = accuracy(params)\n", " print(f\"Step {step}: accuracy = {acc}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 4 }