RTGN_recurrent
- class conformer_rl.models.RTGN_recurrent.RTGNRecurrent(action_dim: int, hidden_dim: int, edge_dim: int, node_dim: int)
Actor-critic neural network using message passing neural network (MPNN) 1 and long short-term memory (LSTM) for predicting discrete torsion angles, as described in the TorsionNet paper 2.
Works with molecules with any number of torsion angles, and batches containing graphs of different molecules.
- Parameters
action_dim (int) – The number of discrete action choices for each torsion angle.
hidden_dim (int) – Dimension of the hidden layer.
edge_dim (int) – The dimension of each edge embedding in the input graph.
node_dim (int) – The dimension of each node embedding in the input graph.
References
- forward(obs: List[Tuple[torch_geometric.data.Batch, List[List[int]]]], states: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, action: Optional[torch.Tensor] = None) Tuple[dict, Tuple[torch.Tensor]]
- Parameters
obs (list of 2-tuples of Pytorch Geometric Batch objects and list of lists of int) – Each tuple is a single observation (the entire list is a batch). Each Pytorch Geometric Batch object corresponds to the Pytorch Geometric graph representing the molecule. The list of lists of integers is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers are the indices of the four atoms making up the torsion.
states (4-tuple of torch.Tensor, optional) – Recurrent states for the lstm’s. If none are specified they are initialized to zeros.
action (batch of torch.Tensor, optional) – If specified, the log probabilities returned by the network will be the log probabilities for the specified actions instead of for the newly sampled actions.
- Returns
prediction (dict) –
prediction[‘v’]: The value estimation.
prediction[‘a’]: The action sampled from the distribution predicted by the network.
prediction[‘entropy’] The entropy of the distribution.
prediction[‘log_pi_a’] The log probabilities of the actions from the distribution.
states (4-tuple of torch.Tensor) – Output recurrent states from LSTM.