Batch support#

This notebook explains how to use the batch support in Haiku Geometric.

Open in Colab

[ ]:
!pip install git+https://github.com/alexOarga/haiku-geometric.git

Batching graphs#

The haiku_geometric.utils.batch function can be used to batch a list of haiku_geometric.data.DataGraphTuple objects into a single haiku_geometric.data.DataGraphTuple object.

The batch function returns: - A single haiku_geometric.data.DataGraphTuple with the batched graphs. - A jax.numpy.Array with indices indicating to which graph each node belongs to.

[2]:
import jax.numpy as jnp
from haiku_geometric.utils import batch
from haiku_geometric.datasets.base import DataGraphTuple

graph1 = DataGraphTuple(
    nodes=jnp.array([0.0, 0.1, 0.2]),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([2, 2, 0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
    globals=None,
    position=None,
    y=jnp.array([0, 0, 0]),
    train_mask=None,
)

graph2 = DataGraphTuple(
    nodes=jnp.array([0.0, 0.0]),
    senders=jnp.array([0, 1]),
    receivers=jnp.array([1, 0]),
    edges=None,
    n_node=jnp.array([2]),
    n_edge=jnp.array([2]),
    globals=None,
    position=None,
    y=jnp.array([0, 0]),
    train_mask=None,
)

batched_graph, batch_index = batch([graph1, graph2])
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Unbatching graphs#

To unbatch a DataGraphTuple object created with batch function, we can use the haiku_geometric.utils.unbatch function. This function takes a DataGraphTuple object and returns a list of haiku_geometric.data.DataGraphTuple objects.

[3]:
from haiku_geometric.utils import unbatch

unbatched_graphs = unbatch(batched_graph)
graph1 = unbatched_graphs[0]
graph2 = unbatched_graphs[1]

Dynamic batching#

Unfortunately, Haiku Geometric does not currently support dynamic batching. If you are working with jraph, you can create a jraph.GraphsTuple object and use the available function jraph.dynamically_batch.

[6]:
import jax.numpy as jnp
import jraph

graph1 = jraph.GraphsTuple(
    nodes=jnp.array([0.0, 0.1, 0.2]),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([2, 2, 0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
    globals=None,
)

graph2 = jraph.GraphsTuple(
    nodes=jnp.array([0.0, 0.0]),
    senders=jnp.array([0, 1]),
    receivers=jnp.array([1, 0]),
    edges=None,
    n_node=jnp.array([2]),
    n_edge=jnp.array([2]),
    globals=None,
)

MAXIMUM_NUM_NODES = 2
MAXIMUM_NUM_EDGES = 3
MAXIMUM_NUM_GRAPHS = 2

batched_generator = jraph.dynamically_batch([graph1, graph2],
                                        MAXIMUM_NUM_NODES, # max number of nodes in a batch
                                        MAXIMUM_NUM_EDGES, # max number of edges in a batch
                                        MAXIMUM_NUM_GRAPHS)  # max number of graphs in a batch