Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/api/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ The following table lists the available parameter options by optimizer type:

The following table lists the available general options:

+------------------------------+-----------+------------------------------------------------------------------------------------------------+
| Option | Data Type | Description |
+==============================+===========+================================================================================================+
| ``grad_accumulation_steps`` | integer | number of training steps to accumulate gradients between optimizer steps (default = ``1``) |
+------------------------------+-----------+------------------------------------------------------------------------------------------------+
+------------------------------+-----------+---------------------------------------------------------------------------------------------------------------+
| Option | Data Type | Description |
+==============================+===========+===============================================================================================================+
| ``grad_accumulation_steps`` | integer | number of training steps to accumulate gradients between optimizer steps (default = ``1``) |
+------------------------------+-----------+---------------------------------------------------------------------------------------------------------------+

.. _lr_schedule_properties-ref:

Expand Down
4 changes: 4 additions & 0 deletions src/csrc/include/internal/rl/off_policy/ddpg.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const

// optimizer step
q_model.optimizer->step();

// lr scheduler step
q_model.lr_scheduler->step();
}

Expand Down Expand Up @@ -149,6 +151,8 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const

// optimizer step
p_model.optimizer->step();

// lr scheduler step
p_model.lr_scheduler->step();
}

Expand Down
39 changes: 27 additions & 12 deletions src/csrc/include/internal/rl/off_policy/sac.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,13 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models

// if we are updating the entropy coefficient, do that first
torch::Tensor alpha_loss;
auto state = p_model.state;
if (alpha_optimizer) {
alpha_optimizer->zero_grad();

if (state->step_train_current % p_model.grad_accumulation_steps == 0) {
alpha_optimizer->zero_grad();
}

// compute target entropy
float targ_ent;
if (target_entropy > 0.) {
Expand All @@ -114,18 +119,26 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
alpha_loss.backward();

// reduce gradients
if (p_model.comm) {
std::vector<torch::Tensor> grads;
grads.reserve(alpha_model->parameters().size());
for (const auto& p : alpha_model->parameters()) {
grads.push_back(p.grad());
if ((state->step_train_current + 1) % p_model.grad_accumulation_steps == 0) {
if (p_model.comm) {
std::vector<torch::Tensor> grads;
grads.reserve(alpha_model->parameters().size());
for (const auto& p : alpha_model->parameters()) {
grads.push_back(p.grad());
}
p_model.comm->allreduce(grads, true);
}
p_model.comm->allreduce(grads, true);
}

alpha_optimizer->step();
if (alpha_lr_scheduler) {
alpha_lr_scheduler->step();
// I think we do not need grad clipping here
// the model is just a scalar and clipping the grad for that might hurt

// optimizer step
alpha_optimizer->step();

// lr scheduler step
if (alpha_lr_scheduler) {
alpha_lr_scheduler->step();
}
}
}

Expand Down Expand Up @@ -157,7 +170,7 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
torch::Tensor q_old_tensor =
torch::squeeze(q_models[0].model->forward(std::vector<torch::Tensor>{state_old_tensor, action_old_tensor})[0], 1);
torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor);
auto state = q_models[0].state;
state = q_models[0].state;
if (state->step_train_current % q_models[0].grad_accumulation_steps == 0) {
q_models[0].optimizer->zero_grad();
}
Expand Down Expand Up @@ -252,6 +265,8 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models

// optimizer step
p_model.optimizer->step();

// scheduler step
p_model.lr_scheduler->step();
}

Expand Down
4 changes: 4 additions & 0 deletions src/csrc/include/internal/rl/off_policy/td3.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const

// optimizer step
q_model.optimizer->step();

// lr scheduler step
q_model.lr_scheduler->step();
}
}
Expand Down Expand Up @@ -180,6 +182,8 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const

// optimizer step
p_model.optimizer->step();

// lr scheduler step
p_model.lr_scheduler->step();
}

Expand Down
3 changes: 2 additions & 1 deletion src/csrc/include/internal/rl/on_policy/ppo.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch::
}

// optimizer step
// policy
pq_model.optimizer->step();

// lr scheduler step
pq_model.lr_scheduler->step();
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/csrc/rl/off_policy/ddpg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ DDPGSystem::DDPGSystem(const char* name, const YAML::Node& system_node, int mode
THROW_INVALID_USAGE("Missing optimizer block in configuration file.");
}

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
check_params(supported_params, params.keys());
try {
p_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
q_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
} catch (std::out_of_range) {
// default
}
}

// get schedulers
// policy model
if (system_node["policy_lr_scheduler"]) {
Expand Down
14 changes: 14 additions & 0 deletions src/csrc/rl/off_policy/sac.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,20 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_
THROW_INVALID_USAGE("Missing optimizer block in configuration file.");
}

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
check_params(supported_params, params.keys());
try {
p_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
for (auto& q_model : q_models_) {
q_model.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
}
} catch (std::out_of_range) {
// default
}
}

// in this case we want to optimize the entropy coefficient
if (system_node["alpha_optimizer"]) {
// register alpha as a new parameter
Expand Down
14 changes: 14 additions & 0 deletions src/csrc/rl/off_policy/td3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_
THROW_INVALID_USAGE("Missing optimizer block in configuration file.");
}

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
check_params(supported_params, params.keys());
try {
p_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
for (auto& q_model : q_models_) {
q_model.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
}
} catch (std::out_of_range) {
// default
}
}

// get schedulers
// policy model
if (system_node["policy_lr_scheduler"]) {
Expand Down
11 changes: 11 additions & 0 deletions src/csrc/rl/on_policy/ppo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node, int model_
THROW_INVALID_USAGE("Missing optimizer block in configuration file.");
}

if (system_node["optimizer"]["general"]) {
auto params = get_params(system_node["optimizer"]["general"]);
std::set<std::string> supported_params{"grad_accumulation_steps"};
check_params(supported_params, params.keys());
try {
pq_model_.grad_accumulation_steps = params.get_param<int>("grad_accumulation_steps")[0];
} catch (std::out_of_range) {
// default
}
}

// get schedulers
// policy model
if (system_node["lr_scheduler"]) {
Expand Down