haiku_geometric.posenc
haiku_geometric.posenc#
Graph Laplacian Positional Encoder described in the "Rethinking Graph Transformers with Spectral Attention" paper. |
|
MagLapNet Positional Encoder described in the "Transformers Meet Directed Graphs" paper. |
- class haiku_geometric.posenc.LaplacianEncoder(dim, model, model_dropout=0.0, layers=1, heads=1, post_layers=1, norm=None, norm_decay=0.9)[source]#
Graph Laplacian Positional Encoder described in the “Rethinking Graph Transformers with Spectral Attention” paper.
Usage:
from haiku_geometric.utils import eigv_laplacian from haiku_geometric.posenc import LaplacianEncoder # Compute the eigenvectors of the Laplacian matrix eigenvalues, eigenvectors = eigv_laplacian( senders=senders, receivers=receivers, k=...) # The function that you will transform with Haiku def your_forward_function(...): # Create the encoder model model = LaplacianEncoder(...) # Encode the eignevalues and eigenvectors h = model(eigenvalues, eigenvectors, is_training)
- __call__(eigenvalues, eigenvectors, is_training, call_args=None)[source]#
- Parameters
eigenvalues (torch.Tensor) – Eigenvalues of the Laplacian matrix with shape
[K,]
.eigenvectors (torch.Tensor) – Eigenvectors of the Laplacian matrix with shape
[N, K]
.is_training (bool) – Whether the model is in training mode.
- Returns
Encoded features with shape
[N, dim]
.- Return type
torch.Tensor
- __init__(dim, model, model_dropout=0.0, layers=1, heads=1, post_layers=1, norm=None, norm_decay=0.9)[source]#
- Parameters
dim (int) – Dimension of the output features.
model (str) – Model to use for the encoder. Can be either
"Transformer"
or"DeepSet"
.model_dropout (float, optional) – Dropout rate for the model. (default:
0.0
).layers (int, optional) – Number of layers for the model. (default:
1
).heads (int, optional) – Number of heads for the model. Only used if
model="Transformer"
. (default:1
).post_layers (int, optional) – Number of post layers after the model. (default:
1
).norm (str, optional) – Normalization layer to use. Can be either
"batchnorm"
orNone
. (default:None
).norm_decay (float, optional) – Decay rate for the normalization layer. (default:
0.9
).
- class haiku_geometric.posenc.MagLaplacianEncoder(d_model_elem=32, d_model_aggr=256, num_heads=4, n_layers=1, dropout=0.2, activation=<jax._src.custom_derivatives.custom_jvp object>, return_real_output=True, consider_im_part=True, use_signnet=True, use_gnn=False, use_attention=False, concatenate_eigenvalues=False, norm=None)[source]#
MagLapNet Positional Encoder described in the “Transformers Meet Directed Graphs” paper. Positional encodings are computed using the Magnetic Laplacian matrix (Hermitian matrix).
Usage:
from haiku_geometric.utils import eigv_magnetic_laplace from haiku_geometric.posenc import MagLaplacianEncoder # Compute the eigenvectors of the Magnetic Laplacian matrix eigenvalues, eigenvectors = eigv_magnetic_laplace( senders=senders, receivers=receivers, k=...) # The function that you will transform with Haiku def your_forward_function(...): # Create the encoder model model = MagLaplacianEncoder(...) # Encode the eignevalues and eigenvectors h = model(senders, receivers, eigenvalues, eigenvectors, is_training)
- __call__(senders, receivers, eigenvalues, eigenvectors, is_training, call_args=None)[source]#
- Parameters
senders (jnp.ndarray) – indices of the senders nodes.
receivers (jnp.ndarray) – indices of the receivers nodes.
eigenvalues (torch.Tensor) – Eigenvalues of the Laplacian matrix with shape
[K,]
.eigenvectors (torch.Tensor) – Eigenvectors of the Laplacian matrix with shape
[N, K]
.is_training (bool) – Whether the model is in training mode.
- Returns
Encoded features with shape
[N, d_model_aggr]
.- Return type
torch.Tensor
- __init__(d_model_elem=32, d_model_aggr=256, num_heads=4, n_layers=1, dropout=0.2, activation=<jax._src.custom_derivatives.custom_jvp object>, return_real_output=True, consider_im_part=True, use_signnet=True, use_gnn=False, use_attention=False, concatenate_eigenvalues=False, norm=None)[source]#
- Parameters
d_model_elem (int, optional) – Embedding dimension for each element of the eigenvectors. (default:
32
).d_model_aggr (int, optional) – Dimension of the aggregation of all the elements of the eigenvectors. (default:
256
).num_heads (int, optional) – Number of attention heads (default:
4
).n_layers (int, optional) – Number of layers for the MLPs. (default:
1
).dropout (float, optional) – Dropout rate. (default:
0.2
).activation (Callable[[jnp.ndarray], jnp.ndarray], optional) – Activation function. (default:
jax.nn.relu
).return_real_output (bool, optional) – Whether to return only the real part of the output. (default:
True
).consider_im_part (bool, optional) – Whether to consider the imaginary part of the eigenvectors. (default:
True
).use_signnet (bool, optional) – Whether to use the SignNet, this is, each eigenvector \(\gamma_i\) is processed as \(f_{elem}(\gamma_i) + f_{elem}(-\gamma_i)\) where \(f_{elem}\) is an MLP or a GNN. (default:
True
).use_gnn (bool, optional) – Whether to use a GNN aggregate embeddings of the eigenvectors instead of an MLP. (default:
False
).use_attention (bool, optional) – Whether to apply a multi-head attention layer to the embeddings. (default:
False
).concatenate_eigenvalues (bool, optional) – Wheter to initially concatenate eignevalues to the eigenvectors. (default:
False
).norm (Callable[[jnp.ndarray], jnp.ndarray], optional) – Normalization layer (default:
None
).