PPO_agent

class conformer_rl.agents.PPO.PPO_agent.PPOAgent(config: conformer_rl.config.agent_config.Config)

Bases: conformer_rl.agents.base_ac_agent.BaseACAgent

Implements agent that uses the PPO (proximal policy optimization) 1 algorithm.

Parameters

config (Config) – Configuration object for the agent. See notes for a list of config parameters used by this agent.

Notes

Config parameters: The following parameters are required in the config object. See Config for more details on the parameters.

  • tag

  • train_env

  • eval_env

  • optimizer_fn

  • network

  • rollout_length

  • max_steps

  • save_interval

  • eval_interval

  • eval_episodes

  • optimization_epochs

  • mini_batch_size

  • discount

  • use_gae

  • gae_lambda

  • entropy_weight

  • value_loss_coefficient

  • gradient_clip

  • ppo_ratio_clip

  • data_dir

  • use_tensorboard

Logged values: The following values are logged during training:

  • advantages

  • loss

  • policy_loss

  • entropy_loss

  • value_loss

  • episodic_return_eval (total rewards per episode for eval episodes)

  • episodic_return_train (total rewards per episode for training episodes)

References

1

PPO Paper

step() None

Performs one iteration of acquiring samples on the environment and then trains on the acquired samples.

evaluate() None

Evaluates the agent on the evaluation environment.

Information dict returned by the environment’s conformer_rl.environments.conformer_env.ConformerEnv.step() method is logged by the eval_logger and saved.

load(filename: str) None

Loads the neural network with weights.

Parameters

filename (str) – The path where the neural network weights are saved.

run_steps() None

Trains the agent.

Trains the agent until the maximum number of steps (specified by config) is reached. Also periodically saves neural network parameters and performs evaluations on the agent, if specified in the config.

save(filename: str) None

Saves the neural network weights to a file.

Parameters

filename (str) – The path where the neural network weights are to be saved.