haiku_geometric.datasets
haiku_geometric.datasets#
Container class containing a list of graphs. |
|
Container class containing an individual graph data. |
|
Toy dataset from Deepmind's intro to graph nets tutorial with jraph. |
|
TZachary's karate club network from the "An Information Flow Model for Conflict and Fission in Small Groups" paper |
|
Interface for the datasets from the "Open Graph Benchmark" . |
|
The Planetoid dataset from the "Revisiting Semi-Supervised Learning with Graph Embeddings" paper. |
|
Interface for the datasets from the "Benchmarking Graph Neural Networks" . |
|
The TUDataset from the "TUDataset: A collection of benchmark datasets for learning with graphs" paper. |
- class haiku_geometric.datasets.DataGraphTuple(nodes: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], edges: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], receivers: Optional[jax._src.numpy.ndarray.ndarray], senders: Optional[jax._src.numpy.ndarray.ndarray], n_node: jax._src.numpy.ndarray.ndarray, n_edge: jax._src.numpy.ndarray.ndarray, globals: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], position: Optional[jax._src.numpy.ndarray.ndarray], y: Optional[jax._src.numpy.ndarray.ndarray], train_mask: Optional[jax._src.numpy.ndarray.ndarray])[source]#
Bases:
tupleContainer class containing an individual graph data.
Attributes:
nodes: If available, Array of node features.
edges: If available, Array of edge features.
receivers: Array of receiver node indices.
senders: Array of sender node indices.
globals: If available, array of global features.
n_node: Number of nodes in the graph.
n_edge: Number of edges in the graph.
y: If available, ground truth for each node, edge or whole graph.
position: If available, Array of node positions.
train_mask: If available, array of booleans indicating which elements are in the train set.
- edges: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]#
Alias for field number 1
- globals: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]#
Alias for field number 6
- n_edge: jax._src.numpy.ndarray.ndarray#
Alias for field number 5
- n_node: jax._src.numpy.ndarray.ndarray#
Alias for field number 4
- nodes: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]#
Alias for field number 0
- position: Optional[jax._src.numpy.ndarray.ndarray]#
Alias for field number 7
- receivers: Optional[jax._src.numpy.ndarray.ndarray]#
Alias for field number 2
- senders: Optional[jax._src.numpy.ndarray.ndarray]#
Alias for field number 3
- train_mask: Optional[jax._src.numpy.ndarray.ndarray]#
Alias for field number 9
- y: Optional[jax._src.numpy.ndarray.ndarray]#
Alias for field number 8
- class haiku_geometric.datasets.GNNBenchmarkDataset(name, root, split='train')[source]#
Bases:
haiku_geometric.datasets.base.GraphDatasetInterface for the datasets from the “Benchmarking Graph Neural Networks” .
Note
Usage of this dataset requires installing first the Pytorch Geometric package.
- Parameters
name (str) – Name of the GNN Benchmark dataset. Available datasets are:
PATTERN,CLUSTER,MNIST,CIFAR10,TSPandCSL.root (str) – Root directory where the dataset should be saved.
split (str) – Split of the dataset. Split can take values:
train,validandtest.
Attributes:
data (List[DataGraphTuple]): List of graph tuples.
- class haiku_geometric.datasets.GraphDataset(data=[], y=None)[source]#
Bases:
objectContainer class containing a list of graphs.
Attributes: - data: List of DataGraphTuple. - y: If available, ground truth for each graph.
- class haiku_geometric.datasets.KarateClub[source]#
Bases:
haiku_geometric.datasets.base.GraphDatasetTZachary’s karate club network from the “An Information Flow Model for Conflict and Fission in Small Groups” paper
- Attributes:
data List[DataGraphTuple]: List of length 1 containing a single graph.
- Stats:
#nodes
#edges
#features
nodes features size
edge features size
#classes
34
156
34
34
0
4
- class haiku_geometric.datasets.OGB(name, root=None)[source]#
Bases:
haiku_geometric.datasets.base.GraphDatasetInterface for the datasets from the “Open Graph Benchmark” .
Note
- Usage of this dataset requires installing first the ogb package:
pip install ogb
- Parameters
name (-) – Name of the OGB dataset.
root (-) – Root directory where the dataset should be saved.
Attributes:
data: (List[DataGraphTuple]): List of graph tuples.
splits: (Dict[str, List[int]]): Dictionary with the indices of the train, validation and test set:
{ 'train': [...], 'valid': [...], 'test': [...] }
- class haiku_geometric.datasets.Planetoid(name, root, split='public', num_train_per_class=20, num_val=500, num_test=1000)[source]#
Bases:
haiku_geometric.datasets.base.GraphDatasetThe Planetoid dataset from the “Revisiting Semi-Supervised Learning with Graph Embeddings” paper.
- Parameters
name (str) – Name of the dataset. Can be one of
'cora','citeseer'or'pubmed'.root (str) – Root directory where the dataset will be saved.
split (str) – Which split to use. Can be one of
'public','full'or'random'.num_train_per_class (int) – Number of training examples for the
'random'split.num_val (int) – Number of validation examples. Only used for the
'random'split.num_test (int) – Number of test examples. Only used for the
'random'split.
Attributes:
data: (List[DataGraphTuple]): List of graph tuples (in this case only one graph).
train_mask: (List[bool]): Boolean mask for the training set.
val_mask: (List[bool]): Boolean mask for the validation set.
test_mask: (List[bool]): Boolean mask for the test set.
num_classes: (int): Number of classes.
- Stats:
Name
#nodes
#edges
#node features
#classes
Cora
2,708
10,858
1,433
7
CiteSeer
3,312
9,464
3,703
6
PubMed
19,717
88,676
500
3
- class haiku_geometric.datasets.TUDataset(name, root, use_node_attr=False, use_edge_attr=False)[source]#
Bases:
haiku_geometric.datasets.base.GraphDatasetThe TUDataset from the “TUDataset: A collection of benchmark datasets for learning with graphs” paper.
- Parameters
name (str) – Name of a TUDataset, e.g.
'ENZYMES','PROTEINS','MUTAG'.root (str) – Root directory where the dataset will be saved.
use_node_attr (bool) – If
True, the node attributes will be included in the graphs. (default:False)use_edge_attr (bool) – If
True, the edge attributes will be included in the graphs. (default:False)
- Attributes:
data: (List[DataGraphTuple]): List of graph tuples (in this case only one graph).
y: (jnp.ndarray): Graph labels.
- Stats:
Name
#graphs
#avg nodes
#avg edges
#node features
#edge features
#classes
PROTEINS
1113
39.06
72.82
4
0
2
ENZYMES
600
32.63
62.14
21
0
6
MUTAG
188
17.93
19.79
7
4
2
- class haiku_geometric.datasets.ToyGraphDataset[source]#
Bases:
haiku_geometric.datasets.base.GraphDatasetToy dataset from Deepmind’s intro to graph nets tutorial with jraph.
- Attributes:
data: (List[DataGraphTuple]): List of graph tuples containing only one graph.
- Stats:
#nodes
#edges
#node features
#edge features
#global features
#classes
4
5
1
1
1
0