haiku_geometric.posenc#

LaplacianEncoder

Graph Laplacian Positional Encoder described in the "Rethinking Graph Transformers with Spectral Attention" paper.

MagLaplacianEncoder

MagLapNet Positional Encoder described in the "Transformers Meet Directed Graphs" paper.

class haiku_geometric.posenc.LaplacianEncoder(dim, model, model_dropout=0.0, layers=1, heads=1, post_layers=1, norm=None, norm_decay=0.9)[source]#

Graph Laplacian Positional Encoder described in the “Rethinking Graph Transformers with Spectral Attention” paper.

Usage:

from haiku_geometric.utils import eigv_laplacian
from haiku_geometric.posenc import LaplacianEncoder

# Compute the eigenvectors of the Laplacian matrix
eigenvalues, eigenvectors = eigv_laplacian(
    senders=senders,
    receivers=receivers,
    k=...)

# The function that you will transform with Haiku
def your_forward_function(...):

    # Create the encoder model
    model = LaplacianEncoder(...)

    # Encode the eignevalues and eigenvectors
    h = model(eigenvalues, eigenvectors, is_training)
__call__(eigenvalues, eigenvectors, is_training, call_args=None)[source]#
Parameters
  • eigenvalues (torch.Tensor) – Eigenvalues of the Laplacian matrix with shape [K,].

  • eigenvectors (torch.Tensor) – Eigenvectors of the Laplacian matrix with shape [N, K].

  • is_training (bool) – Whether the model is in training mode.

Returns

Encoded features with shape [N, dim].

Return type

torch.Tensor

__init__(dim, model, model_dropout=0.0, layers=1, heads=1, post_layers=1, norm=None, norm_decay=0.9)[source]#
Parameters
  • dim (int) – Dimension of the output features.

  • model (str) – Model to use for the encoder. Can be either "Transformer" or "DeepSet".

  • model_dropout (float, optional) – Dropout rate for the model. (default: 0.0).

  • layers (int, optional) – Number of layers for the model. (default: 1).

  • heads (int, optional) – Number of heads for the model. Only used if model="Transformer". (default: 1).

  • post_layers (int, optional) – Number of post layers after the model. (default: 1).

  • norm (str, optional) – Normalization layer to use. Can be either "batchnorm" or None. (default: None).

  • norm_decay (float, optional) – Decay rate for the normalization layer. (default: 0.9).

class haiku_geometric.posenc.MagLaplacianEncoder(d_model_elem=32, d_model_aggr=256, num_heads=4, n_layers=1, dropout=0.2, activation=<jax._src.custom_derivatives.custom_jvp object>, return_real_output=True, consider_im_part=True, use_signnet=True, use_gnn=False, use_attention=False, concatenate_eigenvalues=False, norm=None)[source]#

MagLapNet Positional Encoder described in the “Transformers Meet Directed Graphs” paper. Positional encodings are computed using the Magnetic Laplacian matrix (Hermitian matrix).

Usage:

from haiku_geometric.utils import eigv_magnetic_laplace
from haiku_geometric.posenc import MagLaplacianEncoder

# Compute the eigenvectors of the Magnetic Laplacian matrix
eigenvalues, eigenvectors = eigv_magnetic_laplace(
    senders=senders,
    receivers=receivers,
    k=...)

# The function that you will transform with Haiku
def your_forward_function(...):

    # Create the encoder model
    model = MagLaplacianEncoder(...)

    # Encode the eignevalues and eigenvectors
    h = model(senders, receivers, eigenvalues, eigenvectors, is_training)
__call__(senders, receivers, eigenvalues, eigenvectors, is_training, call_args=None)[source]#
Parameters
  • senders (jnp.ndarray) – indices of the senders nodes.

  • receivers (jnp.ndarray) – indices of the receivers nodes.

  • eigenvalues (torch.Tensor) – Eigenvalues of the Laplacian matrix with shape [K,].

  • eigenvectors (torch.Tensor) – Eigenvectors of the Laplacian matrix with shape [N, K].

  • is_training (bool) – Whether the model is in training mode.

Returns

Encoded features with shape [N, d_model_aggr].

Return type

torch.Tensor

__init__(d_model_elem=32, d_model_aggr=256, num_heads=4, n_layers=1, dropout=0.2, activation=<jax._src.custom_derivatives.custom_jvp object>, return_real_output=True, consider_im_part=True, use_signnet=True, use_gnn=False, use_attention=False, concatenate_eigenvalues=False, norm=None)[source]#
Parameters
  • d_model_elem (int, optional) – Embedding dimension for each element of the eigenvectors. (default: 32).

  • d_model_aggr (int, optional) – Dimension of the aggregation of all the elements of the eigenvectors. (default: 256).

  • num_heads (int, optional) – Number of attention heads (default: 4).

  • n_layers (int, optional) – Number of layers for the MLPs. (default: 1).

  • dropout (float, optional) – Dropout rate. (default: 0.2).

  • activation (Callable[[jnp.ndarray], jnp.ndarray], optional) – Activation function. (default: jax.nn.relu).

  • return_real_output (bool, optional) – Whether to return only the real part of the output. (default: True).

  • consider_im_part (bool, optional) – Whether to consider the imaginary part of the eigenvectors. (default: True).

  • use_signnet (bool, optional) – Whether to use the SignNet, this is, each eigenvector \(\gamma_i\) is processed as \(f_{elem}(\gamma_i) + f_{elem}(-\gamma_i)\) where \(f_{elem}\) is an MLP or a GNN. (default: True).

  • use_gnn (bool, optional) – Whether to use a GNN aggregate embeddings of the eigenvectors instead of an MLP. (default: False).

  • use_attention (bool, optional) – Whether to apply a multi-head attention layer to the embeddings. (default: False).

  • concatenate_eigenvalues (bool, optional) – Wheter to initially concatenate eignevalues to the eigenvectors. (default: False).

  • norm (Callable[[jnp.ndarray], jnp.ndarray], optional) – Normalization layer (default: None).