haiku_geometric.utils
haiku_geometric.utils#
Batch a list of graphs into a single graph. |
|
Returns the unique edges in a graph. |
|
Returns the Laplacian of a graph. |
|
Returns the Laplacian of a graph. |
|
Returns the top-k eigenvectors of the Laplacian of a graph. |
|
k non-ptrivial complex eigenvectors of the smallest k eigenvectors of the magnetic laplacian. |
|
Pads the given graphs until they have the given number of nodes and edges. |
|
Random walk on the input graph. |
|
Returns the undirected version of a graph. |
|
Unbatch a graph into a list of graphs. |
- haiku_geometric.utils.batch(graphs)[source]#
Batch a list of graphs into a single graph.
- Parameters
graphs (
Sequence[DataGraphTuple]) – List ofhaiku_geometric.datasets.base.DataGraphTuple.- Return type
Tuple[DataGraphTuple,ndarray]- Returns
A single
haiku_geometric.datasets.base.DataGraphTuplecontaining the batched graphs.A jax.numpy.ndarray with indices indicating to which graph each node belongs to.
- haiku_geometric.utils.coalesce(senders, receivers, edge_attr=None, num_nodes=None, is_sorted=False, sort_by_row=True)[source]#
Returns the unique edges in a graph.
- Parameters
senders (jnp.ndarray) – The senders of the edges.
receivers (jnp.ndarray) – The receivers of the edges.
edge_attr (jnp.ndarray, optional) – The edge attributes. (default:
None)num_nodes (int, optional) – The number of nodes in the graph. (default:
None)is_sorted (bool, optional) – Whether senders and receiver are sorted row-wise. (default:
False)sort_by_row (bool, optional) – Whether to sort the edges by row. If
False, the edges will be sorted by column. (default:True)
- Returns
Tuple(jnp.ndarray, jnp.ndarray)with senders, receivers ifedge_attrisNone.Tuple(jnp.ndarray, jnp.ndarray, jnp.ndarray)with senders, receivers, edge_attr otherwise.
- haiku_geometric.utils.eigv_laplacian(senders, receivers, edge_weight=None, normalization=None, num_nodes=None, k=5, eigv_norm='L2')[source]#
Returns the top-k eigenvectors of the Laplacian of a graph.
- Parameters
senders (jnp.ndarray) – The senders of the edges.
receivers (jnp.ndarray) – The receivers of the edges.
edge_weight (jnp.ndarray) – The weight of each edge. (default:
None)normalization (str, optional) – The normalization to apply to the Laplacian. (default:
None). Available options are:None,"sym","rw".num_nodes (int, optional) – The number of nodes in the graph. (default:
None)k (int, optional) – The number of eigenvectors to return. (default:
5)eigv_norm (str, optional) – The normalization to apply to the eigenvectors. (default:
"L2"). Available options are: 1.None: No normalization. 2."L2": Normalize the eigenvectors to have unit L2 norm.
- Returns
(k,) eigenvalues.
(jnp.ndarray): (num_nodes, k) of eigenvector values per node.- Return type
(jnp.ndarray)
- haiku_geometric.utils.eigv_magnetic_laplacian(senders, receivers, n_node, k, k_excl, q=0.25, q_absolute=False, norm_comps_sep=False, l2_norm=True, sign_rotate=True, use_symmetric_norm=False)[source]#
- k non-ptrivial complex eigenvectors of the smallest k eigenvectors of the magnetic laplacian.
This implementation is from the paper “Transformers Meet Directed Graphs” paper
- Parameters
senders (jnp.ndarray) – Origin of the edges of shape [m].
receivers (jnp.ndarray) – Target of the edges of shape [m].
n_node (int) – Number of nodes in the graph.
k (int) – Returns top k eigenvectors.
k_excl (int) – The top (trivial) eigenvalues / -vectors to exclude.
q (float, optional) – Factor in magnetic laplacian. (default:
0.25)q_absolute (bool, optional) – If true q will be used, otherwise q / m_imag / 2. (default:
False)norm_comps_sep (bool, optional) – If true first imaginary part is separately normalized. (default:
False)l2_norm (bool, optional) – If true we use l2 normalization and otherwise the abs max value. Will be treated as false if norm_comps_sep is true. (default:
True)sign_rotate (bool, optional) – If true we decide on the sign based on max real values and rotate the imaginary part. (default:
True)use_symmetric_norm (bool, optional) – symmetric (True) or row normalization (False). (default:
False)
- Returns
(jnp.float64)list with arrays of shape [<= k] containing the k eigenvalues.(jnp.complex128)list with arrays of shape [n_node, <= k] containing the k eigenvectors.
- haiku_geometric.utils.get_laplacian(senders, receivers, edge_weight=None, normalization=None, num_nodes=None)[source]#
Returns the Laplacian of a graph. Unlike
get_laplacian_matrix(), this function performs the operations over the indices of the graph, rather than the adjacency matrix. Consequently, this function returns the indices and weights of the Laplacian.- Parameters
senders (jnp.ndarray) – The senders of the edges.
receivers (jnp.ndarray) – The receivers of the edges. (default:
None)edge_weight (jnp.ndarray) – The weight of each edge. (default:
None)normalization (str, optional) –
The normalization to apply to the Laplacian. (default:
None). Available options are:1.
None: No normalization. \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym": Symmetric normalization. \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw": Random-walk normalization. \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)num_nodes (int, optional) – The number of nodes in the graph. (default:
None)
- Returns
Senders, receivers and weights of the Laplacian.
- Return type
Tuple(jnp.ndarray, jnp.ndarray, jnp.ndarray)
- haiku_geometric.utils.get_laplacian_matrix(senders, receivers, edge_weight=None, normalization=None, num_nodes=None)[source]#
Returns the Laplacian of a graph.
- Parameters
senders (jnp.ndarray) – The senders of the edges.
receivers (jnp.ndarray) – The receivers of the edges. (default:
None)edge_weight (jnp.ndarray) – The weight of each edge. (default:
None)normalization (str, optional) –
The normalization to apply to the Laplacian. (default:
None). Available options are:1.
None: No normalization. \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym": Symmetric normalization. \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw": Random-walk normalization. \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)num_nodes (int, optional) – The number of nodes in the graph. (default:
None)
- Returns
The Laplacian of the graph.
- haiku_geometric.utils.pad_graph(data_graph, n_nodes, n_edges, batch=None)[source]#
Pads the given graphs until they have the given number of nodes and edges. New nodes and edges are added in a new batch.
- Args:
data_graph (
haiku_geometric.datasets.base.DataGraphTuple.): The graph to pad. n_nodes (int): The number of nodes to pad to. n_edges (int): The number of edges to pad to. batch (int, optional): Batch indexes of the given graph. It will be updated to include the new batch.- Returns:
A single
haiku_geometric.datasets.base.DataGraphTuplecontaining the padded graph.A jax.numpy.ndarray with boolean values indicating which nodes are old (True) and which are new (False).
A jax.numpy.ndarray with the new batch indexes.
A ‘int’ with the new number of batches (i.e. the old number of batches + 1).
- haiku_geometric.utils.random_walk(senders, receivers, walk_length, p=1, q=1, num_nodes=None)[source]#
Random walk on the input graph.
- Parameters
senders (jnp.ndarray) – Array of sender nodes.
receivers (jnp.ndarray) – Array of receiver nodes.
walk_length (int) – Length of the random walk.
p (float) – Likelihood of returning to a previous node in the walk (default: 1).
q (float) – Parameter to interpolate between breadth-first strategy and depth-first strategy (default: 1).
num_nodes (int) – Number of nodes in the graph (default: None).
- haiku_geometric.utils.to_undirected(senders, receivers, edge_attr=None, num_nodes=None)[source]#
Returns the undirected version of a graph.
- Parameters
senders (jnp.ndarray) – The senders of the edges.
receivers (jnp.ndarray) – The receivers of the edges.
edge_attr (jnp.ndarray, optional) – The edge attributes. (default:
None)num_nodes (int, optional) – The number of nodes in the graph. (default:
None)
- Returns
Tuple(jnp.ndarray, jnp.ndarray)with senders, receivers ifedge_attrisNone.Tuple(jnp.ndarray, jnp.ndarray, jnp.ndarray)with senders, receivers, edge_attr otherwise.
- haiku_geometric.utils.unbatch(graph)[source]#
Unbatch a graph into a list of graphs.
- Parameters
graph (
DataGraphTuple) – A graphhaiku_geometric.datasets.base.DataGraphTupleto unbatch.- Return type
Sequence[DataGraphTuple]- Returns
A list of
haiku_geometric.datasets.base.DataGraphTuplecontaining the unbatched graphs.