diff --git a/docs/api/config.rst b/docs/api/config.rst index 73b7511..93edf7d 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -106,6 +106,8 @@ The following table lists the available general options: +==============================+===========+===============================================================================================================+ | ``grad_accumulation_steps`` | integer | number of training steps to accumulate gradients between optimizer steps (default = ``1``) | +------------------------------+-----------+---------------------------------------------------------------------------------------------------------------+ +| ``max_grad_norm`` | float | maximum gradient norm for gradient clipping. A value of 0.0 means clipping is disabled (default = ``0.0``) | ++------------------------------+-----------+---------------------------------------------------------------------------------------------------------------+ .. _lr_schedule_properties-ref: @@ -326,8 +328,6 @@ The following table lists the available options by algorithm type: + + +------------------------------+------------+-------------------------------------------------------------------------------------------+ | | | ``value_loss_coefficient`` | float | value loss coefficient: weight for value estimate component of the loss function | + + +------------------------------+------------+-------------------------------------------------------------------------------------------+ -| | | ``max_grad_norm`` | float | maximum gradient norm for gradient clipping | -+ + +------------------------------+------------+-------------------------------------------------------------------------------------------+ | | | ``normalize_advantage`` | boolean | if set to true, advantage values are normalized over all buffer entries | +----------------+-------------+------------------------------+------------+-------------------------------------------------------------------------------------------+ diff --git a/src/csrc/include/internal/model_pack.h b/src/csrc/include/internal/model_pack.h index 43bc9eb..351e96d 100644 --- a/src/csrc/include/internal/model_pack.h +++ b/src/csrc/include/internal/model_pack.h @@ -37,6 +37,7 @@ struct ModelPack { std::shared_ptr comm; std::shared_ptr state; int grad_accumulation_steps = 1; + float max_grad_norm = 0.0; }; void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); diff --git a/src/csrc/include/internal/rl/off_policy/ddpg.h b/src/csrc/include/internal/rl/off_policy/ddpg.h index 806bb50..3788a1a 100644 --- a/src/csrc/include/internal/rl/off_policy/ddpg.h +++ b/src/csrc/include/internal/rl/off_policy/ddpg.h @@ -109,6 +109,11 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const q_model.comm->allreduce(grads, true); } + // gradient clipping + if (q_model.max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(q_model.model->parameters(), q_model.max_grad_norm); + } + // optimizer step q_model.optimizer->step(); @@ -149,6 +154,11 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const p_model.comm->allreduce(grads, true); } + // gradient clipping + if (p_model.max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(p_model.model->parameters(), p_model.max_grad_norm); + } + // optimizer step p_model.optimizer->step(); diff --git a/src/csrc/include/internal/rl/off_policy/sac.h b/src/csrc/include/internal/rl/off_policy/sac.h index ef459d0..4455485 100644 --- a/src/csrc/include/internal/rl/off_policy/sac.h +++ b/src/csrc/include/internal/rl/off_policy/sac.h @@ -202,8 +202,15 @@ void train_sac(const PolicyPack& p_model, const std::vector& q_models q_model.comm->allreduce(grads, true); } + // gradient clipping + if (q_model.max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(q_model.model->parameters(), q_model.max_grad_norm); + } + // optimizer step q_model.optimizer->step(); + + // lr scheduler step q_model.lr_scheduler->step(); } } @@ -263,6 +270,11 @@ void train_sac(const PolicyPack& p_model, const std::vector& q_models p_model.comm->allreduce(grads, true); } + // gradient clipping + if (p_model.max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(p_model.model->parameters(), p_model.max_grad_norm); + } + // optimizer step p_model.optimizer->step(); diff --git a/src/csrc/include/internal/rl/off_policy/td3.h b/src/csrc/include/internal/rl/off_policy/td3.h index 708a613..4600105 100644 --- a/src/csrc/include/internal/rl/off_policy/td3.h +++ b/src/csrc/include/internal/rl/off_policy/td3.h @@ -130,6 +130,11 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const q_model.comm->allreduce(grads, true); } + // gradient clipping + if (q_model.max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(q_model.model->parameters(), q_model.max_grad_norm); + } + // optimizer step q_model.optimizer->step(); @@ -180,6 +185,11 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const p_model.comm->allreduce(grads, true); } + // gradient clipping + if (p_model.max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(p_model.model->parameters(), p_model.max_grad_norm); + } + // optimizer step p_model.optimizer->step(); diff --git a/src/csrc/include/internal/rl/on_policy/ppo.h b/src/csrc/include/internal/rl/on_policy/ppo.h index b873121..5a16d76 100644 --- a/src/csrc/include/internal/rl/on_policy/ppo.h +++ b/src/csrc/include/internal/rl/on_policy/ppo.h @@ -45,8 +45,8 @@ template void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor q_tensor, torch::Tensor log_p_tensor, torch::Tensor adv_tensor, torch::Tensor ret_tensor, const T& epsilon, const T& clip_q, const T& entropy_loss_coeff, const T& q_loss_coeff, - const T& max_grad_norm, const T& target_kl_divergence, bool normalize_advantage, T& p_loss_val, - T& q_loss_val, T& kl_divergence, T& clip_fraction, T& explained_var) { + const T& target_kl_divergence, bool normalize_advantage, T& p_loss_val, T& q_loss_val, T& kl_divergence, + T& clip_fraction, T& explained_var) { // nvtx marker torchfort::nvtx::rangePush("torchfort_train_ppo"); @@ -179,8 +179,8 @@ void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch:: } // clip - if (max_grad_norm > 0.) { - torch::nn::utils::clip_grad_norm_(pq_model.model->parameters(), max_grad_norm); + if (pq_model.max_grad_norm > 0.) { + torch::nn::utils::clip_grad_norm_(pq_model.model->parameters(), pq_model.max_grad_norm); } // optimizer step @@ -325,7 +325,6 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_this comm; std::shared_ptr state; int grad_accumulation_steps = 1; + float max_grad_norm = 0.0; }; class GaussianPolicy : public Policy, public std::enable_shared_from_this { @@ -157,6 +158,7 @@ struct ACPolicyPack { std::shared_ptr comm; std::shared_ptr state; int grad_accumulation_steps = 1; + float max_grad_norm = 0.0; }; class GaussianACPolicy : public ACPolicy, public std::enable_shared_from_this { diff --git a/src/csrc/rl/off_policy/ddpg.cpp b/src/csrc/rl/off_policy/ddpg.cpp index 76ceca8..d90efe0 100644 --- a/src/csrc/rl/off_policy/ddpg.cpp +++ b/src/csrc/rl/off_policy/ddpg.cpp @@ -180,7 +180,7 @@ DDPGSystem::DDPGSystem(const char* name, const YAML::Node& system_node, int mode if (system_node["optimizer"]["general"]) { auto params = get_params(system_node["optimizer"]["general"]); - std::set supported_params{"grad_accumulation_steps"}; + std::set supported_params{"grad_accumulation_steps", "max_grad_norm"}; check_params(supported_params, params.keys()); try { p_model_.grad_accumulation_steps = params.get_param("grad_accumulation_steps")[0]; @@ -188,6 +188,13 @@ DDPGSystem::DDPGSystem(const char* name, const YAML::Node& system_node, int mode } catch (std::out_of_range) { // default } + + try { + p_model_.max_grad_norm = params.get_param("max_grad_norm")[0]; + q_model_.max_grad_norm = params.get_param("max_grad_norm")[0]; + } catch (std::out_of_range) { + // default + } } // get schedulers diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index 4f3bc54..c1905ee 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -189,7 +189,7 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_ if (system_node["optimizer"]["general"]) { auto params = get_params(system_node["optimizer"]["general"]); - std::set supported_params{"grad_accumulation_steps"}; + std::set supported_params{"grad_accumulation_steps", "max_grad_norm"}; check_params(supported_params, params.keys()); try { p_model_.grad_accumulation_steps = params.get_param("grad_accumulation_steps")[0]; @@ -199,6 +199,15 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_ } catch (std::out_of_range) { // default } + + try { + p_model_.max_grad_norm = params.get_param("max_grad_norm")[0]; + for (auto& q_model : q_models_) { + q_model.max_grad_norm = params.get_param("max_grad_norm")[0]; + } + } catch (std::out_of_range) { + // default + } } // in this case we want to optimize the entropy coefficient diff --git a/src/csrc/rl/off_policy/td3.cpp b/src/csrc/rl/off_policy/td3.cpp index eca74d2..7233d75 100644 --- a/src/csrc/rl/off_policy/td3.cpp +++ b/src/csrc/rl/off_policy/td3.cpp @@ -190,7 +190,7 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_ if (system_node["optimizer"]["general"]) { auto params = get_params(system_node["optimizer"]["general"]); - std::set supported_params{"grad_accumulation_steps"}; + std::set supported_params{"grad_accumulation_steps", "max_grad_norm"}; check_params(supported_params, params.keys()); try { p_model_.grad_accumulation_steps = params.get_param("grad_accumulation_steps")[0]; @@ -200,6 +200,15 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_ } catch (std::out_of_range) { // default } + + try { + p_model_.max_grad_norm = params.get_param("max_grad_norm")[0]; + for (auto& q_model : q_models_) { + q_model.max_grad_norm = params.get_param("max_grad_norm")[0]; + } + } catch (std::out_of_range) { + // default + } } // get schedulers diff --git a/src/csrc/rl/on_policy/ppo.cpp b/src/csrc/rl/on_policy/ppo.cpp index 7f140fd..70999e5 100644 --- a/src/csrc/rl/on_policy/ppo.cpp +++ b/src/csrc/rl/on_policy/ppo.cpp @@ -43,7 +43,6 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_ "target_kl_divergence", "entropy_loss_coefficient", "value_loss_coefficient", - "max_grad_norm", "normalize_advantage"}; check_params(supported_params, params.keys()); batch_size_ = params.get_param("batch_size")[0]; @@ -52,7 +51,6 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_ target_kl_divergence_ = params.get_param("target_kl_divergence")[0]; epsilon_ = params.get_param("epsilon", 0.2)[0]; clip_q_ = params.get_param("clip_q", 0.)[0]; - max_grad_norm_ = params.get_param("max_grad_norm", 0.5)[0]; entropy_loss_coeff_ = params.get_param("entropy_loss_coefficient", 0.0)[0]; value_loss_coeff_ = params.get_param("value_loss_coefficient", 0.5)[0]; normalize_advantage_ = params.get_param("normalize_advantage", true)[0]; @@ -135,13 +133,19 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_ if (system_node["optimizer"]["general"]) { auto params = get_params(system_node["optimizer"]["general"]); - std::set supported_params{"grad_accumulation_steps"}; + std::set supported_params{"grad_accumulation_steps", "max_grad_norm"}; check_params(supported_params, params.keys()); try { pq_model_.grad_accumulation_steps = params.get_param("grad_accumulation_steps")[0]; } catch (std::out_of_range) { // default } + + try { + pq_model_.max_grad_norm = params.get_param("max_grad_norm")[0]; + } catch (std::out_of_range) { + // default + } } // get schedulers @@ -454,8 +458,8 @@ void PPOSystem::trainStep(float& p_loss_val, float& q_loss_val) { // train step train_ppo(pq_model_, s, a, q, logp, adv, ret, epsilon_, clip_q_, entropy_loss_coeff_, value_loss_coeff_, - max_grad_norm_, target_kl_divergence_, normalize_advantage_, p_loss_val, q_loss_val, current_kl_divergence_, - clip_fraction_, explained_variance_); + target_kl_divergence_, normalize_advantage_, p_loss_val, q_loss_val, current_kl_divergence_, clip_fraction_, + explained_variance_); // system logging if ((system_state_->report_frequency > 0) && (train_step_count_ % system_state_->report_frequency == 0)) { diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index e3b4e3b..e52f830 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -99,13 +99,18 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f if (config["optimizer"]["general"]) { auto params = get_params(config["optimizer"]["general"]); - std::set supported_params{"grad_accumulation_steps"}; + std::set supported_params{"grad_accumulation_steps", "max_grad_norm"}; check_params(supported_params, params.keys()); try { models[name].grad_accumulation_steps = params.get_param("grad_accumulation_steps")[0]; } catch (std::out_of_range) { // default } + try { + models[name].max_grad_norm = params.get_param("max_grad_norm")[0]; + } catch (std::out_of_range) { + // default + } } } diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 5339496..58469f2 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -149,6 +149,10 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo models[name].comm->allreduce(*loss_val, true); } + if (models[name].max_grad_norm > 0.0) { + torch::nn::utils::clip_grad_norm_(model->parameters(), models[name].max_grad_norm); + } + opt->step(); if (models[name].lr_scheduler) { models[name].lr_scheduler->step(); diff --git a/tests/rl/configs/ppo.yaml b/tests/rl/configs/ppo.yaml index bcacd47..c7396fc 100644 --- a/tests/rl/configs/ppo.yaml +++ b/tests/rl/configs/ppo.yaml @@ -14,7 +14,6 @@ algorithm: target_kl_divergence: 0.02 entropy_loss_coefficient: 0. value_loss_coefficient: 0.5 - max_grad_norm: 0.5 normalize_advantage: True actor: @@ -47,6 +46,8 @@ optimizer: weight_decay: 0 eps: 1e-6 amsgrad: 0 + general: + max_grad_norm: 0.5 lr_scheduler: type: linear