RTGN_GAT_recurrent
- class conformer_rl.models.RTGN_GAT_recurrent.RTGNGatRecurrent(action_dim: int, hidden_dim: int, node_dim: int)
Actor-critic neural network using graph transformer network (GAT) 1 and long short-term memory (LSTM) for predicting discrete torsion angles.
Works with molecules with any number of torsion angles, and batches containing graphs of different molecules.
- Parameters
action_dim (int) – The number of discrete action choices for each torsion angle.
hidden_dim (int) – Dimension of the hidden layer.
node_dim (int) – The dimension of each node embedding in the input graph.
References
- forward(obs: List[Tuple[torch_geometric.data.Batch, List[List[int]]]], states: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, action: Optional[torch.Tensor] = None) Tuple[dict, Tuple[torch.Tensor]]
- Parameters
obs (list of 2-tuple of Pytorch Geometric Batch objects and list of lists of int) – Each tuple is a single observation (the entire list is a batch). Each Pytorch Geometric Batch object corresponds to the Pytorch Geometric graph representing the molecule. The list of lists of integers is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers are the indices of the four atoms making up the torsion.
states (4-tuple of torch.Tensor, optional) – Recurrent states for the LSTM’s. If none are specified they are initialized to zeros.
action (batch of torch.Tensor, optional) – If specified, the log probabilities returned by the network will be the log probabilities for the specified actions instead of for the newly sampled actions.
- Returns
prediction (dict) –
prediction[‘v’]: The value estimation.
prediction[‘a’]: The action sampled from the distribution predicted by the network.
prediction[‘entropy’] The entropy of the distribution.
prediction[‘log_pi_a’] The log probabilities of the actions from the distribution.
states (4-tuple of torch.Tensor) – Output recurrent states from LSTM.