{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Creating your own dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open in Colab](https://img.shields.io/static/v1.svg?logo=google-colab&label=Creating%20your%20own%20dataset&message=Open%20In%20Colab&color=blue)](https://colab.research.google.com/github/alexOarga/haiku-geometric/blob/main/docs/source/notebooks/creating_dataset.ipynb)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "KE23rOiEM62A" }, "source": [ "## Using JAX arrays\n", "\n", "### Nodes and graph structure\n", "This notebook includes an example of how to create \n", "your own dataset for Haiku Geometric. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "le4i0LeONB-A" }, "outputs": [], "source": [ "!pip install haiku-geometric" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "1zTEYs2QM-Mw" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import haiku as hk" ] }, { "cell_type": "markdown", "metadata": { "id": "WqZ2mtMnNInQ" }, "source": [ "Currently, all GNNs layers in Haiku Geometric expect the following inputs:\n", "\n", "- `nodes`: a `jax.numpy.ndarray` array of shape `[num_nodes, num_node_features]` containing the node features. \n", "- `senders`: a `jax.numpy.ndarray` array of shape `[num_edges]` containing the indices of the source nodes. \n", "- `receivers`: a `jax.numpy.ndarray` array of shape `[num_edges]` containing the indices of the destination nodes. \n", "\n", "Notice that no object is actually necessary to use Haiku Geometric. \n", "It can be used with only JAX numpy arrays. If you want to create an object\n", "to store the graph data see [Creating a graph object](#Creating-a-graph-object).\n", "\n", "Lets say we want to create the following graph, with 4 nodes, 5 edges and\n", "3 features for each node:\n", "\n", "![graph](_static/graph1.png)\n", "\n", "To create the nodes array we arrange the features into a 2D array\n", "with the aforementioned shape:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zwg0TdfpNJWq" }, "outputs": [], "source": [ "nodes = jnp.array([\n", " [0.1, 0.2, 1.0], # node 0 features\n", " [0.4, 0.4, 0.3], # node 1 features\n", " [0.8, 0.0, 0.9], # node 2 features\n", " [0.0, 1.0, 1.0] # node 3 features\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "p_eyjiu9Namo" }, "source": [ "To create the senders and receivers, for each directed edge of the graph\n", "we need to specify the index of the source node and the index of the destination node:\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "9-nyJhG-Nc-O" }, "outputs": [], "source": [ "senders = jnp.array([0, 1, 1, 2, 2])\n", "receivers = jnp.array([1, 0, 2, 2, 3])" ] }, { "cell_type": "markdown", "metadata": { "id": "iDGwCyr0NfXm" }, "source": [ "Notice that self edges are represented by having the same index for the source and destination nodes.\n", "Similarly, to model undirected graphs we can use 2 directed edges, each in one direction.\n", "\n", "### Edge features\n", "\n", "Some GNN layer also allow the user to specify edge features. In that case,\n", "the layer expects, besides from the previous arrays, the following input:\n", "\n", "- `edges`: a `jax.numpy.ndarray` array of shape `[num_edges, num_edge_features]` containing the edge features.\n", "\n", "Let us now consider this graph where each edge has 2 features associated:\n", "\n", "![graph2](_static/graph2.png)\n", "\n", "To represent these features, we create an array of shape `[num_edges, num_edge_features]`\n", "with the edge features:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "xLbhU8UmNgDF" }, "outputs": [], "source": [ "edges = jnp.array([\n", " [0.0, 0.6], # edge from 0 to 1\n", " [1.0, 0.55], # edge from 1 to 0\n", " [0.01, 0.0], # edge from 1 to 2\n", " [0.4, 1.3], # edge from 2 to 2\n", " [0.9, 0.7] # edge from 2 to 3\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "Nd6yRCAuNkLv" }, "source": [ "\n", "**Notice** that the order of the features in the array must match the order of the edges in the `senders` and `receivers` arrays.\n", "\n", "## DataGraphTuple\n", "\n", "In Haiku Geometric, the `DataGraphTuple` object is used to store the graph data.\n", "When using datasets provided from `haiku_geometric.datasets`, each individual graph of a dataset is represented by a \n", "`DataGraphTuple` object.\n", "\n", "`DataGraphTuple` can be created as follows:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p3XUyuNFNk3O" }, "outputs": [], "source": [ "from haiku_geometric.datasets.base import DataGraphTuple\n", "\n", "DataGraphTuple(\n", " nodes=nodes,\n", " senders=senders,\n", " receivers=receivers,\n", " edges=edges,\n", " n_node=4,\n", " n_edge=5,\n", " globals=jnp.array([0.0, 0.0, 0.0]),\n", " position=None,\n", " y=jnp.array([0.0, 1.0, 0.0, 0,0]),\n", " train_mask=jnp.array([True, True, True, False]),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "dqSVjLoXN4LA" }, "source": [ "Besides the `nodes`, `senders`, `receivers`, `edges` arrays, the `DataGraphTuple` object also contains the following attributes:\n", "\n", "- `n_node`: the number of nodes in the graph. \n", "- `n_edge`: the number of edges in the graph. \n", "- `globals`: if available, an array containing the global features. \n", "- `position`: some datasets might also provide position features for each node. If available, an array containing the position of each node. \n", "- `y`: an array containing ground truth labels. \n", "- `train_mask`: an array containing a boolean mask. \n", "\n", "Notice that all the attributes are optional. If you don't have a specific attribute, you can simply set it to `None`." ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 1 }