haiku_geometric.datasets#

GraphDataset

Container class containing a list of graphs.

DataGraphTuple

Container class containing an individual graph data.

ToyGraphDataset

Toy dataset from Deepmind's intro to graph nets tutorial with jraph.

KarateClub

TZachary's karate club network from the "An Information Flow Model for Conflict and Fission in Small Groups" paper

OGB

Interface for the datasets from the "Open Graph Benchmark" .

Planetoid

The Planetoid dataset from the "Revisiting Semi-Supervised Learning with Graph Embeddings" paper.

GNNBenchmarkDataset

Interface for the datasets from the "Benchmarking Graph Neural Networks" .

TUDataset

The TUDataset from the "TUDataset: A collection of benchmark datasets for learning with graphs" paper.

class haiku_geometric.datasets.DataGraphTuple(nodes: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], edges: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], receivers: Optional[jax._src.numpy.ndarray.ndarray], senders: Optional[jax._src.numpy.ndarray.ndarray], n_node: jax._src.numpy.ndarray.ndarray, n_edge: jax._src.numpy.ndarray.ndarray, globals: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], position: Optional[jax._src.numpy.ndarray.ndarray], y: Optional[jax._src.numpy.ndarray.ndarray], train_mask: Optional[jax._src.numpy.ndarray.ndarray])[source]#

Bases: tuple

Container class containing an individual graph data.

Attributes:

  • nodes: If available, Array of node features.

  • edges: If available, Array of edge features.

  • receivers: Array of receiver node indices.

  • senders: Array of sender node indices.

  • globals: If available, array of global features.

  • n_node: Number of nodes in the graph.

  • n_edge: Number of edges in the graph.

  • y: If available, ground truth for each node, edge or whole graph.

  • position: If available, Array of node positions.

  • train_mask: If available, array of booleans indicating which elements are in the train set.

edges: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]#

Alias for field number 1

globals: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]#

Alias for field number 6

n_edge: jax._src.numpy.ndarray.ndarray#

Alias for field number 5

n_node: jax._src.numpy.ndarray.ndarray#

Alias for field number 4

nodes: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]#

Alias for field number 0

position: Optional[jax._src.numpy.ndarray.ndarray]#

Alias for field number 7

receivers: Optional[jax._src.numpy.ndarray.ndarray]#

Alias for field number 2

senders: Optional[jax._src.numpy.ndarray.ndarray]#

Alias for field number 3

train_mask: Optional[jax._src.numpy.ndarray.ndarray]#

Alias for field number 9

y: Optional[jax._src.numpy.ndarray.ndarray]#

Alias for field number 8

class haiku_geometric.datasets.GNNBenchmarkDataset(name, root, split='train')[source]#

Bases: haiku_geometric.datasets.base.GraphDataset

Interface for the datasets from the “Benchmarking Graph Neural Networks” .

Note

Usage of this dataset requires installing first the Pytorch Geometric package.

Parameters
  • name (str) – Name of the GNN Benchmark dataset. Available datasets are: PATTERN, CLUSTER, MNIST, CIFAR10, TSP and CSL.

  • root (str) – Root directory where the dataset should be saved.

  • split (str) – Split of the dataset. Split can take values: train, valid and test.

Attributes:

  • data (List[DataGraphTuple]): List of graph tuples.

class haiku_geometric.datasets.GraphDataset(data=[], y=None)[source]#

Bases: object

Container class containing a list of graphs.

Attributes: - data: List of DataGraphTuple. - y: If available, ground truth for each graph.

class haiku_geometric.datasets.KarateClub[source]#

Bases: haiku_geometric.datasets.base.GraphDataset

TZachary’s karate club network from the “An Information Flow Model for Conflict and Fission in Small Groups” paper

Attributes:
  • data List[DataGraphTuple]: List of length 1 containing a single graph.

Stats:

#nodes

#edges

#features

nodes features size

edge features size

#classes

34

156

34

34

0

4

class haiku_geometric.datasets.OGB(name, root=None)[source]#

Bases: haiku_geometric.datasets.base.GraphDataset

Interface for the datasets from the “Open Graph Benchmark” .

Note

Usage of this dataset requires installing first the ogb package:
pip install ogb
Parameters
  • name (-) – Name of the OGB dataset.

  • root (-) – Root directory where the dataset should be saved.

Attributes:

  • data: (List[DataGraphTuple]): List of graph tuples.

  • splits: (Dict[str, List[int]]): Dictionary with the indices of the train, validation and test set:

    {
        'train': [...],
        'valid': [...],
        'test': [...]
    }
    
class haiku_geometric.datasets.Planetoid(name, root, split='public', num_train_per_class=20, num_val=500, num_test=1000)[source]#

Bases: haiku_geometric.datasets.base.GraphDataset

The Planetoid dataset from the “Revisiting Semi-Supervised Learning with Graph Embeddings” paper.

Parameters
  • name (str) – Name of the dataset. Can be one of 'cora', 'citeseer' or 'pubmed'.

  • root (str) – Root directory where the dataset will be saved.

  • split (str) – Which split to use. Can be one of 'public', 'full' or 'random'.

  • num_train_per_class (int) – Number of training examples for the 'random' split.

  • num_val (int) – Number of validation examples. Only used for the 'random' split.

  • num_test (int) – Number of test examples. Only used for the 'random' split.

Attributes:

  • data: (List[DataGraphTuple]): List of graph tuples (in this case only one graph).

  • train_mask: (List[bool]): Boolean mask for the training set.

  • val_mask: (List[bool]): Boolean mask for the validation set.

  • test_mask: (List[bool]): Boolean mask for the test set.

  • num_classes: (int): Number of classes.

Stats:

Name

#nodes

#edges

#node features

#classes

Cora

2,708

10,858

1,433

7

CiteSeer

3,312

9,464

3,703

6

PubMed

19,717

88,676

500

3

class haiku_geometric.datasets.TUDataset(name, root, use_node_attr=False, use_edge_attr=False)[source]#

Bases: haiku_geometric.datasets.base.GraphDataset

The TUDataset from the “TUDataset: A collection of benchmark datasets for learning with graphs” paper.

Parameters
  • name (str) – Name of a TUDataset, e.g. 'ENZYMES', 'PROTEINS', 'MUTAG'.

  • root (str) – Root directory where the dataset will be saved.

  • use_node_attr (bool) – If True, the node attributes will be included in the graphs. (default: False)

  • use_edge_attr (bool) – If True, the edge attributes will be included in the graphs. (default: False)

Attributes:
  • data: (List[DataGraphTuple]): List of graph tuples (in this case only one graph).

  • y: (jnp.ndarray): Graph labels.

Stats:

Name

#graphs

#avg nodes

#avg edges

#node features

#edge features

#classes

PROTEINS

1113

39.06

72.82

4

0

2

ENZYMES

600

32.63

62.14

21

0

6

MUTAG

188

17.93

19.79

7

4

2

class haiku_geometric.datasets.ToyGraphDataset[source]#

Bases: haiku_geometric.datasets.base.GraphDataset

Toy dataset from Deepmind’s intro to graph nets tutorial with jraph.

Attributes:
  • data: (List[DataGraphTuple]): List of graph tuples containing only one graph.

Stats:

#nodes

#edges

#node features

#edge features

#global features

#classes

4

5

1

1

1

0