Batch support
Contents
Batch support#
This notebook explains how to use the batch support in Haiku Geometric.
[ ]:
!pip install git+https://github.com/alexOarga/haiku-geometric.git
Batching graphs#
The haiku_geometric.utils.batch function can be used to batch a list of haiku_geometric.data.DataGraphTuple objects into a single haiku_geometric.data.DataGraphTuple object.
The batch function returns: - A single haiku_geometric.data.DataGraphTuple with the batched graphs. - A jax.numpy.Array with indices indicating to which graph each node belongs to.
[2]:
import jax.numpy as jnp
from haiku_geometric.utils import batch
from haiku_geometric.datasets.base import DataGraphTuple
graph1 = DataGraphTuple(
nodes=jnp.array([0.0, 0.1, 0.2]),
senders=jnp.array([0, 1, 2]),
receivers=jnp.array([2, 2, 0]),
edges=None,
n_node=jnp.array([3]),
n_edge=jnp.array([3]),
globals=None,
position=None,
y=jnp.array([0, 0, 0]),
train_mask=None,
)
graph2 = DataGraphTuple(
nodes=jnp.array([0.0, 0.0]),
senders=jnp.array([0, 1]),
receivers=jnp.array([1, 0]),
edges=None,
n_node=jnp.array([2]),
n_edge=jnp.array([2]),
globals=None,
position=None,
y=jnp.array([0, 0]),
train_mask=None,
)
batched_graph, batch_index = batch([graph1, graph2])
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Unbatching graphs#
To unbatch a DataGraphTuple object created with batch function, we can use the haiku_geometric.utils.unbatch function. This function takes a DataGraphTuple object and returns a list of haiku_geometric.data.DataGraphTuple objects.
[3]:
from haiku_geometric.utils import unbatch
unbatched_graphs = unbatch(batched_graph)
graph1 = unbatched_graphs[0]
graph2 = unbatched_graphs[1]
Dynamic batching#
Unfortunately, Haiku Geometric does not currently support dynamic batching. If you are working with jraph, you can create a jraph.GraphsTuple object and use the available function jraph.dynamically_batch.
[6]:
import jax.numpy as jnp
import jraph
graph1 = jraph.GraphsTuple(
nodes=jnp.array([0.0, 0.1, 0.2]),
senders=jnp.array([0, 1, 2]),
receivers=jnp.array([2, 2, 0]),
edges=None,
n_node=jnp.array([3]),
n_edge=jnp.array([3]),
globals=None,
)
graph2 = jraph.GraphsTuple(
nodes=jnp.array([0.0, 0.0]),
senders=jnp.array([0, 1]),
receivers=jnp.array([1, 0]),
edges=None,
n_node=jnp.array([2]),
n_edge=jnp.array([2]),
globals=None,
)
MAXIMUM_NUM_NODES = 2
MAXIMUM_NUM_EDGES = 3
MAXIMUM_NUM_GRAPHS = 2
batched_generator = jraph.dynamically_batch([graph1, graph2],
MAXIMUM_NUM_NODES, # max number of nodes in a batch
MAXIMUM_NUM_EDGES, # max number of edges in a batch
MAXIMUM_NUM_GRAPHS) # max number of graphs in a batch