RTGN_GAT

class conformer_rl.models.RTGN_GAT.RTGNGat(action_dim: int, hidden_dim: int, node_dim: int)

Actor-critic neural network using graph transformer network (GAT) 1 for predicting discrete torsion angles.

Works with molecules with any number of torsion angles, and batches containing graphs of different molecules.

Parameters
  • action_dim (int) – The number of discrete action choices for each torsion angle.

  • hidden_dim (int) – Dimension of the hidden layer.

  • node_dim (int) – The dimension of each node embedding in the input graph.

References

1

GAT paper

forward(obs: List[Tuple[torch_geometric.data.Batch, List[List[int]]]], action: Optional[torch.Tensor] = None) dict
Parameters
  • obs (list of 2-tuples of Pytorch Geometric Batch objects and list of lists of int) – Each tuple is a single observation (the entire list is a batch). Each Pytorch Geometric Batch object corresponds to the Pytorch Geometric graph representing the molecule. The list of lists of integers is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers are the indices of the four atoms making up the torsion.

  • action (batch of torch.Tensor, optional) – If specified, the log probabilities returned by the network will be the log probabilities for the specified actions instead of for the newly sampled actions.

Returns

prediction

  • prediction[‘v’]: The value estimation.

  • prediction[‘a’]: The action sampled from the distribution predicted by the network.

  • prediction[‘entropy’] The entropy of the distribution.

  • prediction[‘log_pi_a’] The log probabilities of the actions from the distribution.

Return type

dict