diff --git a/docs/api/config.rst b/docs/api/config.rst index 93edf7d..72539d1 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -31,18 +31,33 @@ The block in the configuration file defining general properties takes the follow The following table lists the available options: -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| Option | Data Type | Description | -+=======================+===========+================================================================================================+ -| ``report_frequency`` | integer | frequency of reported TorchFort training/validation output lines to terminal (default = ``0``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| ``enable_wandb_hook`` | boolean | flag to control whether wandb hook is active (default = ``false``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| ``verbose`` | boolean | flag to control verbose output from TorchFort (default = ``false``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| Option | Data Type | Description | ++========================+===========+=================================================================================================+ +| ``report_frequency`` | integer | frequency of reported TorchFort training/validation output lines to terminal (default = ``0``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``enable_wandb_hook`` | boolean | flag to control whether wandb hook is active (default = ``false``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``verbose`` | boolean | flag to control verbose output from TorchFort (default = ``false``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``enable_cuda_graphs`` | boolean | flag to enable CUDA graph capture for training and inference (default = ``false``). See below. | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ For more information about the wandb hook, see :ref:`wandb_support-ref`. +CUDA Graphs +^^^^^^^^^^^ +When ``enable_cuda_graphs`` is set to ``true``, TorchFort will capture CUDA graphs for the forward pass (inference) +and the forward + loss + backward pass (training). CUDA graphs can significantly reduce kernel launch overhead +and improve performance for models with many small operations. + +**Requirements and limitations:** + +- Input tensors must be on GPU and must have consistent data pointers, shapes, and dtypes across all training/inference calls with the captured model. + If inputs change after graph capture, an error will be thrown. +- The optimizer step and learning rate scheduler updates are not captured in the graph. +- A warmup period of 3 iterations is performed before graph capture to ensure stable execution. + .. _optimizer_properties-ref: Optimizer Properties @@ -613,4 +628,4 @@ Refer to the :ref:`lr_schedule_properties-ref` for available scheduler types and General Remarks ~~~~~~~~~~~~~~~ -Example YAML files for training the different algorithms are available in the `tests/rl/configs <<../../tests/rl/configs/>>`_ directory. \ No newline at end of file +Example YAML files for training the different algorithms are available in the `tests/rl/configs <<../../tests/rl/configs/>>`_ directory. diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h new file mode 100644 index 0000000..307b395 --- /dev/null +++ b/src/csrc/include/internal/cuda_graphs.h @@ -0,0 +1,405 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_GPU +#include +#include +#include +#endif + +#include +#include +#include + +#include + +#include "internal/defines.h" +#include "internal/exceptions.h" + +namespace torchfort { + +// Action to take for current iteration +enum class GraphAction { + WARMUP, // Run eager execution, increment warmup count + CAPTURE, // Run eager execution with graph capture + REPLAY // Skip eager execution, replay captured graph +}; + +#ifdef ENABLE_GPU + +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + +// RAII wrapper for cudaGraph_t +class CudaGraph { +public: + CudaGraph() : graph_(nullptr) {} + ~CudaGraph() { + if (graph_) { + cudaGraphDestroy(graph_); + } + } + + // Non-copyable + CudaGraph(const CudaGraph&) = delete; + CudaGraph& operator=(const CudaGraph&) = delete; + + // Movable + CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { other.graph_ = nullptr; } + CudaGraph& operator=(CudaGraph&& other) noexcept { + if (this != &other) { + if (graph_) + cudaGraphDestroy(graph_); + graph_ = other.graph_; + other.graph_ = nullptr; + } + return *this; + } + + cudaGraph_t& get() { return graph_; } + cudaGraph_t get() const { return graph_; } + bool valid() const { return graph_ != nullptr; } + +private: + cudaGraph_t graph_; +}; + +// RAII wrapper for cudaGraphExec_t +class CudaGraphExec { +public: + CudaGraphExec() : exec_(nullptr) {} + ~CudaGraphExec() { + if (exec_) { + cudaGraphExecDestroy(exec_); + } + } + + // Non-copyable + CudaGraphExec(const CudaGraphExec&) = delete; + CudaGraphExec& operator=(const CudaGraphExec&) = delete; + + // Movable + CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { other.exec_ = nullptr; } + CudaGraphExec& operator=(CudaGraphExec&& other) noexcept { + if (this != &other) { + if (exec_) + cudaGraphExecDestroy(exec_); + exec_ = other.exec_; + other.exec_ = nullptr; + } + return *this; + } + + cudaGraphExec_t& get() { return exec_; } + cudaGraphExec_t get() const { return exec_; } + bool valid() const { return exec_ != nullptr; } + +private: + cudaGraphExec_t exec_; +}; + +// Graph state for inference +class InferenceGraphState { +public: + InferenceGraphState(const char* context = "inference") : context_(context) {} + + // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture + // Returns the action to take. Call begin_capture() after this if action == CAPTURE. + GraphAction prepare(const std::vector& inputs) { + InputSignature current_sig = make_input_signature(inputs); + + if (captured_) { + validate_inputs(current_sig); + return GraphAction::REPLAY; + } + + if (warmup_count_ == kCudaGraphWarmupIters) { + input_signature_ = std::move(current_sig); + return GraphAction::CAPTURE; + } + + return GraphAction::WARMUP; + } + + // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work + void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); + guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + } + + // Finalize after forward pass - handles capture end or warmup increment + void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, + const std::vector& outputs) { + if (action == GraphAction::CAPTURE) { + end_capture(capture_stream, user_stream, guard, device_index); + static_outputs_ = outputs; + } else if (action == GraphAction::WARMUP) { + warmup_count_++; + } + } + + // Launch captured graph on the given stream + void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } + + // Get static outputs (valid after CAPTURE or REPLAY) + const std::vector& get_outputs() const { return static_outputs_; } + + bool is_captured() const { return captured_; } + +private: + // Input signature for validating consistent inputs + struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator!=(const InputSignature& other) const { + return ptrs != other.ptrs || shapes != other.shapes || dtypes != other.dtypes; + } + }; + + static InputSignature make_input_signature(const std::vector& tensors) { + InputSignature sig; + sig.ptrs.reserve(tensors.size()); + sig.shapes.reserve(tensors.size()); + sig.dtypes.reserve(tensors.size()); + for (const auto& t : tensors) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; + } + + void validate_inputs(const InputSignature& current_sig) const { + if (input_signature_ != current_sig) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context_ << ". " + << "When enable_cuda_graphs is set, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable enable_cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } + } + + void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + instantiate_graph(); + captured_ = true; + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); + guard.reset_stream(user_c10_stream); + } + + void instantiate_graph() { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = + cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed in " << context_ << ": " << cudaGetErrorString(result); + if (std::strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } + } + + const char* context_; + int warmup_count_ = 0; + bool captured_ = false; + InputSignature input_signature_; + CudaGraph graph_; + CudaGraphExec graph_exec_; + std::vector static_outputs_; +}; + +// Graph state for training (single graph for forward + loss + backward) +class TrainingGraphState { +public: + TrainingGraphState(const char* context = "training") : context_(context) {} + + // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture + // Returns the action to take. Call begin_capture() after this if action == CAPTURE. + GraphAction prepare(const std::vector& inputs, const std::vector& labels, + const std::vector& extra_args) { + InputSignature current_sig = make_input_signature(inputs, labels, extra_args); + + if (captured_) { + validate_inputs(current_sig); + return GraphAction::REPLAY; + } + + if (warmup_count_ == kCudaGraphWarmupIters) { + input_signature_ = std::move(current_sig); + return GraphAction::CAPTURE; + } + + return GraphAction::WARMUP; + } + + // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work + void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); + guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + } + + // Finalize after forward+loss+backward pass - handles capture end or warmup increment + void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, const torch::Tensor& loss) { + if (action == GraphAction::CAPTURE) { + end_capture(capture_stream, user_stream, guard, device_index); + static_loss_ = loss; + } else if (action == GraphAction::WARMUP) { + warmup_count_++; + } + } + + // Launch captured graph on the given stream + void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } + + // Get static loss (valid after CAPTURE or REPLAY) + const torch::Tensor& get_loss() const { return static_loss_; } + + bool is_captured() const { return captured_; } + +private: + // Input signature for validating consistent inputs + struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator!=(const InputSignature& other) const { + return ptrs != other.ptrs || shapes != other.shapes || dtypes != other.dtypes; + } + }; + + static InputSignature make_input_signature(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature sig; + size_t total = inputs.size() + labels.size() + extra_args.size(); + sig.ptrs.reserve(total); + sig.shapes.reserve(total); + sig.dtypes.reserve(total); + + for (const auto& t : inputs) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : labels) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : extra_args) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; + } + + void validate_inputs(const InputSignature& current_sig) const { + if (input_signature_ != current_sig) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context_ << ". " + << "When enable_cuda_graphs is set, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable enable_cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } + } + + void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + instantiate_graph(); + captured_ = true; + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); + guard.reset_stream(user_c10_stream); + } + + void instantiate_graph() { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = + cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed in " << context_ << ": " << cudaGetErrorString(result); + if (std::strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } + } + + const char* context_; + int warmup_count_ = 0; + bool captured_ = false; + InputSignature input_signature_; + CudaGraph graph_; + CudaGraphExec graph_exec_; + torch::Tensor static_loss_; +}; + +// Graph state for a model, including the capture stream +class ModelGraphState { +public: + InferenceGraphState inference{"inference"}; + TrainingGraphState training{"training"}; + + ModelGraphState(int device_index = 0) : capture_stream_(nullptr), device_index_(device_index) { + // Create a non-blocking stream for graph capture + CHECK_CUDA(cudaSetDevice(device_index_)); + CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); + } + + ~ModelGraphState() { + if (capture_stream_) { + cudaStreamDestroy(capture_stream_); + } + } + + // Non-copyable + ModelGraphState(const ModelGraphState&) = delete; + ModelGraphState& operator=(const ModelGraphState&) = delete; + + cudaStream_t capture_stream() const { return capture_stream_; } + int device_index() const { return device_index_; } + +private: + cudaStream_t capture_stream_; + int device_index_; +}; + +#endif + +} // namespace torchfort diff --git a/src/csrc/include/internal/model_pack.h b/src/csrc/include/internal/model_pack.h index 351e96d..f09b331 100644 --- a/src/csrc/include/internal/model_pack.h +++ b/src/csrc/include/internal/model_pack.h @@ -25,6 +25,9 @@ #include "internal/distributed.h" #include "internal/model_state.h" #include "internal/model_wrapper.h" +#ifdef ENABLE_GPU +#include "internal/cuda_graphs.h" +#endif namespace torchfort { @@ -38,6 +41,11 @@ struct ModelPack { std::shared_ptr state; int grad_accumulation_steps = 1; float max_grad_norm = 0.0; + +#ifdef ENABLE_GPU + // CUDA graph state (initialized if enable_cuda_graphs is true) + std::shared_ptr graph_state; +#endif }; void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); diff --git a/src/csrc/include/internal/model_state.h b/src/csrc/include/internal/model_state.h index c4c2111..ff63aba 100644 --- a/src/csrc/include/internal/model_state.h +++ b/src/csrc/include/internal/model_state.h @@ -36,6 +36,9 @@ struct ModelState { bool verbose; std::filesystem::path report_file; + // CUDA graph settings + bool enable_cuda_graphs = false; + void save(const std::string& fname); void load(const std::string& fname); }; diff --git a/src/csrc/setup.cpp b/src/csrc/setup.cpp index 7f0c296..c052065 100644 --- a/src/csrc/setup.cpp +++ b/src/csrc/setup.cpp @@ -262,7 +262,7 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ if (state_node["general"]) { auto params = get_params(state_node["general"]); - std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose"}; + std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", "enable_cuda_graphs"}; check_params(supported_params, params.keys()); state->report_frequency = params.get_param("report_frequency")[0]; try { @@ -305,6 +305,12 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ } catch (std::out_of_range) { state->verbose = false; } + + try { + state->enable_cuda_graphs = params.get_param("enable_cuda_graphs")[0]; + } catch (std::out_of_range) { + state->enable_cuda_graphs = false; + } } return state; diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index e52f830..8ed2445 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -125,6 +125,13 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f // Setting up general options models[name].state = get_state(name, config); + +#ifdef ENABLE_GPU + // Initialize graph state if CUDA graphs are enabled + if (models[name].state->enable_cuda_graphs && models[name].model->device().is_cuda()) { + models[name].graph_state = std::make_shared(models[name].model->device().index()); + } +#endif } catch (const BaseException& e) { std::cerr << e.what(); return e.getResult(); diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 58469f2..488062e 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -26,8 +26,10 @@ #endif #include +#include "internal/cuda_graphs.h" #include "internal/defines.h" #include "internal/logging.h" +#include "internal/model_pack.h" #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" @@ -55,7 +57,7 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor auto model = models[name].model; -#if ENABLE_GPU +#ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); @@ -66,9 +68,44 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor inputs->to(model->device()); model->eval(); - auto results = model->forward(inputs->tensors); - for (int i = 0; i < results.size(); ++i) { + std::vector results; + + GraphAction action = GraphAction::WARMUP; + +#ifdef ENABLE_GPU + InferenceGraphState* graph_state = nullptr; + + if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { + graph_state = &models[name].graph_state->inference; + action = graph_state->prepare(inputs->tensors); + + if (action == GraphAction::CAPTURE) { + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index()); + } + } +#endif + + // Forward pass + if (action != GraphAction::REPLAY) { + results = model->forward(inputs->tensors); + } + +#ifdef ENABLE_GPU + if (graph_state) { + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index(), results); + + if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { + graph_state->launch(ext_stream); + results = graph_state->get_outputs(); + } + } +#endif + + // Copy results to output tensors + for (size_t i = 0; i < results.size(); ++i) { outputs->tensors[i].copy_(results[i].reshape(outputs->tensors[i].sizes())); } models[name].state->step_inference++; @@ -120,21 +157,64 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo auto opt = models[name].optimizer; auto state = models[name].state; - // fwd pass - auto results = model->forward(inputs->tensors); - auto loss = models[name].loss->forward(results, labels->tensors, - (extra_loss_args) ? extra_loss_args->tensors : std::vector()); + torch::Tensor loss; - // extract loss - *loss_val = loss.item(); + GraphAction action = GraphAction::WARMUP; + +#ifdef ENABLE_GPU + TrainingGraphState* graph_state = nullptr; + + if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { + graph_state = &models[name].graph_state->training; + std::vector extra_args_vec = + extra_loss_args ? extra_loss_args->tensors : std::vector(); + action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); + } +#endif - // bwd pass if (state->step_train_current % models[name].grad_accumulation_steps == 0) { - opt->zero_grad(); + // Only run zero_grad on non-replay steps or if gradient accumulation is active + if (action != GraphAction::REPLAY || models[name].grad_accumulation_steps > 1) { + if (models[name].grad_accumulation_steps > 1) { +#ifdef ENABLE_GPU + // With graphs and grad accumulation active, gradients must be persistent (set_to_none = false) + opt->zero_grad(/*set_to_none=*/(graph_state == nullptr)); +#else + opt->zero_grad(/*set_to_none=*/true); +#endif + } else { + opt->zero_grad(/*set_to_none=*/true); + } + } } - loss.backward(); +#ifdef ENABLE_GPU + if (action == GraphAction::CAPTURE) { + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); + } +#endif + + // Forward + loss + backward + if (action != GraphAction::REPLAY) { + auto fwd_results = model->forward(inputs->tensors); + loss = models[name].loss->forward(fwd_results, labels->tensors, + (extra_loss_args) ? extra_loss_args->tensors : std::vector()); + loss.backward(); + } +#ifdef ENABLE_GPU + if (graph_state) { + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index(), loss); + + if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { + graph_state->launch(ext_stream); + loss = graph_state->get_loss(); + } + } +#endif + + // Optimizer step and related operations if ((state->step_train_current + 1) % models[name].grad_accumulation_steps == 0) { // allreduce (average) gradients (if running distributed) if (models[name].comm) { @@ -144,9 +224,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo grads.push_back(p.grad()); } models[name].comm->allreduce(grads, true); - - // average returned loss value - models[name].comm->allreduce(*loss_val, true); } if (models[name].max_grad_norm > 0.0) { @@ -159,6 +236,13 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } } + // Extract loss value + *loss_val = loss.item(); + if (models[name].comm) { + // average returned loss value (if running distributed) + models[name].comm->allreduce(*loss_val, true); + } + state->step_train++; state->step_train_current++; if (state->report_frequency > 0 && state->step_train % state->report_frequency == 0) {