Graph_components

Modularized graph neural network components using PyTorch Geometric.

class conformer_rl.models.graph_components.MPNN(hidden_dim: int, edge_dim: int, node_dim: int, message_passing_steps: int = 6)

Implements a basic unit of the message passing neural network (MPNN) 1.

Parameters
  • hidden_dim (int) – Dimension of the hidden layer.

  • edge_dim (int) – Dimension of the edge embeddings in the input graph.

  • node_dim (int) – Dimension of the node embeddings in the input graph.

  • message_passing_steps (int) – Number of message passing steps to execute. See 1 for more details.

References

1(1,2)

MPNN paper

class conformer_rl.models.graph_components.GAT(hidden_dim: int, node_dim: int, num_layers: int = 6)

Implements a basic unit of the graph attention network (GAT) 2.

Parameters
  • hidden_dim (int) – Dimension of the hidden layer.

  • node_dim (int) – Dimension of the node embeddings in the input graph.

  • num_layers (int) – Number of GAT conv layers. See 2 for more details.

References

2(1,2)

GAT paper