Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ The following table lists the available general options:
+==============================+===========+===============================================================================================================+
| ``grad_accumulation_steps`` | integer | number of training steps to accumulate gradients between optimizer steps (default = ``1``) |
+------------------------------+-----------+---------------------------------------------------------------------------------------------------------------+
| ``max_grad_norm`` | float | maximum gradient norm for gradient clipping. A value of 0.0 means clipping is disabled (default = ``0.0``) |
+------------------------------+-----------+---------------------------------------------------------------------------------------------------------------+

.. _lr_schedule_properties-ref:

Expand Down Expand Up @@ -326,8 +328,6 @@ The following table lists the available options by algorithm type:
+ + +------------------------------+------------+-------------------------------------------------------------------------------------------+
| | | ``value_loss_coefficient`` | float | value loss coefficient: weight for value estimate component of the loss function |
+ + +------------------------------+------------+-------------------------------------------------------------------------------------------+
| | | ``max_grad_norm`` | float | maximum gradient norm for gradient clipping |
+ + +------------------------------+------------+-------------------------------------------------------------------------------------------+
| | | ``normalize_advantage`` | boolean | if set to true, advantage values are normalized over all buffer entries |
+----------------+-------------+------------------------------+------------+-------------------------------------------------------------------------------------------+

Expand Down
1 change: 1 addition & 0 deletions src/csrc/include/internal/model_pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct ModelPack {
std::shared_ptr<Comm> comm;
std::shared_ptr<ModelState> state;
int grad_accumulation_steps = 1;
float max_grad_norm = 0.0;
};

void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true);
Expand Down
10 changes: 10 additions & 0 deletions src/csrc/include/internal/rl/off_policy/ddpg.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const
q_model.comm->allreduce(grads, true);
}

// gradient clipping
if (q_model.max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(q_model.model->parameters(), q_model.max_grad_norm);
}

// optimizer step
q_model.optimizer->step();

Expand Down Expand Up @@ -149,6 +154,11 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const
p_model.comm->allreduce(grads, true);
}

// gradient clipping
if (p_model.max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(p_model.model->parameters(), p_model.max_grad_norm);
}

// optimizer step
p_model.optimizer->step();

Expand Down
12 changes: 12 additions & 0 deletions src/csrc/include/internal/rl/off_policy/sac.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,15 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
q_model.comm->allreduce(grads, true);
}

// gradient clipping
if (q_model.max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(q_model.model->parameters(), q_model.max_grad_norm);
}

// optimizer step
q_model.optimizer->step();

// lr scheduler step
q_model.lr_scheduler->step();
}
}
Expand Down Expand Up @@ -263,6 +270,11 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
p_model.comm->allreduce(grads, true);
}

// gradient clipping
if (p_model.max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(p_model.model->parameters(), p_model.max_grad_norm);
}

// optimizer step
p_model.optimizer->step();

Expand Down
10 changes: 10 additions & 0 deletions src/csrc/include/internal/rl/off_policy/td3.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const
q_model.comm->allreduce(grads, true);
}

// gradient clipping
if (q_model.max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(q_model.model->parameters(), q_model.max_grad_norm);
}

// optimizer step
q_model.optimizer->step();

Expand Down Expand Up @@ -180,6 +185,11 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const
p_model.comm->allreduce(grads, true);
}

// gradient clipping
if (p_model.max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(p_model.model->parameters(), p_model.max_grad_norm);
}

// optimizer step
p_model.optimizer->step();

Expand Down
9 changes: 4 additions & 5 deletions src/csrc/include/internal/rl/on_policy/ppo.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ template <typename T>
void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch::Tensor action_tensor,
torch::Tensor q_tensor, torch::Tensor log_p_tensor, torch::Tensor adv_tensor, torch::Tensor ret_tensor,
const T& epsilon, const T& clip_q, const T& entropy_loss_coeff, const T& q_loss_coeff,
const T& max_grad_norm, const T& target_kl_divergence, bool normalize_advantage, T& p_loss_val,
T& q_loss_val, T& kl_divergence, T& clip_fraction, T& explained_var) {
const T& target_kl_divergence, bool normalize_advantage, T& p_loss_val, T& q_loss_val, T& kl_divergence,
T& clip_fraction, T& explained_var) {

// nvtx marker
torchfort::nvtx::rangePush("torchfort_train_ppo");
Expand Down Expand Up @@ -179,8 +179,8 @@ void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch::
}

// clip
if (max_grad_norm > 0.) {
torch::nn::utils::clip_grad_norm_(pq_model.model->parameters(), max_grad_norm);
if (pq_model.max_grad_norm > 0.) {
torch::nn::utils::clip_grad_norm_(pq_model.model->parameters(), pq_model.max_grad_norm);
}

// optimizer step
Expand Down Expand Up @@ -325,7 +325,6 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_this<R
float target_kl_divergence_, current_kl_divergence_, explained_variance_;
float clip_fraction_;
float a_low_, a_high_;
float max_grad_norm_;
bool normalize_advantage_;
ActorNormalizationMode actor_normalization_mode_;
};
Expand Down
2 changes: 2 additions & 0 deletions src/csrc/include/internal/rl/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ struct PolicyPack {
std::shared_ptr<Comm> comm;
std::shared_ptr<ModelState> state;
int grad_accumulation_steps = 1;
float max_grad_norm = 0.0;
};

class GaussianPolicy : public Policy, public std::enable_shared_from_this<Policy> {
Expand Down Expand Up @@ -157,6 +158,7 @@ struct ACPolicyPack {
std::shared_ptr<Comm> comm;
std::shared_ptr<ModelState> state;
int grad_accumulation_steps = 1;
float max_grad_norm = 0.0;
};

class GaussianACPolicy : public ACPolicy, public std::enable_shared_from_this<ACPolicy> {
Expand Down
9 changes: 8 additions & 1 deletion src/csrc/rl/off_policy/ddpg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,21 @@ DDPGSystem::DDPGSystem(const char* name, const YAML::Node& system_node, int mode

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
std::set<std::string> supported_params{"grad_accumulation_steps", "max_grad_norm"};
check_params(supported_params, params.keys());
try {
p_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
q_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
} catch (std::out_of_range) {
// default
}

try {
p_model_.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
q_model_.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
} catch (std::out_of_range) {
// default
}
}

// get schedulers
Expand Down
11 changes: 10 additions & 1 deletion src/csrc/rl/off_policy/sac.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
std::set<std::string> supported_params{"grad_accumulation_steps", "max_grad_norm"};
check_params(supported_params, params.keys());
try {
p_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
Expand All @@ -199,6 +199,15 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_
} catch (std::out_of_range) {
// default
}

try {
p_model_.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
for (auto& q_model : q_models_) {
q_model.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
}
} catch (std::out_of_range) {
// default
}
}

// in this case we want to optimize the entropy coefficient
Expand Down
11 changes: 10 additions & 1 deletion src/csrc/rl/off_policy/td3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
std::set<std::string> supported_params{"grad_accumulation_steps", "max_grad_norm"};
check_params(supported_params, params.keys());
try {
p_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
Expand All @@ -200,6 +200,15 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_
} catch (std::out_of_range) {
// default
}

try {
p_model_.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
for (auto& q_model : q_models_) {
q_model.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
}
} catch (std::out_of_range) {
// default
}
}

// get schedulers
Expand Down
14 changes: 9 additions & 5 deletions src/csrc/rl/on_policy/ppo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_
"target_kl_divergence",
"entropy_loss_coefficient",
"value_loss_coefficient",
"max_grad_norm",
"normalize_advantage"};
check_params(supported_params, params.keys());
batch_size_ = params.get_param<int>("batch_size")[0];
Expand All @@ -52,7 +51,6 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_
target_kl_divergence_ = params.get_param<float>("target_kl_divergence")[0];
epsilon_ = params.get_param<float>("epsilon", 0.2)[0];
clip_q_ = params.get_param<float>("clip_q", 0.)[0];
max_grad_norm_ = params.get_param<float>("max_grad_norm", 0.5)[0];
entropy_loss_coeff_ = params.get_param<float>("entropy_loss_coefficient", 0.0)[0];
value_loss_coeff_ = params.get_param<float>("value_loss_coefficient", 0.5)[0];
normalize_advantage_ = params.get_param<bool>("normalize_advantage", true)[0];
Expand Down Expand Up @@ -135,13 +133,19 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
std::set<std::string> supported_params{"grad_accumulation_steps", "max_grad_norm"};
check_params(supported_params, params.keys());
try {
pq_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
} catch (std::out_of_range) {
// default
}

try {
pq_model_.max_grad_norm = params.get_param<float>("max_grad_norm")[0];
} catch (std::out_of_range) {
// default
}
}

// get schedulers
Expand Down Expand Up @@ -454,8 +458,8 @@ void PPOSystem::trainStep(float& p_loss_val, float& q_loss_val) {

// train step
train_ppo(pq_model_, s, a, q, logp, adv, ret, epsilon_, clip_q_, entropy_loss_coeff_, value_loss_coeff_,
max_grad_norm_, target_kl_divergence_, normalize_advantage_, p_loss_val, q_loss_val, current_kl_divergence_,
clip_fraction_, explained_variance_);
target_kl_divergence_, normalize_advantage_, p_loss_val, q_loss_val, current_kl_divergence_, clip_fraction_,
explained_variance_);

// system logging
if ((system_state_->report_frequency > 0) && (train_step_count_ % system_state_->report_frequency == 0)) {
Expand Down
7 changes: 6 additions & 1 deletion src/csrc/torchfort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,18 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f

if (config["optimizer"]["general"]) {
auto params = get_params(config["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
std::set<std::string> supported_params{"grad_accumulation_steps", "max_grad_norm"};
check_params(supported_params, params.keys());
try {
models[name].grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
} catch (std::out_of_range) {
// default
}
try {
models[name].max_grad_norm = params.get_param<float>("max_grad_norm")[0];
} catch (std::out_of_range) {
// default
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/csrc/training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo
models[name].comm->allreduce(*loss_val, true);
}

if (models[name].max_grad_norm > 0.0) {
torch::nn::utils::clip_grad_norm_(model->parameters(), models[name].max_grad_norm);
}

opt->step();
if (models[name].lr_scheduler) {
models[name].lr_scheduler->step();
Expand Down
3 changes: 2 additions & 1 deletion tests/rl/configs/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ algorithm:
target_kl_divergence: 0.02
entropy_loss_coefficient: 0.
value_loss_coefficient: 0.5
max_grad_norm: 0.5
normalize_advantage: True

actor:
Expand Down Expand Up @@ -47,6 +46,8 @@ optimizer:
weight_decay: 0
eps: 1e-6
amsgrad: 0
general:
max_grad_norm: 0.5

lr_scheduler:
type: linear
Expand Down