haiku_geometric.nn#

Convolutional Layers#

GCNConv

The graph convolutional operator from the "Semi-supervised Classification with Graph Convolutional Networks" paper

GraphConv

The graph neural network operator from the "Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" paper

GeneralConv

A general GNN layer adapted from the "Design Space for Graph Neural Networks" paper.

GINConv

The graph isomorphism operator from the "How Powerful are Graph Neural Networks?" paper

GINEConv

Graph Isomorphism operator introduced to include edge features from "Strategies for Pre-training Graph Neural Networks" paper

GATConv

Graph attention layer from "Graph Attention Networks" paper

SAGEConv

The GraphSAGE operator from the "Inductive Representation Learning on Large Graphs" paper

GatedGraphConv

The gated graph convolution operator from the "Gated Graph Sequence Neural Networks" paper

PNAConv

The Principal Neighbourhood Aggregation graph convolution operator from the "Principal Neighbourhood Aggregation for Graph Nets" paper.

GPSLayer

GPS layer from the "Recipe for a General, Powerful, Scalable Graph Transformer" paper.

EdgeConv

The edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper.

MetaLayer

A meta layer for building any kind of graph network, inspired by the "Relational Inductive Biases, Deep Learning, and Graph Networks" paper.

class haiku_geometric.nn.conv.EdgeConv(nn, aggr='max')[source]#

The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper.

The node features are computed as follows:

\[\mathbf{h}^{k + 1}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{h}_i \, \Vert \, \mathbf{h}_j - \mathbf{h}_i)\]

where \(h_{\mathbf{\Theta}}\) denotes a neural network, e.g. a MLP.

Parameters
  • nn (hk.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features h of shape [-1, 2 * in_channels] to shape [-1, out_channels].

  • aggr (string, optional) – The aggregation operator ("add", "mean", "max"). (default: "max")

__call__(nodes, senders, receivers, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(nn, aggr='max')[source]#
class haiku_geometric.nn.conv.GATConv(out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0.0, dropout_nodes=0.0, add_self_loops=True, bias=True, init=None)[source]#

Graph attention layer from “Graph Attention Networks” paper

where each node’s output feature is computed as follows:

\[\vec{h}_{i}^{\prime}=\sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j} \mathbf{W} \vec{h}_{j}\right)\]

where the attention coefficients are computed as:

\[\alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{k}\right]\right)\right)}\]

When multiple attention heads are used, the output nodes features are averaged:

\[\vec{h}_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)\]

If concat=True the output feature is the concatenation of the \(K\) heads features:

\[\vec{h}_{i}^{\prime}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)\]
Parameters
  • out_channels (int) – Size of the output features produced by the layer for each node.

  • heads (int, optional) – Number of head attentions. (default: 1)

  • concat (bool, optional) – If False, the multi-head features are averaged else concatenated. (default: True)

  • negative_slope (float, optional) – scalar specifying the negative slope of the LeakyReLU. (default: 0.2)

  • add_self_loops (bool, optional) – If True, will add a self-loop for each node of the graph. (default: True)

  • dropout (float, optional) – Dropout applied to attention weights. This dropout simulates random sampling of the neigbours. (default: 0.0)

  • dropout_nodes (float, optional) – Dropout applied initially to the input features. (default: 0.0)

  • bias (bool, optional) – If True, the layer will add an additive bias to the output. (default: True)

  • init (hk.initializers.Initializer) – Weights initializer (default: hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"))

__call__(in_nodes_features, senders, receivers, edges=None, num_nodes=None, training=False)[source]#
Return type

ndarray

__init__(out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0.0, dropout_nodes=0.0, add_self_loops=True, bias=True, init=None)[source]#
lift(scores_source, scores_target, nodes_features_matrix_proj, senders, receivers)[source]#
class haiku_geometric.nn.conv.GCNConv(out_channels, improved=False, cached=False, add_self_loops=True, normalize=True, bias=True, aggr='add')[source]#

The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper

\[H^{(l+1)} = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W\]

where \(\tilde{A} = A + I_N\) is the adjacency matrix with added self loops \(I_N\). and \(\tilde{D}_ii = \sum_j \tilde{A}_ij\).

The node-wise formulation is given by:

\[\mathbf{h}_u = W^{\top} \sum_{v \in \mathcal{N}(u) \cup \{ u \}} \frac{e_{v,u}}{\sqrt{\hat{d}_u \hat{d}_v}} \mathbf{h}_v\]

where \(e_{v,u}\) is the edge weight and \(\hat{d}_u = 1 + \sum_{v \in \mathcal{N}(u)} e_{v,u}\)

Parameters
  • out_channels (int) – Size of the output features of each node.

  • improved (bool, optional) – If True, then \(\mathbf{\hat{A}}\) is computed as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

  • cached (bool, optional) – If True, the value \(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\) on first execution, and will use the is cached and used in further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If True, will add a self-loop for each node of the graph. (default: True)

  • normalize (bool, optional) – Whether to compute and apply the symmetric normalization. (default: True)

  • bias (bool, optional) – If True, the layer will add an additive bias to the output. (default: True)

__call__(nodes, senders, receivers, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(out_channels, improved=False, cached=False, add_self_loops=True, normalize=True, bias=True, aggr='add')[source]#
class haiku_geometric.nn.conv.GINConv(nn, eps=0.0, train_eps=False)[source]#

The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper

The node features are computed as follows:

\[\mathbf{{h}}_{u}^{k}= \phi\left( (1 + \epsilon) \mathbf{{h}}_{u}^{k-1} + \sum_{v \in \mathcal{N}(u)} \mathbf{{h}}_{v}^{k-1}\right)\]

where \(\phi\) is a neural network (e.g. a MLP).

Parameters
  • nn (hk.Module) – A neural network \(\phi\) that produces output features of shape out_channels defined by the user.

  • eps (float, optional) – \(\epsilon\) value. (default: 0.)

  • train_eps (bool, optional) – If True, \(\epsilon\) will be a trainable parameter. (default: False)

__call__(nodes, senders, receivers, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(nn, eps=0.0, train_eps=False)[source]#
class haiku_geometric.nn.conv.GINEConv(nn, eps=0.0, train_eps=False, edge_dim=None)[source]#

Graph Isomorphism operator introduced to include edge features from “Strategies for Pre-training Graph Neural Networks” paper

The node features are computed as follows:

\[\mathbf{{h}}_{u}^{k}= \phi\left( (1 + \epsilon) \mathbf{{h}}_{u}^{k-1} + \sum_{v \in \mathcal{N}(u)} ReLU(\mathbf{{h}}_{v}^{k-1} + \mathbf{e}_{u,v}) \right)\]

where \(\phi\) is a neural network (e.g. a MLP) and \(\mathbf{e}_{j,i}\) are the edge features.

Parameters
  • nn (hk.Module) – A neural network \(\phi\) that produces output features of shape out_channels defined by the user.

  • eps (float, optional) – \(\epsilon\) value. (default: 0.)

  • train_eps (bool, optional) – If True, \(\epsilon\) will be a trainable parameter. (default: False)

  • edge_dim (int, optional) – If None, edge and node features shapes are expected to match. Otherwise, the edge features are linearly transformed to match the node features shape.

__call__(nodes, senders, receivers, edges, num_nodes=None)[source]#
Return type

ndarray

__init__(nn, eps=0.0, train_eps=False, edge_dim=None)[source]#
class haiku_geometric.nn.conv.GPSLayer(dim_h, local_gnn_type, global_model_type, act, num_heads=1, pna_degrees=None, equivstable_pe=False, dropout=0.0, attn_dropout=0.0, layer_norm=False, batch_norm=True)[source]#

GPS layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.

🚧: This layer is still under development and might not work as expected.

Parameters
  • dim_h (int) – Size of each output features.

  • local_gnn_type (str) – Name of a message passing neural network. Available networks are: None, "GINE", "GAT", "PNA".

  • global_model_type (str) – Name of a global attention layer. Available networks are: None, "Transformer", "Performer".

  • act (Callable) – (Callable): activation function (e.g. jax.nn.relu).

  • num_heads (int, optional) – number of heads when using multi-head attention. (default: 1).

  • pna_degrees (jnp.ndarray, optional) – Array of degrees histogram when using PNA.

  • equivstable_pe (bool, optional) –

    • Not implemented *.

    (default: False).

  • dropout (float, optional) – dropout rate. (default: 0.0).

  • attn_dropout (float, optional) – dropout rate with global attention. (default: 0.0).

  • layer_norm (bool, optional) – Whether to use layer normalization. (default: False).

  • batch_norm (bool, optional) – Whether to use batch normalization. (default: True).

__call__(training, nodes, senders=None, receivers=None, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(dim_h, local_gnn_type, global_model_type, act, num_heads=1, pna_degrees=None, equivstable_pe=False, dropout=0.0, attn_dropout=0.0, layer_norm=False, batch_norm=True)[source]#
class haiku_geometric.nn.conv.GatedGraphConv(out_channels, num_layers, aggr='add')[source]#

The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper

where the output features are computed as follows:

\[\mathbf{{h}}_{u}^{(0)} = \mathbf{{h}}_{u}^{(0)} \Vert \mathbf{0}\]

for layer \(k: 1,...,L\):

\[\begin{split}\mathbf{{m}}_{u}^{(k)} &= \text{AGGREGATE}(\{e_{u, v} \cdot \mathbf{{W}} \cdot \mathbf{{h}}_{v}^{(k - 1)}, \forall v \in \mathcal{N}(u)\}) \\ \mathbf{{h}}_{u}^{(k)} &= GRU(\mathbf{{m}}_{u}^{(k)}, \mathbf{{h}}_{u}^{(k - 1)})\end{split}\]

with \(AGGREGATE\) being the aggregation operator (i.e. "mean", "max", or "add").

Parameters
  • out_channels (int) – Size of the output features of each node.

  • num_layers (int) – Number of layers \(L\).

  • aggr (string, optional) – Aggregation operator ("add", "mean", "max"). (default: "add")

__call__(nodes, senders, receivers, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(out_channels, num_layers, aggr='add')[source]#
class haiku_geometric.nn.conv.GeneralConv(out_channels, in_edge_channels=None, aggr='add', skip_linear=False, directed_msg=True, heads=1, attention=False, attention_type='additive', l2_normalize=False, bias=True)[source]#

A general GNN layer adapted from the “Design Space for Graph Neural Networks” paper.

where the output features are computed as follows:

\[\mathbf{{h}}_{u}^{k}= \text{AGGREGATE}(\{\mathbf{m}_{u,v}, \forall v \in \mathcal{N}(u)\})\]

with \(AGGREGATE\) being the aggregation operator (i.e. "mean", "max", or "add") and each message \(\mathbf{m}_{u,v}\) is computed as:

\[\mathbf{m}_{u,v} = \mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1}\]

If directed_msg=True, the message is bidirectional:

\[\mathbf{m}_{u,v} = \mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1} + \mathbf{W}_2 \cdot \mathbf{{h}}_{u}^{k-1}\]

If in_edge_channels is not None, the edge features are also added to the message:

\[\mathbf{m}_{u,v} = \mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1} + \mathbf{W}_3 \cdot \mathbf{{e}}_{u, v}\]

If attention=True, attention is performed on the message computation:

\[\mathbf{m}_{u,v} = \alpha_{u,v}(\mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1})\]

where the attention coefficient \(\alpha_{u,v}\) is computed as follows:

\[\alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} + \mathbf{W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} + \mathbf{W} \vec{h}_{k}\right]\right)\right)}\]

If skip_linear=True a skip connection is added to the output:

\[\begin{split}\mathbf{{h}}_{u}^{k}= \text{AGGREGATE}(\{\mathbf{m}_{u,v}, \forall v \in \mathcal{N}(u)\}) + \mathbf{W}_4 \cdot \mathbf{{h}}_{u}^{k-1} \\\end{split}\]
Parameters
  • out_channels (int) – Size of the output features of a node.

  • in_edge_channels (int, optional) – Size of each edge features. (default: None)

  • aggr (string or Aggregation, optional) – The aggregation operator. Available values are: "mean", "max", or "add". (default: "add")

  • skip_linear (bool, optional) – (default: False)

  • directed_msg (bool, optional) – (default: True)

  • heads (int, optional) – Number of head attentions. If (attention=True) and (heads > 1) the multi-head features are mean aggregated. (default: 1)

  • attention (bool, optional) – perform attention over the messages (default: False)

  • attention_type (str, optional) – Type of attention: "additive", "dot_product". (default: "additive")

  • l2_normalize (bool, optional) – If True, output features are \(\ell_2\)-normalized. (default: False)

  • bias (bool, optional) – If True, linear transformation also add bias. (default: True)

__call__(nodes, senders, receivers, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(out_channels, in_edge_channels=None, aggr='add', skip_linear=False, directed_msg=True, heads=1, attention=False, attention_type='additive', l2_normalize=False, bias=True)[source]#
class haiku_geometric.nn.conv.GraphConv(out_channels, aggr='add', bias=True)[source]#

The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper

The node features are computed as follows:

\[\mathbf{{h}}_{u}^{k}=\mathbf{W}_1 \cdot \mathbf{{h}}_{u}^{k-1} + \mathbf{W}_2 \cdot \text{AGGREGATE}(\{e_{u, v} \cdot \mathbf{{h}}_{v}^{k-1}, \forall v \in \mathcal{N}(u)\})\]

with \(AGGREGATE\) being the aggregation operator (i.e. "mean", "max", or "add")

Parameters
  • out_channels (int) – Size of the output features of a node.

  • aggr (string or Aggregation, optional) – The aggregation operator. Available values are: "mean", "max", or "add". (default: "add")

  • bias (bool, optional) – If True, the layer will add an additive bias to the output. (default: True)

__call__(nodes=None, senders=None, receivers=None, edges=None, num_nodes=None)[source]#
Return type

Union[ndarray, GraphsTuple]

__init__(out_channels, aggr='add', bias=True)[source]#
class haiku_geometric.nn.conv.MetaLayer(edge_model=None, node_model=None, global_model=None)[source]#

A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.

General graph network that takes as input the nodes features nodes, the edges features edge_attr, the senders nodes indices senders, the receivers nodes indices receivers and the global features of the graph globals.

It returns the updated nodes features nodes, the updated edges features edge_attr and the updated global features globals.

Nodes features are updated after calling the node model node_model, edges features are updated after calling the edge model edge_model and global features are updated after calling the global model global_model.

Parameters
  • edge_model (hk.Module, optional) –

    A neural network that updates its edge features based on its source and target nodes features. It receives as input:

    • senders features of shape [E, F_N] where E is the number of edges and F_N the number of input node features.

    • receivers features of shape [E, F_N] where E is the number of edges and F_N the number of input node features.

    • edges features of shape [E, F_E] where E is the number of edges and F_E the number of input edge features.

    • globals features of shape [F_G] for non-batched graphs or shape [G * F_G] for batched graphs, where G is the number of graphs and F_G the shape of the global features.

    • batch indices of shape [N], where N is the number of nodes. This array indicates to which graph each node belongs to.

  • node_model (hk.Module, optional) –

    A neural network that updates the nodes features based on the current node features, edge features and global features. It receives as input:

    • nodes features of shape [N, F_N] where N is the number of nodes and F_N the number of input node features.

    • senders features of shape [E, F_N] where E is the number of edges and F_N the number of input node features.

    • receivers features of shape [E, F_N] where E is the number of edges and F_N the number of input node features.

    • edges features of shape [E, F_E] where E is the number of edges and F_E the number of input edge features.

    • globals features of shape [F_G] for non-batched graphs or shape [G * F_G] for batched graphs, where G is the number of graphs and F_G the shape of the global features.

    • batch indices of shape [N], where N is the number of nodes. This array indicates to which graph each node belongs to.

  • global_model (hk.Module, optional) –

    A neural network that updates a graph global features based on the current nodes features, edges features and global features. It receives as input:

    • nodes features of shape [N, F_N] where N is the number of nodes and F_N the number of input node features.

    • senders features of shape [E, F_N] where E is the number of edges and F_N the number of input node features.

    • receivers features of shape [E, F_N] where E is the number of edges and F_N the number of input node features.

    • edges features of shape [E, F_E] where E is the number of edges and F_E the number of input edge features.

    • globals features of shape [F_G] for non-batched graphs or shape [G * F_G] for batched graphs, where G is the number of graphs and F_G the shape of the global features.

    • batch indices of shape [N], where N is the number of nodes. This array indicates to which graph each node belongs to.

Returns

  • The updated nodes features if node_model is not None.

  • The updated edges features if edge_model is not None.

  • The updated globals features if global_model is not None.

Return type

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

Examples:

import haiku as hk
from haiku_geometric.nn.aggr.utils import aggregation

class EdgeModel(hk.Module):
    def __init__(self):
        super().__init__()
        self.mlp = hk.Sequential([hk.Linear(...), jax.nn.relu, hk.Linear(...)])

    def __call__(self, senders_features, receivers_features, edges_features, globals, batch, num_nodes=None):
        h = jnp.concatenate([senders_features, receivers_features, edges_features], axis=-1)
        return self.mlp(h)

class NodeModel(hk.Module):
    def __init__(self):
        super().__init__()
        self.aggr = aggregation('mean')
        self.mlp = hk.Sequential([hk.Linear(...), jax.nn.relu, hk.Linear(...)])

    def __call__(self, nodes, senders, receivers, edge_attr, globals, batch, num_nodes=None):
        h = jnp.concatenate([nodes[senders], edge_attr], axis=1)
        messages = self.mlp(h)
        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
        return self.aggr(messages, receivers, total_num_nodes)

class GlobalModel(hk.Module):
    def __init__(self):
        super().__init__()
        self.mlp = hk.Sequential([hk.Linear(...), jax.nn.relu, hk.Linear(...)])

    def __call__(self, nodes, senders, receivers, edge_attr, globals, batch, num_nodes=None):
        return self.mlp(globals)
__call__(nodes, senders, receivers, edge_attr=None, globals=None, batch=None, num_nodes=None)[source]#
Return type

Tuple[ndarray, ndarray, ndarray]

__init__(edge_model=None, node_model=None, global_model=None)[source]#

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

class haiku_geometric.nn.conv.PNAConv(out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False)[source]#

The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper.

Out features are computed as follows:

\[\mathbf{h}_u^{(k+1)} = U \left( \mathbf{h}_u^{(k)}, \underset{v \in \mathcal{N}(u)}{\bigoplus} M \left( \mathbf{h}_u^{(k)}, \mathbf{h}_v^{(k)} \right) \right)\]

with \(M\) and \(U\) being MLPs, and:

\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]
__call__(nodes, senders, receivers, edges=None, num_nodes=None)[source]#
Return type

ndarray

__init__(out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False)[source]#
class haiku_geometric.nn.conv.SAGEConv(out_channels, aggr='mean', normalize=False, root_weight=True, project=False, bias=True)[source]#

The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper

The node features are computed as follows:

\[\mathbf{{h}}_{u}^{k}=\mathbf{W}_1 \cdot \mathbf{{h}}_{u}^{k-1} + \mathbf{W}_2 \cdot \text{CONCAT}(\mathbf{{h}}_{u}^{k-1}, \mathbf{{h}}_{\mathcal{N}(u)}^{k})\]

having:

\[\mathbf{{h}}_{\mathcal{N}(u)}^{k} = \text{AGGREGATE}(\{\mathbf{{h}}_{v}^{k-1}, \forall v \in \mathcal{N}(u)\})\]

and \(AGGREGATE\) being an aggregation operator (i.e. "mean", "max", or "sum")

If project = True, then \(\mathbf{{h}}_{u}^{k-1}\) is first projected via:

\[\mathbf{{h}}_{v}^{k-1}=\text{ReLU}(\mathbf{W}_3 \mathbf{{h}}_{v}^{k-1} + \mathbf{b})\]
Parameters
  • out_channels (int) – Size of the output features of a node.

  • aggr (string or Aggregation, optional) – The aggregation operator. Available values are: "mean", "max", or "sum". (default: "mean")

  • normalize (bool, optional) – If True, output features are \(\ell_2\)-normalized. (default: False)

  • root_weight (bool, optional) – If False the linear transformed features \(\mathbf{W}_1 \cdot \mathbf{{h}}_{u}^{k-1}\) are not added to the output features. (default: True)

  • project (bool, optional) – If True, neighbour features are projected before aggregation as explained above. (default: False)

  • bias (bool, optional) – If True, the layer will add an additive bias to the output. (default: True)

__call__(nodes, senders, receivers, num_nodes=None)[source]#
Return type

ndarray

__init__(out_channels, aggr='mean', normalize=False, root_weight=True, project=False, bias=True)[source]#

Pooling Layers#

global_add_pool

Returns the sum of all node features of the input graph:

global_mean_pool

Returns the average of all node features of the input graph:

global_max_pool

Returns the maximum across the input features.

TopKPooling

Topk pooling operator from the "Graph U-Nets" and "Towards Sparse Hierarchical Graph Classifiers" paper.

class haiku_geometric.nn.pool.TopKPooling(in_channels, ratio=0.5, multiplier=1.0)[source]#

Topk pooling operator from the “Graph U-Nets” and “Towards Sparse Hierarchical Graph Classifiers” paper.

Parameters
  • in_channels (int) – Dimension of input node features.

  • ratio (Union[int, float]) – (Union[int, float], optional): Ratio of nodes to keep. If int, the number of nodes to keep. (default: 0.5).

  • multiplier (float, optional) – Multiplier to scale the features after pooling. (default: 1.).

__call__(x, senders, receivers, edges=None, batch=None, create_new_batch=False, batch_size=None, max_num_nodes=None)[source]#
Parameters
  • x (jnp.ndarray) – Node features of shape [num_nodes, in_channels].

  • senders (jnp.ndarray) – Sender indices.

  • receivers (jnp.ndarray) – Receiver indices.

  • edges (jnp.ndarray, optional) – Edge features of shape [num_edges, in_channels]. (default: None).

  • batch (jnp.ndarray, optional) – Batch array with batch indexes for each node. Shape: [num_nodes]. Note: This array should be sorted in increasing order. (default: None).

  • create_new_batch (bool, optional) – If set to False, nodes that are not top-k selected and their edges are removed from the graph. If set to True, the nodes are kept in the graph, but they are assigned to a new batch with value batch_size + 1. Their corresponding edges are transformed in self-loops. Note: If True, the output sizes of x, batch, senders, receivers and edges stay the same. If False output sizes of x, batch, senders, receivers and edges might be reduced according to the ratio parameter. (default: False).

  • batch_size (int, optional) – Number of batched graphs. If not given, it is automatically computed as batch.max() + 1. (default: None).

  • max_num_nodes (int, optional) – Maximum number of nodes that a graph can have. If not given, it is automatically computed as batch.shape[0]. (default: None).

Returns

  • The updated nodes features.

  • The updated senders indices.

  • The updated receivers indices.

  • The updated edges features.

  • The updated batch array.

Return type

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Observations:

To make this layer jit-able, it requires providing parameters create_new_batch=True and batch_size as static parameters.

__init__(in_channels, ratio=0.5, multiplier=1.0)[source]#
haiku_geometric.nn.pool.global_add_pool(x, batch=None, num_segments=None)[source]#

Returns the sum of all node features of the input graph:

\[\mathbf{r} = \sum_{i=1}^{N} \mathbf{h}_i.\]
Parameters
  • x (jax.numpy.ndarray) – Node features array.

  • batch (jax.numpy.ndarray, optional) – Batch vector with indices that indicate to which graph each node belongs. (default: None).

  • num_segments (int, optional) – Number of segments in batch. (default: None)

Returns

Array with the sum of the nodes features. If batch is not None, the output array will have shape [batch_size, *], where * denotes the remaining dimensions.

Return type

(jax.numpy.ndarray)

haiku_geometric.nn.pool.global_max_pool(x, batch=None, num_segments=None)[source]#

Returns the maximum across the input features. The maximum is performed individually over each channel.

\[\mathbf{r} = \max_{i=1}^{N} \mathbf{h}_i.\]
Parameters
  • x (jax.numpy.ndarray) – Node features array.

  • batch (jax.numpy.ndarray, optional) – Batch vector with indices that indicate to which graph each node belongs. (default: None).

  • num_segments (int, optional) – Number of segments in batch. (default: None)

Returns

Array with the average of the nodes features. If batch is not None, the output array will have shape [batch_size, *], where * denotes the remaining dimensions.

Return type

(jax.numpy.ndarray)

haiku_geometric.nn.pool.global_mean_pool(x, batch=None, num_segments=None)[source]#

Returns the average of all node features of the input graph:

\[\mathbf{r} = \frac{1}{N} \sum_{i=1}^{N} \mathbf{h}_i.\]
Parameters
  • x (jax.numpy.ndarray) – Node features array.

  • batch (jax.numpy.ndarray, optional) – Batch vector with indices that indicate to which graph each node belongs. (default: None).

  • num_segments (int, optional) – Number of segments in batch. (default: None)

Returns

Array with the average of the nodes features. If batch is not None, the output array will have shape [batch_size, *], where * denotes the remaining dimensions.

Return type

(jax.numpy.ndarray)

Dense Layers#

Linear

Applies a linear module on the input features

class haiku_geometric.nn.dense.Linear(out_channels, bias=True, weight_initializer=None, bias_initializer=None)[source]#

Applies a linear module on the input features

\[\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}\]
Parameters
  • out_channels (int) – Size of each output features.

  • bias (bool, optional) – Whether to add a bias to the output. (default: True)

  • weight_initializer (Optional[Callable[[Sequence[int], Any], ndarray]]) – Optional initializer for weights. By default, uses random values from truncated normal, with stddev 1 / sqrt(fan_in).

  • bias_initializer (Optional[Callable[[Sequence[int], Any], ndarray]]) – Optional initializer for the bias. Default to zeros. (default: None)

__call__(x=None, graph=None)[source]#
Return type

Union[ndarray, GraphsTuple]

__init__(out_channels, bias=True, weight_initializer=None, bias_initializer=None)[source]#

Agregation Operators#

Aggregation

SumAggregation

MeanAggregation

MaxAggregation

MinAggregation

MultiAggregation

Performs the aggregation of multiple aggregators operators.

DegreeScalerAggregation

Combines ,multiple aggregation operators and transforms its outputs with scalers as described in the: "Principal Neighbourhood Aggregation for Graph Nets" paper.

class haiku_geometric.nn.aggr.Aggregation(*args, **kwargs)[source]#
class haiku_geometric.nn.aggr.DegreeScalerAggregation(aggr, scaler, deg, train_norm=False)[source]#

Combines ,multiple aggregation operators and transforms its outputs with scalers as described in the: “Principal Neighbourhood Aggregation for Graph Nets” paper. This aggregation is used in the PNAConv convolution layer.

Parameters
  • aggr (string or list or Aggregation) – Aggregation or list of aggregation operators to be used.

  • scaler (str or list) – List of scaling function identifiers. Available scalers are: "identity", "amplification", "attenuation", "linear" and "inverse_linear".

  • deg (jnp.ndarray) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.

  • train_norm (bool, optional) – (default: False)

class haiku_geometric.nn.aggr.MaxAggregation(*args, **kwargs)[source]#
class haiku_geometric.nn.aggr.MeanAggregation(*args, **kwargs)[source]#
class haiku_geometric.nn.aggr.MinAggregation(*args, **kwargs)[source]#
class haiku_geometric.nn.aggr.MultiAggregation(aggrs, aggrs_kwargs=None, mode='cat', mode_kwargs=None)[source]#

Performs the aggregation of multiple aggregators operators. The aggregation is performed according to the mode parameter.

Parameters
  • aggrs (list) – List of aggregation operators.

  • aggrs_kwargs (list, optional) – Optional arguments passed to the aggregator operator function. This must be a list of dictionaries of length equal to the length of aggrs parameter. (default: None)

  • mode (string, optional) – The combine mode used to aggregate the aggregation operators result. Available modes are: "cat", obj:”proj”, "sum", "mean", "max", "min", "std", "var". (default: "cat")

  • mode_kwargs (dict, optional) – Optimal arguments passed to the mode function. (default: None)

class haiku_geometric.nn.aggr.SumAggregation(*args, **kwargs)[source]#

Attention Layers#

SelfAttention

Self attention with a causal mask applied.

Transformer

Transformer layer from the "Attention is all you need" paper.

class haiku_geometric.nn.attention.SelfAttention(num_heads, key_size, w_init_scale=None, *, w_init=None, value_size=None, model_size=None, name=None)[source]#

Self attention with a causal mask applied.

__call__(query, key=None, value=None, mask=None)[source]#
Return type

ndarray

class haiku_geometric.nn.attention.Transformer(num_heads, num_layers, key_size, dropout_rate, widening_factor=4)[source]#

Transformer layer from the “Attention is all you need” paper. Where each layer computes the following function:

\[\begin{split}h &= \mathrm{LayerNorm}(x) \\ h_a &= x + \mathrm{Dropout}(\mathrm{MultiHeadAttention}(h, h, h)) \\ \mathrm{Transformer}(h_a) &= h_a + \mathrm{Dropout}(\mathrm{DenseBlock}(\mathrm{LayerNorm}(h_a)))\end{split}\]
__call__(embeddings, *, mask=None, is_training=True)[source]#

Transforms input embedding sequences to output embedding sequences.

Parameters
  • embeddings (jnp.ndarray) – Input embedding sequences of shape [B, T, D], where B is the batch size, T is the sequence length, and D is the embedding dimension.

  • mask (jnp.ndarray, optional) – Mask for the input embedding sequences of shape [B, T] (default: None).

  • is_training (bool, optional) – Whether the model is in training mode or not (default: True).

Return type

ndarray

__init__(num_heads, num_layers, key_size, dropout_rate, widening_factor=4)[source]#
Parameters
  • num_heads (int) – Number of attention heads.

  • num_layers (int) – Number of layers.

  • key_size (int) – Size of the key and query vector.

  • dropout_rate (float) – Dropout rate.

  • widening_factor (int, optional) – Widening factor for the DenseBlock. (default: 4)