{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "HDObRj0KtYYi" }, "source": [ "# Batch support\n", "\n", "This notebook explains how to use the batch support in Haiku Geometric.\n", "\n", "\n", "[![Open in Colab](https://img.shields.io/static/v1.svg?logo=google-colab&label=Quickstart&message=Open%20In%20Colab&color=blue)](https://colab.research.google.com/github/alexOarga/haiku-geometric/blob/main/docs/source/notebooks/batch_support.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install git+https://github.com/alexOarga/haiku-geometric.git" ] }, { "cell_type": "markdown", "metadata": { "id": "_IyJz0r0tugA" }, "source": [ "## Batching graphs\n", "\n", "The `haiku_geometric.utils.batch` function can be used to batch a list of\n", "`haiku_geometric.data.DataGraphTuple` objects into a single `haiku_geometric.data.DataGraphTuple` object.\n", "\n", "The `batch` function returns:\n", " - A single `haiku_geometric.data.DataGraphTuple` with the batched graphs.\n", " - A `jax.numpy.Array` with indices indicating to which graph each node belongs to." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "krfCUf9HtoKt", "outputId": "46f3833d-9d0c-425b-e493-548d001cb198" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "import jax.numpy as jnp\n", "from haiku_geometric.utils import batch\n", "from haiku_geometric.datasets.base import DataGraphTuple\n", "\n", "graph1 = DataGraphTuple(\n", " nodes=jnp.array([0.0, 0.1, 0.2]),\n", " senders=jnp.array([0, 1, 2]),\n", " receivers=jnp.array([2, 2, 0]),\n", " edges=None,\n", " n_node=jnp.array([3]),\n", " n_edge=jnp.array([3]),\n", " globals=None,\n", " position=None,\n", " y=jnp.array([0, 0, 0]),\n", " train_mask=None,\n", ")\n", "\n", "graph2 = DataGraphTuple(\n", " nodes=jnp.array([0.0, 0.0]),\n", " senders=jnp.array([0, 1]),\n", " receivers=jnp.array([1, 0]),\n", " edges=None,\n", " n_node=jnp.array([2]),\n", " n_edge=jnp.array([2]),\n", " globals=None,\n", " position=None,\n", " y=jnp.array([0, 0]),\n", " train_mask=None,\n", ")\n", "\n", "batched_graph, batch_index = batch([graph1, graph2])" ] }, { "cell_type": "markdown", "metadata": { "id": "bm3sA4o1uSn3" }, "source": [ "## Unbatching graphs\n", "\n", "To unbatch a `DataGraphTuple` object created with `batch` function,\n", "we can use the `haiku_geometric.utils.unbatch` function. This function takes a `DataGraphTuple`\n", "object and returns a list of `haiku_geometric.data.DataGraphTuple` objects." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "xyDiWJv5ugsr" }, "outputs": [], "source": [ "from haiku_geometric.utils import unbatch\n", "\n", "unbatched_graphs = unbatch(batched_graph)\n", "graph1 = unbatched_graphs[0]\n", "graph2 = unbatched_graphs[1]" ] }, { "cell_type": "markdown", "metadata": { "id": "UtgF9JP1ZZ1-" }, "source": [ "## Dynamic batching\n", "\n", "Unfortunately, Haiku Geometric does not currently support dynamic batching. If you are working with `jraph`,\n", "you can create a `jraph.GraphsTuple` object and use the available function `jraph.dynamically_batch`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "-HqgVgEvZawQ" }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import jraph\n", "\n", "graph1 = jraph.GraphsTuple(\n", " nodes=jnp.array([0.0, 0.1, 0.2]),\n", " senders=jnp.array([0, 1, 2]),\n", " receivers=jnp.array([2, 2, 0]),\n", " edges=None,\n", " n_node=jnp.array([3]),\n", " n_edge=jnp.array([3]),\n", " globals=None,\n", ")\n", "\n", "graph2 = jraph.GraphsTuple(\n", " nodes=jnp.array([0.0, 0.0]),\n", " senders=jnp.array([0, 1]),\n", " receivers=jnp.array([1, 0]),\n", " edges=None,\n", " n_node=jnp.array([2]),\n", " n_edge=jnp.array([2]),\n", " globals=None,\n", ")\n", "\n", "MAXIMUM_NUM_NODES = 2\n", "MAXIMUM_NUM_EDGES = 3\n", "MAXIMUM_NUM_GRAPHS = 2\n", "\n", "batched_generator = jraph.dynamically_batch([graph1, graph2],\n", " MAXIMUM_NUM_NODES, # max number of nodes in a batch\n", " MAXIMUM_NUM_EDGES, # max number of edges in a batch\n", " MAXIMUM_NUM_GRAPHS) # max number of graphs in a batch" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 1 }