haiku_geometric.models#

MLP

This is just the Haiku MLP extended with layer normalization.

Node2Vec

The Node2vec model from the paper: "node2vec: Scalable Feature Learning for Networks" paper.

class haiku_geometric.models.MLP(output_sizes, w_init=None, b_init=None, with_bias=True, with_layer_norm=False, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]#

This is just the Haiku MLP extended with layer normalization.

Parameters
  • output_sizes (Iterable[int]) – Sequence of layer sizes.

  • w_init (Optional[Callable[[Sequence[int], Any], ndarray]]) – Initializer for haiku.Linear weights.

  • b_init (Optional[Callable[[Sequence[int], Any], ndarray]]) – Initializer for haiku.Linear bias. Must be None if with_bias=False.

  • with_bias (bool) – Whether or not to apply a bias in each layer.

  • with_layer_norm (bool) – Whether or not to apply layer normalization in each layer.

  • activation (Callable[[ndarray], ndarray]) – Activation function to apply between Linear layers. Defaults to ReLU.

  • activate_final (bool) – Whether or not to activate the final layer of the MLP.

  • name (Optional[str]) – Optional name for this module.

Raises

ValueError – If with_bias is False and b_init is not None.

__call__(inputs, dropout_rate=None, rng=None)[source]#

Connects the module to some inputs.

Parameters
  • inputs (ndarray) – A Tensor of shape [batch_size, input_size].

  • dropout_rate (Optional[float]) – Optional dropout rate.

  • rng – Optional RNG key. Require when using dropout.

Return type

ndarray

Returns

The output of the model of size [batch_size, output_size].

__init__(output_sizes, w_init=None, b_init=None, with_bias=True, with_layer_norm=False, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]#

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

class haiku_geometric.models.Node2Vec(senders, receivers, embedding_dim, walk_length, context_size, walks_per_node=1, p=1.0, q=1.0, num_negative_samples=1, num_nodes=None, rng=DeviceArray([0, 42], dtype=uint32))[source]#

The Node2vec model from the paper: “node2vec: Scalable Feature Learning for Networks” paper.

Parameters
  • senders (jnp.ndarray) – The source nodes of the graph.

  • receivers (jnp.ndarray) – The target nodes of the graph.

  • embedding_dim (int) – The dimensionality of the node embeddings.

  • walk_length (int) – The length of the random walk.

  • context_size (int) – Context size considered for positive sampling.

  • walks_per_node (int, optional) – Number of walks per node. (default: 1)

  • p (float, optional) – Likelihood of revisiting a node in the walk. (default: 1.0).

  • q (float, optional) – Control parameter to interpolate between breadth-first strategy and depth-first strategy. (default: 1.0).

  • num_negative_samples (int, optional) – Number of negative samples for each positive sample. (default: 1).

  • num_nodes (int, optional) – The number of nodes in the graph. (default: None).

  • rng (jax.random.PRNGKey, optional) – The random number generator seed. (default: jax.random.PRNGKey(42)).

Attributes:

  • embedding (jnp.ndarray): Embeddings of the node2vec model.

__call__()[source]#

This is the loss function of the node2vec model.

Returns

  • Current embeddings of the model (jnp.ndarray).

  • Loss computed in this forward call.

__init__(senders, receivers, embedding_dim, walk_length, context_size, walks_per_node=1, p=1.0, q=1.0, num_negative_samples=1, num_nodes=None, rng=DeviceArray([0, 42], dtype=uint32))[source]#
neg_sample(batch)[source]#

Returns negative samples.

pos_sample(batch)[source]#

Returns positive samples.

sample(batch)[source]#