Graph_components
Modularized graph neural network components using PyTorch Geometric.
- class conformer_rl.models.graph_components.MPNN(hidden_dim: int, edge_dim: int, node_dim: int, message_passing_steps: int = 6)
Implements a basic unit of the message passing neural network (MPNN) 1.
- Parameters
hidden_dim (int) – Dimension of the hidden layer.
edge_dim (int) – Dimension of the edge embeddings in the input graph.
node_dim (int) – Dimension of the node embeddings in the input graph.
message_passing_steps (int) – Number of message passing steps to execute. See 1 for more details.
References
- class conformer_rl.models.graph_components.GAT(hidden_dim: int, node_dim: int, num_layers: int = 6)
Implements a basic unit of the graph attention network (GAT) 2.
- Parameters
hidden_dim (int) – Dimension of the hidden layer.
node_dim (int) – Dimension of the node embeddings in the input graph.
num_layers (int) – Number of GAT conv layers. See 2 for more details.
References