From 5a7a566b9c052922ec77145892179068d5c638da Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Tue, 2 Dec 2025 12:30:32 -0800 Subject: [PATCH 1/6] Adding support for CUDA graph capture. Signed-off-by: Josh Romero --- docs/api/config.rst | 36 +++- src/csrc/include/internal/cuda_graphs.h | 246 ++++++++++++++++++++++++ src/csrc/include/internal/model_pack.h | 8 + src/csrc/include/internal/model_state.h | 3 + src/csrc/setup.cpp | 9 +- src/csrc/torchfort.cpp | 8 + src/csrc/training.cpp | 200 +++++++++++++++++-- 7 files changed, 487 insertions(+), 23 deletions(-) create mode 100644 src/csrc/include/internal/cuda_graphs.h diff --git a/docs/api/config.rst b/docs/api/config.rst index 93edf7d..8a06f6a 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -31,18 +31,34 @@ The block in the configuration file defining general properties takes the follow The following table lists the available options: -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| Option | Data Type | Description | -+=======================+===========+================================================================================================+ -| ``report_frequency`` | integer | frequency of reported TorchFort training/validation output lines to terminal (default = ``0``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| ``enable_wandb_hook`` | boolean | flag to control whether wandb hook is active (default = ``false``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| ``verbose`` | boolean | flag to control verbose output from TorchFort (default = ``false``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| Option | Data Type | Description | ++========================+===========+=================================================================================================+ +| ``report_frequency`` | integer | frequency of reported TorchFort training/validation output lines to terminal (default = ``0``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``enable_wandb_hook`` | boolean | flag to control whether wandb hook is active (default = ``false``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``verbose`` | boolean | flag to control verbose output from TorchFort (default = ``false``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``enable_cuda_graphs`` | boolean | flag to enable CUDA graph capture for training and inference (default = ``false``). See below. | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ For more information about the wandb hook, see :ref:`wandb_support-ref`. +CUDA Graphs +^^^^^^^^^^^ +When ``enable_cuda_graphs`` is set to ``true``, TorchFort will capture CUDA graphs for the forward pass (inference) +and the forward + loss + backward pass (training). CUDA graphs can significantly reduce kernel launch overhead +and improve performance for models with many small operations. + +**Requirements and limitations:** + +- Input tensors must be on GPU and must have consistent data pointers, shapes, and dtypes across all training/inference calls with the captured model. + If inputs change after graph capture, an error will be thrown. +- For training, CUDA graph capture is automatically disabled when gradient accumulation (``grad_accumulation_steps > 1``) is active. +- The optimizer step and learning rate scheduler updates are not captured in the graph. +- A warmup period of 3 iterations is performed before graph capture to ensure stable execution. + .. _optimizer_properties-ref: Optimizer Properties @@ -613,4 +629,4 @@ Refer to the :ref:`lr_schedule_properties-ref` for available scheduler types and General Remarks ~~~~~~~~~~~~~~~ -Example YAML files for training the different algorithms are available in the `tests/rl/configs <<../../tests/rl/configs/>>`_ directory. \ No newline at end of file +Example YAML files for training the different algorithms are available in the `tests/rl/configs <<../../tests/rl/configs/>>`_ directory. diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h new file mode 100644 index 0000000..9bb9ccd --- /dev/null +++ b/src/csrc/include/internal/cuda_graphs.h @@ -0,0 +1,246 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_GPU + +#include + +#include +#include + +#include +#include + +#include "internal/defines.h" +#include "internal/exceptions.h" + +namespace torchfort { + +// RAII wrapper for cudaGraph_t +class CudaGraph { +public: + CudaGraph() : graph_(nullptr) {} + ~CudaGraph() { + if (graph_) { + cudaGraphDestroy(graph_); + } + } + + // Non-copyable + CudaGraph(const CudaGraph&) = delete; + CudaGraph& operator=(const CudaGraph&) = delete; + + // Movable + CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { + other.graph_ = nullptr; + } + CudaGraph& operator=(CudaGraph&& other) noexcept { + if (this != &other) { + if (graph_) cudaGraphDestroy(graph_); + graph_ = other.graph_; + other.graph_ = nullptr; + } + return *this; + } + + cudaGraph_t& get() { return graph_; } + cudaGraph_t get() const { return graph_; } + bool valid() const { return graph_ != nullptr; } + +private: + cudaGraph_t graph_; +}; + +// RAII wrapper for cudaGraphExec_t +class CudaGraphExec { +public: + CudaGraphExec() : exec_(nullptr) {} + ~CudaGraphExec() { + if (exec_) { + cudaGraphExecDestroy(exec_); + } + } + + // Non-copyable + CudaGraphExec(const CudaGraphExec&) = delete; + CudaGraphExec& operator=(const CudaGraphExec&) = delete; + + // Movable + CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { + other.exec_ = nullptr; + } + CudaGraphExec& operator=(CudaGraphExec&& other) noexcept { + if (this != &other) { + if (exec_) cudaGraphExecDestroy(exec_); + exec_ = other.exec_; + other.exec_ = nullptr; + } + return *this; + } + + cudaGraphExec_t& get() { return exec_; } + cudaGraphExec_t get() const { return exec_; } + bool valid() const { return exec_ != nullptr; } + + // Launch the graph on a stream + void launch(cudaStream_t stream) { + if (exec_) { + CHECK_CUDA(cudaGraphLaunch(exec_, stream)); + } + } + +private: + cudaGraphExec_t exec_; +}; + +// Input signature for validating consistent inputs +struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator==(const InputSignature& other) const { + return ptrs == other.ptrs && shapes == other.shapes && dtypes == other.dtypes; + } + + bool operator!=(const InputSignature& other) const { + return !(*this == other); + } + + bool empty() const { return ptrs.empty(); } +}; + +// Helper to create input signature from tensor list +inline InputSignature make_input_signature(const std::vector& tensors) { + InputSignature sig; + sig.ptrs.reserve(tensors.size()); + sig.shapes.reserve(tensors.size()); + sig.dtypes.reserve(tensors.size()); + for (const auto& t : tensors) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; +} + +// Helper to create input signature from multiple tensor lists (for training) +inline InputSignature make_input_signature(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature sig; + size_t total = inputs.size() + labels.size() + extra_args.size(); + sig.ptrs.reserve(total); + sig.shapes.reserve(total); + sig.dtypes.reserve(total); + + for (const auto& t : inputs) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : labels) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : extra_args) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; +} + +// Validate that current inputs match the captured signature +inline void validate_input_signature(const InputSignature& expected, + const InputSignature& actual, + const char* context) { + if (expected != actual) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context << ". " + << "When cuda_graphs is enabled, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } +} + +// Graph state for inference +struct InferenceGraphState { + int warmup_count = 0; + bool captured = false; + + InputSignature input_signature; + CudaGraph graph; + CudaGraphExec graph_exec; + std::vector static_outputs; +}; + +// Graph state for training (single graph for forward + loss + backward) +struct TrainingGraphState { + int warmup_count = 0; + bool captured = false; + + InputSignature input_signature; + CudaGraph graph; + CudaGraphExec graph_exec; + torch::Tensor static_loss; +}; + +// Graph state for a model, including the capture stream +class ModelGraphState { +public: + InferenceGraphState inference; + TrainingGraphState training; + + ModelGraphState(int device_index = 0) + : capture_stream_(nullptr), device_index_(device_index) { + // Create a non-blocking stream for graph capture + CHECK_CUDA(cudaSetDevice(device_index_)); + CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); + } + + ~ModelGraphState() { + if (capture_stream_) { + cudaStreamDestroy(capture_stream_); + } + } + + // Non-copyable + ModelGraphState(const ModelGraphState&) = delete; + ModelGraphState& operator=(const ModelGraphState&) = delete; + + cudaStream_t capture_stream() const { return capture_stream_; } + int device_index() const { return device_index_; } + + // Get c10 stream wrapper for the capture stream (for PyTorch integration) + c10::cuda::CUDAStream get_capture_cuda_stream() const { + return c10::cuda::getStreamFromExternal(capture_stream_, device_index_); + } + +private: + cudaStream_t capture_stream_; + int device_index_; +}; + +} // namespace torchfort + +#endif // ENABLE_GPU + diff --git a/src/csrc/include/internal/model_pack.h b/src/csrc/include/internal/model_pack.h index 351e96d..f09b331 100644 --- a/src/csrc/include/internal/model_pack.h +++ b/src/csrc/include/internal/model_pack.h @@ -25,6 +25,9 @@ #include "internal/distributed.h" #include "internal/model_state.h" #include "internal/model_wrapper.h" +#ifdef ENABLE_GPU +#include "internal/cuda_graphs.h" +#endif namespace torchfort { @@ -38,6 +41,11 @@ struct ModelPack { std::shared_ptr state; int grad_accumulation_steps = 1; float max_grad_norm = 0.0; + +#ifdef ENABLE_GPU + // CUDA graph state (initialized if enable_cuda_graphs is true) + std::shared_ptr graph_state; +#endif }; void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); diff --git a/src/csrc/include/internal/model_state.h b/src/csrc/include/internal/model_state.h index c4c2111..ff63aba 100644 --- a/src/csrc/include/internal/model_state.h +++ b/src/csrc/include/internal/model_state.h @@ -36,6 +36,9 @@ struct ModelState { bool verbose; std::filesystem::path report_file; + // CUDA graph settings + bool enable_cuda_graphs = false; + void save(const std::string& fname); void load(const std::string& fname); }; diff --git a/src/csrc/setup.cpp b/src/csrc/setup.cpp index 7f0c296..8fbfb02 100644 --- a/src/csrc/setup.cpp +++ b/src/csrc/setup.cpp @@ -262,7 +262,8 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ if (state_node["general"]) { auto params = get_params(state_node["general"]); - std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose"}; + std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", + "enable_cuda_graphs"}; check_params(supported_params, params.keys()); state->report_frequency = params.get_param("report_frequency")[0]; try { @@ -305,6 +306,12 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ } catch (std::out_of_range) { state->verbose = false; } + + try { + state->enable_cuda_graphs = params.get_param("enable_cuda_graphs")[0]; + } catch (std::out_of_range) { + state->enable_cuda_graphs = false; + } } return state; diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index e52f830..bd3eea6 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -125,6 +125,14 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f // Setting up general options models[name].state = get_state(name, config); + +#ifdef ENABLE_GPU + // Initialize graph state if CUDA graphs are enabled + if (models[name].state->enable_cuda_graphs && models[name].model->device().is_cuda()) { + models[name].graph_state = std::make_shared( + models[name].model->device().index()); + } +#endif } catch (const BaseException& e) { std::cerr << e.what(); return e.getResult(); diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 58469f2..a31b44a 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -28,14 +28,38 @@ #include "internal/defines.h" #include "internal/logging.h" +#include "internal/model_pack.h" #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" +#ifdef ENABLE_GPU +#include "internal/cuda_graphs.h" +#endif namespace torchfort { // Declaration of external global variables extern std::unordered_map models; +#ifdef ENABLE_GPU +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + +// Helper to instantiate a CUDA graph from a captured graph +void instantiate_graph(CudaGraph& graph, CudaGraphExec& exec) { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = cudaGraphInstantiate(&exec.get(), graph.get(), &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed: " << cudaGetErrorString(result); + if (strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } +} +#endif + void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in, cudaStream_t ext_stream = 0) { torchfort::nvtx::rangePush("torchfort_inference"); @@ -55,8 +79,9 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor auto model = models[name].model; -#if ENABLE_GPU +#ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; + cudaStream_t user_stream = ext_stream; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); guard.reset_stream(stream); @@ -66,9 +91,79 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor inputs->to(model->device()); model->eval(); - auto results = model->forward(inputs->tensors); - for (int i = 0; i < results.size(); ++i) { + std::vector results; + +#ifdef ENABLE_GPU + // CUDA graph handling + bool capturing = false; + InferenceGraphState* graph_state = nullptr; + cudaStream_t capture_stream = nullptr; + + if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { + + graph_state = &models[name].graph_state->inference; + capture_stream = models[name].graph_state->capture_stream(); + + // Create input signature for validation + InputSignature current_sig = make_input_signature(inputs->tensors); + + if (graph_state->captured) { + + validate_input_signature(graph_state->input_signature, current_sig, "inference"); + + } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { + + // Store input signature used during capture + graph_state->input_signature = current_sig; + + // Synchronize user stream before capture + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + + // Switch PyTorch to use our capture stream + auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); + guard.reset_stream(capture_c10_stream); + + // Begin capture + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + capturing = true; + } + } +#endif + + // Forward pass +#ifdef ENABLE_GPU + if (!graph_state || !graph_state->captured) { +#endif + results = model->forward(inputs->tensors); +#ifdef ENABLE_GPU + if (graph_state) graph_state->warmup_count++; + } + + if (graph_state) { + + if (capturing) { + // End capture and instantiate the graph + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); + instantiate_graph(graph_state->graph, graph_state->graph_exec); + graph_state->static_outputs = results; + graph_state->captured = true; + + // Switch back to user stream for replay and subsequent operations + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); + guard.reset_stream(user_c10_stream); + } + + // Replay graph + if (graph_state->captured) { + graph_state->graph_exec.launch(user_stream); + results = graph_state->static_outputs; + } + } +#endif + + // Copy results to output tensors + for (size_t i = 0; i < results.size(); ++i) { outputs->tensors[i].copy_(results[i].reshape(outputs->tensors[i].sizes())); } models[name].state->step_inference++; @@ -105,6 +200,7 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; + cudaStream_t user_stream = ext_stream; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); guard.reset_stream(stream); @@ -120,21 +216,101 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo auto opt = models[name].optimizer; auto state = models[name].state; - // fwd pass - auto results = model->forward(inputs->tensors); - auto loss = models[name].loss->forward(results, labels->tensors, - (extra_loss_args) ? extra_loss_args->tensors : std::vector()); + torch::Tensor loss; - // extract loss - *loss_val = loss.item(); +#ifdef ENABLE_GPU + // CUDA graph handling + bool capturing = false; + TrainingGraphState* graph_state = nullptr; + cudaStream_t capture_stream = nullptr; + + if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state && + models[name].grad_accumulation_steps == 1) { + + // Note: CUDA graph capture for training is disabled if gradient accumulation is active + + graph_state = &models[name].graph_state->training; + capture_stream = models[name].graph_state->capture_stream(); + + // Create input signature for validation + std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); + InputSignature current_sig = make_input_signature(inputs->tensors, labels->tensors, extra_args_vec); + + if (graph_state->captured) { + + validate_input_signature(graph_state->input_signature, current_sig, "training"); + + } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { + + // Store input signature used during capture + graph_state->input_signature = current_sig; + capturing = true; + } + } +#endif - // bwd pass if (state->step_train_current % models[name].grad_accumulation_steps == 0) { - opt->zero_grad(); +#ifdef ENABLE_GPU + // Only explicitly call zero_grad for non-replay steps + if (!graph_state || !graph_state->captured) { +#endif + opt->zero_grad(/*set_to_none=*/true); +#ifdef ENABLE_GPU + } +#endif } - loss.backward(); +#ifdef ENABLE_GPU + if (capturing) { + // Synchronize user stream before capture + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + + // Switch PyTorch to use our capture stream + auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); + guard.reset_stream(capture_c10_stream); + + // Begin capture on our non-blocking stream + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + } +#endif + + // Forward + loss + backward +#ifdef ENABLE_GPU + if (!graph_state || !graph_state->captured) { +#endif + auto fwd_results = model->forward(inputs->tensors); + loss = models[name].loss->forward(fwd_results, labels->tensors, + (extra_loss_args) ? extra_loss_args->tensors : std::vector()); + loss.backward(); +#ifdef ENABLE_GPU + if (graph_state) graph_state->warmup_count++; + } + + if (graph_state) { + if (capturing) { + // End graph capture and instantiate + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); + instantiate_graph(graph_state->graph, graph_state->graph_exec); + graph_state->static_loss = loss; + graph_state->captured = true; + + // Switch back to user stream for replay and subsequent operations + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); + guard.reset_stream(user_c10_stream); + } + + // Replay graph + if (graph_state->captured) { + graph_state->graph_exec.launch(user_stream); + loss = graph_state->static_loss; + } + } +#endif + + // Extract loss value + *loss_val = loss.item(); + // Optimizer step and related operations if ((state->step_train_current + 1) % models[name].grad_accumulation_steps == 0) { // allreduce (average) gradients (if running distributed) if (models[name].comm) { From ebcdaa519f8334ae572ccea52f3437e45693acd2 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 12:56:46 -0800 Subject: [PATCH 2/6] Simplifying and cleaning up implementation. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 377 ++++++++++++++++++------ src/csrc/training.cpp | 138 ++------- 2 files changed, 301 insertions(+), 214 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 9bb9ccd..2eb8838 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -21,9 +21,11 @@ #include +#include #include #include +#include #include #include @@ -32,6 +34,16 @@ namespace torchfort { +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + +// Action to take for current iteration +enum class GraphAction { + WARMUP, // Run eager execution, increment warmup count + CAPTURE, // Run eager execution with graph capture + REPLAY // Skip eager execution, replay captured graph +}; + // RAII wrapper for cudaGraph_t class CudaGraph { public: @@ -98,117 +110,293 @@ class CudaGraphExec { cudaGraphExec_t get() const { return exec_; } bool valid() const { return exec_ != nullptr; } - // Launch the graph on a stream - void launch(cudaStream_t stream) { - if (exec_) { - CHECK_CUDA(cudaGraphLaunch(exec_, stream)); - } - } - private: cudaGraphExec_t exec_; }; -// Input signature for validating consistent inputs -struct InputSignature { - std::vector ptrs; - std::vector> shapes; - std::vector dtypes; +// Graph state for inference +class InferenceGraphState { +public: + InferenceGraphState(const char* context = "inference") : context_(context) {} + + // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture + // Returns the action to take. Call begin_capture() after this if action == CAPTURE. + GraphAction prepare(const std::vector& inputs) { + InputSignature current_sig = make_input_signature(inputs); - bool operator==(const InputSignature& other) const { - return ptrs == other.ptrs && shapes == other.shapes && dtypes == other.dtypes; + if (captured_) { + validate_inputs(current_sig); + return GraphAction::REPLAY; + } + + if (warmup_count_ == kCudaGraphWarmupIters) { + input_signature_ = std::move(current_sig); + return GraphAction::CAPTURE; + } + + return GraphAction::WARMUP; } - bool operator!=(const InputSignature& other) const { - return !(*this == other); + // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work + void begin_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); + guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); } - bool empty() const { return ptrs.empty(); } -}; + // Finalize after forward pass - handles capture end or warmup increment + void finalize(GraphAction action, + cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index, + const std::vector& outputs) { + if (action == GraphAction::CAPTURE) { + end_capture(capture_stream, user_stream, guard, device_index); + static_outputs_ = outputs; + } else if (action == GraphAction::WARMUP) { + warmup_count_++; + } + } -// Helper to create input signature from tensor list -inline InputSignature make_input_signature(const std::vector& tensors) { - InputSignature sig; - sig.ptrs.reserve(tensors.size()); - sig.shapes.reserve(tensors.size()); - sig.dtypes.reserve(tensors.size()); - for (const auto& t : tensors) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - return sig; -} - -// Helper to create input signature from multiple tensor lists (for training) -inline InputSignature make_input_signature(const std::vector& inputs, - const std::vector& labels, - const std::vector& extra_args) { - InputSignature sig; - size_t total = inputs.size() + labels.size() + extra_args.size(); - sig.ptrs.reserve(total); - sig.shapes.reserve(total); - sig.dtypes.reserve(total); - - for (const auto& t : inputs) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - for (const auto& t : labels) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - for (const auto& t : extra_args) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - return sig; -} - -// Validate that current inputs match the captured signature -inline void validate_input_signature(const InputSignature& expected, - const InputSignature& actual, - const char* context) { - if (expected != actual) { - std::stringstream ss; - ss << "CUDA graph input mismatch in " << context << ". " - << "When cuda_graphs is enabled, input tensors must have consistent " - << "data pointers, shapes, and dtypes across all calls. " - << "If you need to change inputs, disable cuda_graphs."; - THROW_INVALID_USAGE(ss.str()); - } -} + // Launch captured graph on the given stream + void launch(cudaStream_t stream) { + CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); + } -// Graph state for inference -struct InferenceGraphState { - int warmup_count = 0; - bool captured = false; - - InputSignature input_signature; - CudaGraph graph; - CudaGraphExec graph_exec; - std::vector static_outputs; + // Get static outputs (valid after CAPTURE or REPLAY) + const std::vector& get_outputs() const { return static_outputs_; } + + bool is_captured() const { return captured_; } + +private: + // Input signature for validating consistent inputs + struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator!=(const InputSignature& other) const { + return ptrs != other.ptrs || shapes != other.shapes || dtypes != other.dtypes; + } + }; + + static InputSignature make_input_signature(const std::vector& tensors) { + InputSignature sig; + sig.ptrs.reserve(tensors.size()); + sig.shapes.reserve(tensors.size()); + sig.dtypes.reserve(tensors.size()); + for (const auto& t : tensors) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; + } + + void validate_inputs(const InputSignature& current_sig) const { + if (input_signature_ != current_sig) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context_ << ". " + << "When enable_cuda_graphs is set, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable enable_cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } + } + + void end_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + instantiate_graph(); + captured_ = true; + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); + guard.reset_stream(user_c10_stream); + } + + void instantiate_graph() { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), + &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed in " << context_ << ": " + << cudaGetErrorString(result); + if (std::strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } + } + + const char* context_; + int warmup_count_ = 0; + bool captured_ = false; + InputSignature input_signature_; + CudaGraph graph_; + CudaGraphExec graph_exec_; + std::vector static_outputs_; }; // Graph state for training (single graph for forward + loss + backward) -struct TrainingGraphState { - int warmup_count = 0; - bool captured = false; - - InputSignature input_signature; - CudaGraph graph; - CudaGraphExec graph_exec; - torch::Tensor static_loss; +class TrainingGraphState { +public: + TrainingGraphState(const char* context = "training") : context_(context) {} + + // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture + // Returns the action to take. Call begin_capture() after this if action == CAPTURE. + GraphAction prepare(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature current_sig = make_input_signature(inputs, labels, extra_args); + + if (captured_) { + validate_inputs(current_sig); + return GraphAction::REPLAY; + } + + if (warmup_count_ == kCudaGraphWarmupIters) { + input_signature_ = std::move(current_sig); + return GraphAction::CAPTURE; + } + + return GraphAction::WARMUP; + } + + // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work + void begin_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); + guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + } + + // Finalize after forward+loss+backward pass - handles capture end or warmup increment + void finalize(GraphAction action, + cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index, + const torch::Tensor& loss) { + if (action == GraphAction::CAPTURE) { + end_capture(capture_stream, user_stream, guard, device_index); + static_loss_ = loss; + } else if (action == GraphAction::WARMUP) { + warmup_count_++; + } + } + + // Launch captured graph on the given stream + void launch(cudaStream_t stream) { + CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); + } + + // Get static loss (valid after CAPTURE or REPLAY) + const torch::Tensor& get_loss() const { return static_loss_; } + + bool is_captured() const { return captured_; } + +private: + // Input signature for validating consistent inputs + struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator!=(const InputSignature& other) const { + return ptrs != other.ptrs || shapes != other.shapes || dtypes != other.dtypes; + } + }; + + static InputSignature make_input_signature(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature sig; + size_t total = inputs.size() + labels.size() + extra_args.size(); + sig.ptrs.reserve(total); + sig.shapes.reserve(total); + sig.dtypes.reserve(total); + + for (const auto& t : inputs) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : labels) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : extra_args) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; + } + + void validate_inputs(const InputSignature& current_sig) const { + if (input_signature_ != current_sig) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context_ << ". " + << "When enable_cuda_graphs is set, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable enable_cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } + } + + void end_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + instantiate_graph(); + captured_ = true; + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); + guard.reset_stream(user_c10_stream); + } + + void instantiate_graph() { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), + &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed in " << context_ << ": " + << cudaGetErrorString(result); + if (std::strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } + } + + const char* context_; + int warmup_count_ = 0; + bool captured_ = false; + InputSignature input_signature_; + CudaGraph graph_; + CudaGraphExec graph_exec_; + torch::Tensor static_loss_; }; // Graph state for a model, including the capture stream class ModelGraphState { public: - InferenceGraphState inference; - TrainingGraphState training; + InferenceGraphState inference{"inference"}; + TrainingGraphState training{"training"}; ModelGraphState(int device_index = 0) : capture_stream_(nullptr), device_index_(device_index) { @@ -230,11 +418,6 @@ class ModelGraphState { cudaStream_t capture_stream() const { return capture_stream_; } int device_index() const { return device_index_; } - // Get c10 stream wrapper for the capture stream (for PyTorch integration) - c10::cuda::CUDAStream get_capture_cuda_stream() const { - return c10::cuda::getStreamFromExternal(capture_stream_, device_index_); - } - private: cudaStream_t capture_stream_; int device_index_; diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index a31b44a..fa58f32 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -40,26 +40,6 @@ namespace torchfort { // Declaration of external global variables extern std::unordered_map models; -#ifdef ENABLE_GPU -// Number of warmup iterations before CUDA graph capture -constexpr int kCudaGraphWarmupIters = 3; - -// Helper to instantiate a CUDA graph from a captured graph -void instantiate_graph(CudaGraph& graph, CudaGraphExec& exec) { - cudaGraphNode_t error_node; - char log_buffer[1024]; - cudaError_t result = cudaGraphInstantiate(&exec.get(), graph.get(), &error_node, log_buffer, sizeof(log_buffer)); - if (result != cudaSuccess) { - std::stringstream ss; - ss << "CUDA graph instantiation failed: " << cudaGetErrorString(result); - if (strlen(log_buffer) > 0) { - ss << " Log: " << log_buffer; - } - THROW_INTERNAL_ERROR(ss.str()); - } -} -#endif - void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in, cudaStream_t ext_stream = 0) { torchfort::nvtx::rangePush("torchfort_inference"); @@ -81,7 +61,6 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor #ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; - cudaStream_t user_stream = ext_stream; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); guard.reset_stream(stream); @@ -95,69 +74,33 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor std::vector results; #ifdef ENABLE_GPU - // CUDA graph handling - bool capturing = false; + GraphAction action = GraphAction::WARMUP; InferenceGraphState* graph_state = nullptr; - cudaStream_t capture_stream = nullptr; if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { - graph_state = &models[name].graph_state->inference; - capture_stream = models[name].graph_state->capture_stream(); - - // Create input signature for validation - InputSignature current_sig = make_input_signature(inputs->tensors); - - if (graph_state->captured) { - - validate_input_signature(graph_state->input_signature, current_sig, "inference"); - - } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { - - // Store input signature used during capture - graph_state->input_signature = current_sig; - - // Synchronize user stream before capture - CHECK_CUDA(cudaStreamSynchronize(user_stream)); - - // Switch PyTorch to use our capture stream - auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); - guard.reset_stream(capture_c10_stream); + action = graph_state->prepare(inputs->tensors); - // Begin capture - CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); - capturing = true; + if (action == GraphAction::CAPTURE) { + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); } } #endif // Forward pass #ifdef ENABLE_GPU - if (!graph_state || !graph_state->captured) { + if (action != GraphAction::REPLAY) { #endif results = model->forward(inputs->tensors); #ifdef ENABLE_GPU - if (graph_state) graph_state->warmup_count++; } if (graph_state) { + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), results); - if (capturing) { - // End capture and instantiate the graph - CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); - instantiate_graph(graph_state->graph, graph_state->graph_exec); - graph_state->static_outputs = results; - graph_state->captured = true; - - // Switch back to user stream for replay and subsequent operations - auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); - guard.reset_stream(user_c10_stream); - } - - // Replay graph - if (graph_state->captured) { - graph_state->graph_exec.launch(user_stream); - results = graph_state->static_outputs; + if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { + graph_state->launch(ext_stream); + results = graph_state->get_outputs(); } } #endif @@ -200,7 +143,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard guard; - cudaStream_t user_stream = ext_stream; if (model->device().is_cuda()) { auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); guard.reset_stream(stream); @@ -219,40 +161,22 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo torch::Tensor loss; #ifdef ENABLE_GPU - // CUDA graph handling - bool capturing = false; + GraphAction action = GraphAction::WARMUP; TrainingGraphState* graph_state = nullptr; - cudaStream_t capture_stream = nullptr; + // Note: CUDA graph capture for training is disabled if gradient accumulation is active if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state && models[name].grad_accumulation_steps == 1) { - - // Note: CUDA graph capture for training is disabled if gradient accumulation is active - graph_state = &models[name].graph_state->training; - capture_stream = models[name].graph_state->capture_stream(); - - // Create input signature for validation std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); - InputSignature current_sig = make_input_signature(inputs->tensors, labels->tensors, extra_args_vec); - - if (graph_state->captured) { - - validate_input_signature(graph_state->input_signature, current_sig, "training"); - - } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { - - // Store input signature used during capture - graph_state->input_signature = current_sig; - capturing = true; - } + action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); } #endif if (state->step_train_current % models[name].grad_accumulation_steps == 0) { #ifdef ENABLE_GPU - // Only explicitly call zero_grad for non-replay steps - if (!graph_state || !graph_state->captured) { + // zero_grad is only needed for non-replay steps + if (action != GraphAction::REPLAY) { #endif opt->zero_grad(/*set_to_none=*/true); #ifdef ENABLE_GPU @@ -261,48 +185,28 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } #ifdef ENABLE_GPU - if (capturing) { - // Synchronize user stream before capture - CHECK_CUDA(cudaStreamSynchronize(user_stream)); - - // Switch PyTorch to use our capture stream - auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); - guard.reset_stream(capture_c10_stream); - - // Begin capture on our non-blocking stream - CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + if (action == GraphAction::CAPTURE) { + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); } #endif // Forward + loss + backward #ifdef ENABLE_GPU - if (!graph_state || !graph_state->captured) { + if (action != GraphAction::REPLAY) { #endif auto fwd_results = model->forward(inputs->tensors); loss = models[name].loss->forward(fwd_results, labels->tensors, (extra_loss_args) ? extra_loss_args->tensors : std::vector()); loss.backward(); #ifdef ENABLE_GPU - if (graph_state) graph_state->warmup_count++; } if (graph_state) { - if (capturing) { - // End graph capture and instantiate - CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); - instantiate_graph(graph_state->graph, graph_state->graph_exec); - graph_state->static_loss = loss; - graph_state->captured = true; - - // Switch back to user stream for replay and subsequent operations - auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); - guard.reset_stream(user_c10_stream); - } + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), loss); - // Replay graph - if (graph_state->captured) { - graph_state->graph_exec.launch(user_stream); - loss = graph_state->static_loss; + if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { + graph_state->launch(ext_stream); + loss = graph_state->get_loss(); } } #endif From e5bb25b9033de13155b057072bf64145f1172a50 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 15:34:54 -0800 Subject: [PATCH 3/6] Adding graph support for grad accumulation. Cleaning up some ifdefs. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 17 +++++++----- src/csrc/training.cpp | 35 ++++++++++++------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 2eb8838..611d374 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -18,15 +18,15 @@ #pragma once #ifdef ENABLE_GPU - #include +#include +#include +#endif #include #include #include -#include -#include #include #include "internal/defines.h" @@ -34,9 +34,6 @@ namespace torchfort { -// Number of warmup iterations before CUDA graph capture -constexpr int kCudaGraphWarmupIters = 3; - // Action to take for current iteration enum class GraphAction { WARMUP, // Run eager execution, increment warmup count @@ -44,6 +41,11 @@ enum class GraphAction { REPLAY // Skip eager execution, replay captured graph }; +#ifdef ENABLE_GPU + +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + // RAII wrapper for cudaGraph_t class CudaGraph { public: @@ -423,7 +425,8 @@ class ModelGraphState { int device_index_; }; +#endif + } // namespace torchfort -#endif // ENABLE_GPU diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index fa58f32..7224459 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -32,9 +32,7 @@ #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" -#ifdef ENABLE_GPU #include "internal/cuda_graphs.h" -#endif namespace torchfort { // Declaration of external global variables @@ -73,8 +71,9 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor std::vector results; -#ifdef ENABLE_GPU GraphAction action = GraphAction::WARMUP; + +#ifdef ENABLE_GPU InferenceGraphState* graph_state = nullptr; if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { @@ -88,13 +87,11 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor #endif // Forward pass -#ifdef ENABLE_GPU if (action != GraphAction::REPLAY) { -#endif results = model->forward(inputs->tensors); -#ifdef ENABLE_GPU } +#ifdef ENABLE_GPU if (graph_state) { graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), results); @@ -160,13 +157,12 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo torch::Tensor loss; -#ifdef ENABLE_GPU GraphAction action = GraphAction::WARMUP; + +#ifdef ENABLE_GPU TrainingGraphState* graph_state = nullptr; - // Note: CUDA graph capture for training is disabled if gradient accumulation is active - if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state && - models[name].grad_accumulation_steps == 1) { + if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { graph_state = &models[name].graph_state->training; std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); @@ -174,14 +170,19 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #endif if (state->step_train_current % models[name].grad_accumulation_steps == 0) { + // Only run zero_grad on non-replay steps or if gradient accumulation is active + if (action != GraphAction::REPLAY || models[name].grad_accumulation_steps > 1) { + if (models[name].grad_accumulation_steps > 1) { #ifdef ENABLE_GPU - // zero_grad is only needed for non-replay steps - if (action != GraphAction::REPLAY) { + // With graphs and grad accumulation active, gradients must be persistent (set_to_none = false) + opt->zero_grad(/*set_to_none=*/(graph_state == nullptr)); +#else + opt->zero_grad(/*set_to_none=*/true); #endif - opt->zero_grad(/*set_to_none=*/true); -#ifdef ENABLE_GPU + } else { + opt->zero_grad(/*set_to_none=*/true); + } } -#endif } #ifdef ENABLE_GPU @@ -191,16 +192,14 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #endif // Forward + loss + backward -#ifdef ENABLE_GPU if (action != GraphAction::REPLAY) { -#endif auto fwd_results = model->forward(inputs->tensors); loss = models[name].loss->forward(fwd_results, labels->tensors, (extra_loss_args) ? extra_loss_args->tensors : std::vector()); loss.backward(); -#ifdef ENABLE_GPU } +#ifdef ENABLE_GPU if (graph_state) { graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), loss); From 76b058e976bf3dd67df19afd7f9dbed86d1fc9e9 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 15:36:08 -0800 Subject: [PATCH 4/6] Update docs. Signed-off-by: Josh Romero --- docs/api/config.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api/config.rst b/docs/api/config.rst index 8a06f6a..72539d1 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -55,7 +55,6 @@ and improve performance for models with many small operations. - Input tensors must be on GPU and must have consistent data pointers, shapes, and dtypes across all training/inference calls with the captured model. If inputs change after graph capture, an error will be thrown. -- For training, CUDA graph capture is automatically disabled when gradient accumulation (``grad_accumulation_steps > 1``) is active. - The optimizer step and learning rate scheduler updates are not captured in the graph. - A warmup period of 3 iterations is performed before graph capture to ensure stable execution. From 02289e71b2cff9e83adf2d61ada07ea1238b0810 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 15:51:48 -0800 Subject: [PATCH 5/6] Formatting. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 83 +++++++++---------------- src/csrc/setup.cpp | 3 +- src/csrc/torchfort.cpp | 3 +- src/csrc/training.cpp | 14 +++-- 4 files changed, 39 insertions(+), 64 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 611d374..307b395 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -18,9 +18,9 @@ #pragma once #ifdef ENABLE_GPU -#include #include #include +#include #endif #include @@ -36,9 +36,9 @@ namespace torchfort { // Action to take for current iteration enum class GraphAction { - WARMUP, // Run eager execution, increment warmup count - CAPTURE, // Run eager execution with graph capture - REPLAY // Skip eager execution, replay captured graph + WARMUP, // Run eager execution, increment warmup count + CAPTURE, // Run eager execution with graph capture + REPLAY // Skip eager execution, replay captured graph }; #ifdef ENABLE_GPU @@ -61,12 +61,11 @@ class CudaGraph { CudaGraph& operator=(const CudaGraph&) = delete; // Movable - CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { - other.graph_ = nullptr; - } + CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { other.graph_ = nullptr; } CudaGraph& operator=(CudaGraph&& other) noexcept { if (this != &other) { - if (graph_) cudaGraphDestroy(graph_); + if (graph_) + cudaGraphDestroy(graph_); graph_ = other.graph_; other.graph_ = nullptr; } @@ -96,12 +95,11 @@ class CudaGraphExec { CudaGraphExec& operator=(const CudaGraphExec&) = delete; // Movable - CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { - other.exec_ = nullptr; - } + CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { other.exec_ = nullptr; } CudaGraphExec& operator=(CudaGraphExec&& other) noexcept { if (this != &other) { - if (exec_) cudaGraphExecDestroy(exec_); + if (exec_) + cudaGraphExecDestroy(exec_); exec_ = other.exec_; other.exec_ = nullptr; } @@ -140,9 +138,7 @@ class InferenceGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamSynchronize(user_stream)); auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); @@ -151,11 +147,8 @@ class InferenceGraphState { } // Finalize after forward pass - handles capture end or warmup increment - void finalize(GraphAction action, - cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index, + void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, const std::vector& outputs) { if (action == GraphAction::CAPTURE) { end_capture(capture_stream, user_stream, guard, device_index); @@ -166,9 +159,7 @@ class InferenceGraphState { } // Launch captured graph on the given stream - void launch(cudaStream_t stream) { - CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); - } + void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } // Get static outputs (valid after CAPTURE or REPLAY) const std::vector& get_outputs() const { return static_outputs_; } @@ -211,9 +202,7 @@ class InferenceGraphState { } } - void end_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); instantiate_graph(); @@ -225,12 +214,11 @@ class InferenceGraphState { void instantiate_graph() { cudaGraphNode_t error_node; char log_buffer[1024]; - cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), - &error_node, log_buffer, sizeof(log_buffer)); + cudaError_t result = + cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), &error_node, log_buffer, sizeof(log_buffer)); if (result != cudaSuccess) { std::stringstream ss; - ss << "CUDA graph instantiation failed in " << context_ << ": " - << cudaGetErrorString(result); + ss << "CUDA graph instantiation failed in " << context_ << ": " << cudaGetErrorString(result); if (std::strlen(log_buffer) > 0) { ss << " Log: " << log_buffer; } @@ -254,8 +242,7 @@ class TrainingGraphState { // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture // Returns the action to take. Call begin_capture() after this if action == CAPTURE. - GraphAction prepare(const std::vector& inputs, - const std::vector& labels, + GraphAction prepare(const std::vector& inputs, const std::vector& labels, const std::vector& extra_args) { InputSignature current_sig = make_input_signature(inputs, labels, extra_args); @@ -273,9 +260,7 @@ class TrainingGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamSynchronize(user_stream)); auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); @@ -284,12 +269,8 @@ class TrainingGraphState { } // Finalize after forward+loss+backward pass - handles capture end or warmup increment - void finalize(GraphAction action, - cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index, - const torch::Tensor& loss) { + void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, const torch::Tensor& loss) { if (action == GraphAction::CAPTURE) { end_capture(capture_stream, user_stream, guard, device_index); static_loss_ = loss; @@ -299,9 +280,7 @@ class TrainingGraphState { } // Launch captured graph on the given stream - void launch(cudaStream_t stream) { - CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); - } + void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } // Get static loss (valid after CAPTURE or REPLAY) const torch::Tensor& get_loss() const { return static_loss_; } @@ -358,9 +337,7 @@ class TrainingGraphState { } } - void end_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); instantiate_graph(); @@ -372,12 +349,11 @@ class TrainingGraphState { void instantiate_graph() { cudaGraphNode_t error_node; char log_buffer[1024]; - cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), - &error_node, log_buffer, sizeof(log_buffer)); + cudaError_t result = + cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), &error_node, log_buffer, sizeof(log_buffer)); if (result != cudaSuccess) { std::stringstream ss; - ss << "CUDA graph instantiation failed in " << context_ << ": " - << cudaGetErrorString(result); + ss << "CUDA graph instantiation failed in " << context_ << ": " << cudaGetErrorString(result); if (std::strlen(log_buffer) > 0) { ss << " Log: " << log_buffer; } @@ -400,8 +376,7 @@ class ModelGraphState { InferenceGraphState inference{"inference"}; TrainingGraphState training{"training"}; - ModelGraphState(int device_index = 0) - : capture_stream_(nullptr), device_index_(device_index) { + ModelGraphState(int device_index = 0) : capture_stream_(nullptr), device_index_(device_index) { // Create a non-blocking stream for graph capture CHECK_CUDA(cudaSetDevice(device_index_)); CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); @@ -428,5 +403,3 @@ class ModelGraphState { #endif } // namespace torchfort - - diff --git a/src/csrc/setup.cpp b/src/csrc/setup.cpp index 8fbfb02..c052065 100644 --- a/src/csrc/setup.cpp +++ b/src/csrc/setup.cpp @@ -262,8 +262,7 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ if (state_node["general"]) { auto params = get_params(state_node["general"]); - std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", - "enable_cuda_graphs"}; + std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", "enable_cuda_graphs"}; check_params(supported_params, params.keys()); state->report_frequency = params.get_param("report_frequency")[0]; try { diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index bd3eea6..8ed2445 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -129,8 +129,7 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f #ifdef ENABLE_GPU // Initialize graph state if CUDA graphs are enabled if (models[name].state->enable_cuda_graphs && models[name].model->device().is_cuda()) { - models[name].graph_state = std::make_shared( - models[name].model->device().index()); + models[name].graph_state = std::make_shared(models[name].model->device().index()); } #endif } catch (const BaseException& e) { diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 7224459..4925c1f 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -26,13 +26,13 @@ #endif #include +#include "internal/cuda_graphs.h" #include "internal/defines.h" #include "internal/logging.h" #include "internal/model_pack.h" #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" -#include "internal/cuda_graphs.h" namespace torchfort { // Declaration of external global variables @@ -81,7 +81,8 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor action = graph_state->prepare(inputs->tensors); if (action == GraphAction::CAPTURE) { - graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index()); } } #endif @@ -93,7 +94,8 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor #ifdef ENABLE_GPU if (graph_state) { - graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), results); + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index(), results); if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { graph_state->launch(ext_stream); @@ -164,7 +166,8 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { graph_state = &models[name].graph_state->training; - std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); + std::vector extra_args_vec = + extra_loss_args ? extra_loss_args->tensors : std::vector(); action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); } #endif @@ -201,7 +204,8 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #ifdef ENABLE_GPU if (graph_state) { - graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), loss); + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index(), loss); if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { graph_state->launch(ext_stream); From 9e890ec761b9864c5447e356d70beb26a3047240 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 4 Dec 2025 11:00:13 -0800 Subject: [PATCH 6/6] Move loss D2H copy and allreduce after optimizer step call. Signed-off-by: Josh Romero --- src/csrc/training.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 4925c1f..488062e 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -214,9 +214,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } #endif - // Extract loss value - *loss_val = loss.item(); - // Optimizer step and related operations if ((state->step_train_current + 1) % models[name].grad_accumulation_steps == 0) { // allreduce (average) gradients (if running distributed) @@ -227,9 +224,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo grads.push_back(p.grad()); } models[name].comm->allreduce(grads, true); - - // average returned loss value - models[name].comm->allreduce(*loss_val, true); } if (models[name].max_grad_norm > 0.0) { @@ -242,6 +236,13 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } } + // Extract loss value + *loss_val = loss.item(); + if (models[name].comm) { + // average returned loss value (if running distributed) + models[name].comm->allreduce(*loss_val, true); + } + state->step_train++; state->step_train_current++; if (state->report_frequency > 0 && state->step_train % state->report_frequency == 0) {