haiku_geometric.nn
Contents
haiku_geometric.nn#
Convolutional Layers#
The graph convolutional operator from the "Semi-supervised Classification with Graph Convolutional Networks" paper |
|
The graph neural network operator from the "Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" paper |
|
A general GNN layer adapted from the "Design Space for Graph Neural Networks" paper. |
|
The graph isomorphism operator from the "How Powerful are Graph Neural Networks?" paper |
|
Graph Isomorphism operator introduced to include edge features from "Strategies for Pre-training Graph Neural Networks" paper |
|
Graph attention layer from "Graph Attention Networks" paper |
|
The GraphSAGE operator from the "Inductive Representation Learning on Large Graphs" paper |
|
The gated graph convolution operator from the "Gated Graph Sequence Neural Networks" paper |
|
The Principal Neighbourhood Aggregation graph convolution operator from the "Principal Neighbourhood Aggregation for Graph Nets" paper. |
|
GPS layer from the "Recipe for a General, Powerful, Scalable Graph Transformer" paper. |
|
The edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper. |
|
A meta layer for building any kind of graph network, inspired by the "Relational Inductive Biases, Deep Learning, and Graph Networks" paper. |
- class haiku_geometric.nn.conv.EdgeConv(nn, aggr='max')[source]#
The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper.
The node features are computed as follows:
\[\mathbf{h}^{k + 1}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{h}_i \, \Vert \, \mathbf{h}_j - \mathbf{h}_i)\]where \(h_{\mathbf{\Theta}}\) denotes a neural network, e.g. a MLP.
- Parameters
nn (hk.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
h
of shape[-1, 2 * in_channels]
to shape[-1, out_channels]
.aggr (string, optional) – The aggregation operator (
"add"
,"mean"
,"max"
). (default:"max"
)
- class haiku_geometric.nn.conv.GATConv(out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0.0, dropout_nodes=0.0, add_self_loops=True, bias=True, init=None)[source]#
Graph attention layer from “Graph Attention Networks” paper
where each node’s output feature is computed as follows:
\[\vec{h}_{i}^{\prime}=\sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j} \mathbf{W} \vec{h}_{j}\right)\]where the attention coefficients are computed as:
\[\alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{k}\right]\right)\right)}\]When multiple attention heads are used, the output nodes features are averaged:
\[\vec{h}_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)\]If concat=True the output feature is the concatenation of the \(K\) heads features:
\[\vec{h}_{i}^{\prime}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)\]- Parameters
out_channels (int) – Size of the output features produced by the layer for each node.
heads (int, optional) – Number of head attentions. (default:
1
)concat (bool, optional) – If
False
, the multi-head features are averaged else concatenated. (default:True
)negative_slope (float, optional) – scalar specifying the negative slope of the LeakyReLU. (default:
0.2
)add_self_loops (bool, optional) – If
True
, will add a self-loop for each node of the graph. (default:True
)dropout (float, optional) – Dropout applied to attention weights. This dropout simulates random sampling of the neigbours. (default:
0.0
)dropout_nodes (float, optional) – Dropout applied initially to the input features. (default:
0.0
)bias (bool, optional) – If
True
, the layer will add an additive bias to the output. (default:True
)init (hk.initializers.Initializer) – Weights initializer (default:
hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")
)
- __call__(in_nodes_features, senders, receivers, edges=None, num_nodes=None, training=False)[source]#
- Return type
ndarray
- class haiku_geometric.nn.conv.GCNConv(out_channels, improved=False, cached=False, add_self_loops=True, normalize=True, bias=True, aggr='add')[source]#
The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper
\[H^{(l+1)} = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W\]where \(\tilde{A} = A + I_N\) is the adjacency matrix with added self loops \(I_N\). and \(\tilde{D}_ii = \sum_j \tilde{A}_ij\).
The node-wise formulation is given by:
\[\mathbf{h}_u = W^{\top} \sum_{v \in \mathcal{N}(u) \cup \{ u \}} \frac{e_{v,u}}{\sqrt{\hat{d}_u \hat{d}_v}} \mathbf{h}_v\]where \(e_{v,u}\) is the edge weight and \(\hat{d}_u = 1 + \sum_{v \in \mathcal{N}(u)} e_{v,u}\)
- Parameters
out_channels (int) – Size of the output features of each node.
improved (bool, optional) – If
True
, then \(\mathbf{\hat{A}}\) is computed as \(\mathbf{A} + 2\mathbf{I}\). (default:False
)cached (bool, optional) – If
True
, the value \(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\) on first execution, and will use the is cached and used in further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If
True
, will add a self-loop for each node of the graph. (default:True
)normalize (bool, optional) – Whether to compute and apply the symmetric normalization. (default:
True
)bias (bool, optional) – If
True
, the layer will add an additive bias to the output. (default:True
)
- class haiku_geometric.nn.conv.GINConv(nn, eps=0.0, train_eps=False)[source]#
The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper
The node features are computed as follows:
\[\mathbf{{h}}_{u}^{k}= \phi\left( (1 + \epsilon) \mathbf{{h}}_{u}^{k-1} + \sum_{v \in \mathcal{N}(u)} \mathbf{{h}}_{v}^{k-1}\right)\]where \(\phi\) is a neural network (e.g. a MLP).
- Parameters
nn (hk.Module) – A neural network \(\phi\) that produces output features of shape
out_channels
defined by the user.eps (float, optional) – \(\epsilon\) value. (default:
0.
)train_eps (bool, optional) – If
True
, \(\epsilon\) will be a trainable parameter. (default:False
)
- class haiku_geometric.nn.conv.GINEConv(nn, eps=0.0, train_eps=False, edge_dim=None)[source]#
Graph Isomorphism operator introduced to include edge features from “Strategies for Pre-training Graph Neural Networks” paper
The node features are computed as follows:
\[\mathbf{{h}}_{u}^{k}= \phi\left( (1 + \epsilon) \mathbf{{h}}_{u}^{k-1} + \sum_{v \in \mathcal{N}(u)} ReLU(\mathbf{{h}}_{v}^{k-1} + \mathbf{e}_{u,v}) \right)\]where \(\phi\) is a neural network (e.g. a MLP) and \(\mathbf{e}_{j,i}\) are the edge features.
- Parameters
nn (hk.Module) – A neural network \(\phi\) that produces output features of shape
out_channels
defined by the user.eps (float, optional) – \(\epsilon\) value. (default:
0.
)train_eps (bool, optional) – If
True
, \(\epsilon\) will be a trainable parameter. (default:False
)edge_dim (int, optional) – If
None
, edge and node features shapes are expected to match. Otherwise, the edge features are linearly transformed to match the node features shape.
- class haiku_geometric.nn.conv.GPSLayer(dim_h, local_gnn_type, global_model_type, act, num_heads=1, pna_degrees=None, equivstable_pe=False, dropout=0.0, attn_dropout=0.0, layer_norm=False, batch_norm=True)[source]#
GPS layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.
🚧: This layer is still under development and might not work as expected.
- Parameters
dim_h (int) – Size of each output features.
local_gnn_type (str) – Name of a message passing neural network. Available networks are:
None
,"GINE"
,"GAT"
,"PNA"
.global_model_type (str) – Name of a global attention layer. Available networks are:
None
,"Transformer"
,"Performer"
.act (
Callable
) – (Callable): activation function (e.g.jax.nn.relu
).num_heads (int, optional) – number of heads when using multi-head attention. (default:
1
).pna_degrees (jnp.ndarray, optional) – Array of degrees histogram when using PNA.
equivstable_pe (bool, optional) –
Not implemented *.
(default:
False
).dropout (float, optional) – dropout rate. (default:
0.0
).attn_dropout (float, optional) – dropout rate with global attention. (default:
0.0
).layer_norm (bool, optional) – Whether to use layer normalization. (default:
False
).batch_norm (bool, optional) – Whether to use batch normalization. (default:
True
).
- class haiku_geometric.nn.conv.GatedGraphConv(out_channels, num_layers, aggr='add')[source]#
The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper
where the output features are computed as follows:
\[\mathbf{{h}}_{u}^{(0)} = \mathbf{{h}}_{u}^{(0)} \Vert \mathbf{0}\]for layer \(k: 1,...,L\):
\[\begin{split}\mathbf{{m}}_{u}^{(k)} &= \text{AGGREGATE}(\{e_{u, v} \cdot \mathbf{{W}} \cdot \mathbf{{h}}_{v}^{(k - 1)}, \forall v \in \mathcal{N}(u)\}) \\ \mathbf{{h}}_{u}^{(k)} &= GRU(\mathbf{{m}}_{u}^{(k)}, \mathbf{{h}}_{u}^{(k - 1)})\end{split}\]with \(AGGREGATE\) being the aggregation operator (i.e.
"mean"
,"max"
, or"add"
).- Parameters
out_channels (int) – Size of the output features of each node.
num_layers (int) – Number of layers \(L\).
aggr (string, optional) – Aggregation operator (
"add"
,"mean"
,"max"
). (default:"add"
)
- class haiku_geometric.nn.conv.GeneralConv(out_channels, in_edge_channels=None, aggr='add', skip_linear=False, directed_msg=True, heads=1, attention=False, attention_type='additive', l2_normalize=False, bias=True)[source]#
A general GNN layer adapted from the “Design Space for Graph Neural Networks” paper.
where the output features are computed as follows:
\[\mathbf{{h}}_{u}^{k}= \text{AGGREGATE}(\{\mathbf{m}_{u,v}, \forall v \in \mathcal{N}(u)\})\]with \(AGGREGATE\) being the aggregation operator (i.e.
"mean"
,"max"
, or"add"
) and each message \(\mathbf{m}_{u,v}\) is computed as:\[\mathbf{m}_{u,v} = \mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1}\]If
directed_msg=True
, the message is bidirectional:\[\mathbf{m}_{u,v} = \mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1} + \mathbf{W}_2 \cdot \mathbf{{h}}_{u}^{k-1}\]If
in_edge_channels
is notNone
, the edge features are also added to the message:\[\mathbf{m}_{u,v} = \mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1} + \mathbf{W}_3 \cdot \mathbf{{e}}_{u, v}\]If
attention=True
, attention is performed on the message computation:\[\mathbf{m}_{u,v} = \alpha_{u,v}(\mathbf{W}_1 \cdot \mathbf{{h}}_{v}^{k-1})\]where the attention coefficient \(\alpha_{u,v}\) is computed as follows:
\[\alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} + \mathbf{W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} + \mathbf{W} \vec{h}_{k}\right]\right)\right)}\]If
skip_linear=True
a skip connection is added to the output:\[\begin{split}\mathbf{{h}}_{u}^{k}= \text{AGGREGATE}(\{\mathbf{m}_{u,v}, \forall v \in \mathcal{N}(u)\}) + \mathbf{W}_4 \cdot \mathbf{{h}}_{u}^{k-1} \\\end{split}\]- Parameters
out_channels (int) – Size of the output features of a node.
in_edge_channels (int, optional) – Size of each edge features. (default:
None
)aggr (string or Aggregation, optional) – The aggregation operator. Available values are:
"mean"
,"max"
, or"add"
. (default:"add"
)skip_linear (bool, optional) – (default:
False
)directed_msg (bool, optional) – (default:
True
)heads (int, optional) – Number of head attentions. If (
attention=True
) and (heads > 1
) the multi-head features are mean aggregated. (default:1
)attention (bool, optional) – perform attention over the messages (default:
False
)attention_type (str, optional) – Type of attention:
"additive"
,"dot_product"
. (default:"additive"
)l2_normalize (bool, optional) – If
True
, output features are \(\ell_2\)-normalized. (default:False
)bias (bool, optional) – If
True
, linear transformation also add bias. (default:True
)
- class haiku_geometric.nn.conv.GraphConv(out_channels, aggr='add', bias=True)[source]#
The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper
The node features are computed as follows:
\[\mathbf{{h}}_{u}^{k}=\mathbf{W}_1 \cdot \mathbf{{h}}_{u}^{k-1} + \mathbf{W}_2 \cdot \text{AGGREGATE}(\{e_{u, v} \cdot \mathbf{{h}}_{v}^{k-1}, \forall v \in \mathcal{N}(u)\})\]with \(AGGREGATE\) being the aggregation operator (i.e.
"mean"
,"max"
, or"add"
)- Parameters
out_channels (int) – Size of the output features of a node.
aggr (string or Aggregation, optional) – The aggregation operator. Available values are:
"mean"
,"max"
, or"add"
. (default:"add"
)bias (bool, optional) – If
True
, the layer will add an additive bias to the output. (default:True
)
- class haiku_geometric.nn.conv.MetaLayer(edge_model=None, node_model=None, global_model=None)[source]#
A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.
General graph network that takes as input the nodes features
nodes
, the edges featuresedge_attr
, the senders nodes indicessenders
, the receivers nodes indicesreceivers
and the global features of the graphglobals
.It returns the updated nodes features
nodes
, the updated edges featuresedge_attr
and the updated global featuresglobals
.Nodes features are updated after calling the node model
node_model
, edges features are updated after calling the edge modeledge_model
and global features are updated after calling the global modelglobal_model
.- Parameters
edge_model (hk.Module, optional) –
A neural network that updates its edge features based on its source and target nodes features. It receives as input:
senders features of shape
[E, F_N]
whereE
is the number of edges andF_N
the number of input node features.receivers features of shape
[E, F_N]
whereE
is the number of edges andF_N
the number of input node features.edges features of shape
[E, F_E]
whereE
is the number of edges andF_E
the number of input edge features.globals features of shape
[F_G]
for non-batched graphs or shape[G * F_G]
for batched graphs, whereG
is the number of graphs andF_G
the shape of the global features.batch indices of shape
[N]
, whereN
is the number of nodes. This array indicates to which graph each node belongs to.
node_model (hk.Module, optional) –
A neural network that updates the nodes features based on the current node features, edge features and global features. It receives as input:
nodes features of shape
[N, F_N]
whereN
is the number of nodes andF_N
the number of input node features.senders features of shape
[E, F_N]
whereE
is the number of edges andF_N
the number of input node features.receivers features of shape
[E, F_N]
whereE
is the number of edges andF_N
the number of input node features.edges features of shape
[E, F_E]
whereE
is the number of edges andF_E
the number of input edge features.globals features of shape
[F_G]
for non-batched graphs or shape[G * F_G]
for batched graphs, whereG
is the number of graphs andF_G
the shape of the global features.batch indices of shape
[N]
, whereN
is the number of nodes. This array indicates to which graph each node belongs to.
global_model (hk.Module, optional) –
A neural network that updates a graph global features based on the current nodes features, edges features and global features. It receives as input:
nodes features of shape
[N, F_N]
whereN
is the number of nodes andF_N
the number of input node features.senders features of shape
[E, F_N]
whereE
is the number of edges andF_N
the number of input node features.receivers features of shape
[E, F_N]
whereE
is the number of edges andF_N
the number of input node features.edges features of shape
[E, F_E]
whereE
is the number of edges andF_E
the number of input edge features.globals features of shape
[F_G]
for non-batched graphs or shape[G * F_G]
for batched graphs, whereG
is the number of graphs andF_G
the shape of the global features.batch indices of shape
[N]
, whereN
is the number of nodes. This array indicates to which graph each node belongs to.
- Returns
The updated nodes features if
node_model
is notNone
.The updated edges features if
edge_model
is notNone
.The updated globals features if
global_model
is notNone
.
- Return type
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Examples:
import haiku as hk from haiku_geometric.nn.aggr.utils import aggregation class EdgeModel(hk.Module): def __init__(self): super().__init__() self.mlp = hk.Sequential([hk.Linear(...), jax.nn.relu, hk.Linear(...)]) def __call__(self, senders_features, receivers_features, edges_features, globals, batch, num_nodes=None): h = jnp.concatenate([senders_features, receivers_features, edges_features], axis=-1) return self.mlp(h) class NodeModel(hk.Module): def __init__(self): super().__init__() self.aggr = aggregation('mean') self.mlp = hk.Sequential([hk.Linear(...), jax.nn.relu, hk.Linear(...)]) def __call__(self, nodes, senders, receivers, edge_attr, globals, batch, num_nodes=None): h = jnp.concatenate([nodes[senders], edge_attr], axis=1) messages = self.mlp(h) total_num_nodes = tree.tree_leaves(nodes)[0].shape[0] return self.aggr(messages, receivers, total_num_nodes) class GlobalModel(hk.Module): def __init__(self): super().__init__() self.mlp = hk.Sequential([hk.Linear(...), jax.nn.relu, hk.Linear(...)]) def __call__(self, nodes, senders, receivers, edge_attr, globals, batch, num_nodes=None): return self.mlp(globals)
- __call__(nodes, senders, receivers, edge_attr=None, globals=None, batch=None, num_nodes=None)[source]#
- Return type
Tuple
[ndarray
,ndarray
,ndarray
]
- __init__(edge_model=None, node_model=None, global_model=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.
- class haiku_geometric.nn.conv.PNAConv(out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False)[source]#
The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper.
Out features are computed as follows:
\[\mathbf{h}_u^{(k+1)} = U \left( \mathbf{h}_u^{(k)}, \underset{v \in \mathcal{N}(u)}{\bigoplus} M \left( \mathbf{h}_u^{(k)}, \mathbf{h}_v^{(k)} \right) \right)\]with \(M\) and \(U\) being MLPs, and:
\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]
- class haiku_geometric.nn.conv.SAGEConv(out_channels, aggr='mean', normalize=False, root_weight=True, project=False, bias=True)[source]#
The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper
The node features are computed as follows:
\[\mathbf{{h}}_{u}^{k}=\mathbf{W}_1 \cdot \mathbf{{h}}_{u}^{k-1} + \mathbf{W}_2 \cdot \text{CONCAT}(\mathbf{{h}}_{u}^{k-1}, \mathbf{{h}}_{\mathcal{N}(u)}^{k})\]having:
\[\mathbf{{h}}_{\mathcal{N}(u)}^{k} = \text{AGGREGATE}(\{\mathbf{{h}}_{v}^{k-1}, \forall v \in \mathcal{N}(u)\})\]and \(AGGREGATE\) being an aggregation operator (i.e.
"mean"
,"max"
, or"sum"
)If
project = True
, then \(\mathbf{{h}}_{u}^{k-1}\) is first projected via:\[\mathbf{{h}}_{v}^{k-1}=\text{ReLU}(\mathbf{W}_3 \mathbf{{h}}_{v}^{k-1} + \mathbf{b})\]- Parameters
out_channels (int) – Size of the output features of a node.
aggr (string or Aggregation, optional) – The aggregation operator. Available values are:
"mean"
,"max"
, or"sum"
. (default:"mean"
)normalize (bool, optional) – If
True
, output features are \(\ell_2\)-normalized. (default:False
)root_weight (bool, optional) – If
False
the linear transformed features \(\mathbf{W}_1 \cdot \mathbf{{h}}_{u}^{k-1}\) are not added to the output features. (default:True
)project (bool, optional) – If
True
, neighbour features are projected before aggregation as explained above. (default:False
)bias (bool, optional) – If
True
, the layer will add an additive bias to the output. (default:True
)
Pooling Layers#
Returns the sum of all node features of the input graph: |
|
Returns the average of all node features of the input graph: |
|
Returns the maximum across the input features. |
|
Topk pooling operator from the "Graph U-Nets" and "Towards Sparse Hierarchical Graph Classifiers" paper. |
- class haiku_geometric.nn.pool.TopKPooling(in_channels, ratio=0.5, multiplier=1.0)[source]#
Topk pooling operator from the “Graph U-Nets” and “Towards Sparse Hierarchical Graph Classifiers” paper.
- Parameters
in_channels (int) – Dimension of input node features.
ratio (
Union
[int
,float
]) – (Union[int, float], optional): Ratio of nodes to keep. If int, the number of nodes to keep. (default:0.5
).multiplier (float, optional) – Multiplier to scale the features after pooling. (default:
1.
).
- __call__(x, senders, receivers, edges=None, batch=None, create_new_batch=False, batch_size=None, max_num_nodes=None)[source]#
- Parameters
x (jnp.ndarray) – Node features of shape
[num_nodes, in_channels]
.senders (jnp.ndarray) – Sender indices.
receivers (jnp.ndarray) – Receiver indices.
edges (jnp.ndarray, optional) – Edge features of shape
[num_edges, in_channels]
. (default:None
).batch (jnp.ndarray, optional) – Batch array with batch indexes for each node. Shape:
[num_nodes]
. Note: This array should be sorted in increasing order. (default:None
).create_new_batch (bool, optional) – If set to
False
, nodes that are not top-k selected and their edges are removed from the graph. If set toTrue
, the nodes are kept in the graph, but they are assigned to a new batch with valuebatch_size + 1
. Their corresponding edges are transformed in self-loops. Note: IfTrue
, the output sizes ofx
,batch
,senders
,receivers
andedges
stay the same. IfFalse
output sizes ofx
,batch
,senders
,receivers
andedges
might be reduced according to theratio
parameter. (default:False
).batch_size (int, optional) – Number of batched graphs. If not given, it is automatically computed as
batch.max() + 1
. (default:None
).max_num_nodes (int, optional) – Maximum number of nodes that a graph can have. If not given, it is automatically computed as
batch.shape[0]
. (default:None
).
- Returns
The updated nodes features.
The updated senders indices.
The updated receivers indices.
The updated edges features.
The updated batch array.
- Return type
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
- Observations:
To make this layer jit-able, it requires providing parameters
create_new_batch=True
andbatch_size
as static parameters.
- haiku_geometric.nn.pool.global_add_pool(x, batch=None, num_segments=None)[source]#
Returns the sum of all node features of the input graph:
\[\mathbf{r} = \sum_{i=1}^{N} \mathbf{h}_i.\]- Parameters
x (jax.numpy.ndarray) – Node features array.
batch (jax.numpy.ndarray, optional) – Batch vector with indices that indicate to which graph each node belongs. (default:
None
).num_segments (int, optional) – Number of segments in
batch
. (default:None
)
- Returns
Array with the sum of the nodes features. If
batch
is notNone
, the output array will have shape[batch_size, *]
, where*
denotes the remaining dimensions.- Return type
(jax.numpy.ndarray)
- haiku_geometric.nn.pool.global_max_pool(x, batch=None, num_segments=None)[source]#
Returns the maximum across the input features. The maximum is performed individually over each channel.
\[\mathbf{r} = \max_{i=1}^{N} \mathbf{h}_i.\]- Parameters
x (jax.numpy.ndarray) – Node features array.
batch (jax.numpy.ndarray, optional) – Batch vector with indices that indicate to which graph each node belongs. (default:
None
).num_segments (int, optional) – Number of segments in
batch
. (default:None
)
- Returns
Array with the average of the nodes features. If
batch
is notNone
, the output array will have shape[batch_size, *]
, where*
denotes the remaining dimensions.- Return type
(jax.numpy.ndarray)
- haiku_geometric.nn.pool.global_mean_pool(x, batch=None, num_segments=None)[source]#
Returns the average of all node features of the input graph:
\[\mathbf{r} = \frac{1}{N} \sum_{i=1}^{N} \mathbf{h}_i.\]- Parameters
x (jax.numpy.ndarray) – Node features array.
batch (jax.numpy.ndarray, optional) – Batch vector with indices that indicate to which graph each node belongs. (default:
None
).num_segments (int, optional) – Number of segments in
batch
. (default:None
)
- Returns
Array with the average of the nodes features. If
batch
is notNone
, the output array will have shape[batch_size, *]
, where*
denotes the remaining dimensions.- Return type
(jax.numpy.ndarray)
Dense Layers#
Applies a linear module on the input features |
- class haiku_geometric.nn.dense.Linear(out_channels, bias=True, weight_initializer=None, bias_initializer=None)[source]#
Applies a linear module on the input features
\[\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}\]- Parameters
out_channels (int) – Size of each output features.
bias (bool, optional) – Whether to add a bias to the output. (default:
True
)weight_initializer (
Optional
[Callable
[[Sequence
[int
],Any
],ndarray
]]) – Optional initializer for weights. By default, uses random values from truncated normal, with stddev 1 / sqrt(fan_in).bias_initializer (
Optional
[Callable
[[Sequence
[int
],Any
],ndarray
]]) – Optional initializer for the bias. Default to zeros. (default:None
)
Agregation Operators#
Performs the aggregation of multiple aggregators operators. |
|
Combines ,multiple aggregation operators and transforms its outputs with scalers as described in the: "Principal Neighbourhood Aggregation for Graph Nets" paper. |
- class haiku_geometric.nn.aggr.DegreeScalerAggregation(aggr, scaler, deg, train_norm=False)[source]#
Combines ,multiple aggregation operators and transforms its outputs with scalers as described in the: “Principal Neighbourhood Aggregation for Graph Nets” paper. This aggregation is used in the
PNAConv
convolution layer.- Parameters
aggr (string or list or Aggregation) – Aggregation or list of aggregation operators to be used.
scaler (str or list) – List of scaling function identifiers. Available scalers are:
"identity"
,"amplification"
,"attenuation"
,"linear"
and"inverse_linear"
.deg (jnp.ndarray) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.
train_norm (bool, optional) – (default:
False
)
- class haiku_geometric.nn.aggr.MultiAggregation(aggrs, aggrs_kwargs=None, mode='cat', mode_kwargs=None)[source]#
Performs the aggregation of multiple aggregators operators. The aggregation is performed according to the
mode
parameter.- Parameters
aggrs (list) – List of aggregation operators.
aggrs_kwargs (list, optional) – Optional arguments passed to the aggregator operator function. This must be a list of dictionaries of length equal to the length of
aggrs
parameter. (default:None
)mode (string, optional) – The combine mode used to aggregate the aggregation operators result. Available modes are:
"cat"
, obj:”proj”,"sum"
,"mean"
,"max"
,"min"
,"std"
,"var"
. (default:"cat"
)mode_kwargs (dict, optional) – Optimal arguments passed to the mode function. (default:
None
)
Attention Layers#
Self attention with a causal mask applied. |
|
Transformer layer from the "Attention is all you need" paper. |
- class haiku_geometric.nn.attention.SelfAttention(num_heads, key_size, w_init_scale=None, *, w_init=None, value_size=None, model_size=None, name=None)[source]#
Self attention with a causal mask applied.
- class haiku_geometric.nn.attention.Transformer(num_heads, num_layers, key_size, dropout_rate, widening_factor=4)[source]#
Transformer layer from the “Attention is all you need” paper. Where each layer computes the following function:
\[\begin{split}h &= \mathrm{LayerNorm}(x) \\ h_a &= x + \mathrm{Dropout}(\mathrm{MultiHeadAttention}(h, h, h)) \\ \mathrm{Transformer}(h_a) &= h_a + \mathrm{Dropout}(\mathrm{DenseBlock}(\mathrm{LayerNorm}(h_a)))\end{split}\]- __call__(embeddings, *, mask=None, is_training=True)[source]#
Transforms input embedding sequences to output embedding sequences.
- Parameters
embeddings (jnp.ndarray) – Input embedding sequences of shape
[B, T, D]
, whereB
is the batch size,T
is the sequence length, andD
is the embedding dimension.mask (jnp.ndarray, optional) – Mask for the input embedding sequences of shape
[B, T]
(default:None
).is_training (bool, optional) – Whether the model is in training mode or not (default:
True
).
- Return type
ndarray
- __init__(num_heads, num_layers, key_size, dropout_rate, widening_factor=4)[source]#
- Parameters
num_heads (int) – Number of attention heads.
num_layers (int) – Number of layers.
key_size (int) – Size of the key and query vector.
dropout_rate (float) – Dropout rate.
widening_factor (int, optional) – Widening factor for the DenseBlock. (default:
4
)