Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAK

target_sources(${PROJECT_NAME}
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/cuda_wrap.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/logging.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/model_state.cpp
Expand Down
81 changes: 81 additions & 0 deletions src/csrc/cuda_wrap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef ENABLE_GPU
#include <cuda_runtime.h>

#include "internal/cuda_wrap.h"
#include "internal/defines.h"

#if CUDART_VERSION >= 13000
#define LOAD_SYM(symbol, version, optional) \
do { \
cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \
cudaError_t err = cudaGetDriverEntryPointByVersion(#symbol, (void**)(&cuFnTable.pfn_##symbol), version, \
cudaEnableDefault, &driverStatus)); \
if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \
THROW_CUDA_ERROR("cudaGetDriverEntryPointByVersion failed."); \
} \
} while (false)
#elif CUDART_VERSION >= 12000
#define LOAD_SYM(symbol, version, optional) \
do { \
cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \
cudaError_t err = \
cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault, &driverStatus); \
if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \
THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \
} \
} while (false)
#else
#define LOAD_SYM(symbol, version, optional) \
do { \
cudaError_t err = cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault); \
if (err != cudaSuccess && !optional) { \
THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \
} \
} while (false)
#endif

namespace torchfort {

cuFunctionTable cuFnTable; // global table of required CUDA driver functions

void initCuFunctionTable() {
std::lock_guard<std::mutex> guard(cuFnTable.mutex);

if (cuFnTable.initialized) {
return;
}

#if CUDART_VERSION >= 11030
LOAD_SYM(cuCtxGetCurrent, 4000, false);
LOAD_SYM(cuCtxGetDevice, 2000, false);
LOAD_SYM(cuCtxSetCurrent, 4000, false);
LOAD_SYM(cuGetErrorString, 6000, false);
LOAD_SYM(cuStreamGetCtx, 9020, false);
#if CUDART_VERSION >= 12080
LOAD_SYM(cuStreamGetDevice, 12080, true);
#endif
#endif
cuFnTable.initialized = true;
}

} // namespace torchfort

#undef LOAD_SYM
#endif
49 changes: 49 additions & 0 deletions src/csrc/include/internal/cuda_wrap.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <mutex>
#if CUDART_VERSION >= 11030
#include <cudaTypedefs.h>
#endif

#define DECLARE_CUDA_PFN(symbol, version) PFN_##symbol##_v##version pfn_##symbol = nullptr

namespace torchfort {

struct cuFunctionTable {
#if CUDART_VERSION >= 11030
DECLARE_CUDA_PFN(cuCtxGetCurrent, 4000);
DECLARE_CUDA_PFN(cuCtxGetDevice, 2000);
DECLARE_CUDA_PFN(cuCtxSetCurrent, 4000);
DECLARE_CUDA_PFN(cuGetErrorString, 6000);
DECLARE_CUDA_PFN(cuStreamGetCtx, 9020);
#if CUDART_VERSION >= 12080
DECLARE_CUDA_PFN(cuStreamGetDevice, 12080);
#endif
#endif
bool initialized = false;
std::mutex mutex;
};

extern cuFunctionTable cuFnTable;

void initCuFunctionTable();
} // namespace torchfort

#undef DECLARE_CUDA_PFN
22 changes: 22 additions & 0 deletions src/csrc/include/internal/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "internal/base_loss.h"
#include "internal/base_model.h"
#include "internal/cuda_wrap.h"
#include "internal/exceptions.h"
#include "internal/utils.h"

Expand All @@ -44,6 +45,19 @@
} \
} while (false)

#define CHECK_CUDA_DRV(call) \
do { \
if (!cuFnTable.initialized) { \
initCuFunctionTable(); \
} \
CUresult err = cuFnTable.pfn_##call; \
if (CUDA_SUCCESS != err) { \
const char* error_str; \
cuFnTable.pfn_cuGetErrorString(err, &error_str); \
throw torchfort::CudaError(__FILE__, __LINE__, error_str); \
} \
} while (false)

#define CHECK_NCCL(call) \
do { \
ncclResult_t err = call; \
Expand Down Expand Up @@ -72,6 +86,14 @@
} \
} while (false)

#define IS_CUDA_DRV_FUNC_AVAILABLE(symbol) \
([&]() { \
if (!cuFnTable.initialized) { \
initCuFunctionTable(); \
} \
return cuFnTable.pfn_##symbol != nullptr; \
})()

#define BEGIN_MODEL_REGISTRY \
static std::unordered_map<std::string, std::function<std::shared_ptr<BaseModel>()>> model_registry {

Expand Down
41 changes: 16 additions & 25 deletions src/csrc/include/internal/rl/off_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "internal/defines.h"
#include "internal/logging.h"
#include "internal/utils.h"

namespace torchfort {

Expand Down Expand Up @@ -90,12 +91,10 @@ static void update_replay_buffer(const char* name, T* state_old, T* state_new, s
torch::NoGradGuard no_grad;

#ifdef ENABLE_GPU
c10::cuda::OptionalCUDAStreamGuard guard;
auto rb_device = registry[name]->rbDevice();
if (rb_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream);
#endif

// get tensors and copy:
Expand All @@ -121,12 +120,10 @@ static void update_replay_buffer(const char* name, T* state_old, T* state_new, s
torch::NoGradGuard no_grad;

#ifdef ENABLE_GPU
c10::cuda::OptionalCUDAStreamGuard guard;
auto rb_device = registry[name]->rbDevice();
if (rb_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream);
#endif

// get tensors and copy:
Expand All @@ -152,12 +149,10 @@ static void predict_explore(const char* name, T* state, size_t state_dim, int64_

#ifdef ENABLE_GPU
// device and stream handling
c10::cuda::OptionalCUDAStreamGuard guard;
auto model_device = registry[name]->modelDevice();
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
#endif

// create tensors
Expand Down Expand Up @@ -190,12 +185,10 @@ static void predict(const char* name, T* state, size_t state_dim, int64_t* state

#ifdef ENABLE_GPU
// device and stream handling
c10::cuda::OptionalCUDAStreamGuard guard;
auto model_device = registry[name]->modelDevice();
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
#endif

// create tensors
Expand Down Expand Up @@ -228,12 +221,10 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_

#ifdef ENABLE_GPU
// device and stream handling
c10::cuda::OptionalCUDAStreamGuard guard;
auto model_device = registry[name]->modelDevice();
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
#endif

// create tensors
Expand Down
49 changes: 18 additions & 31 deletions src/csrc/include/internal/rl/on_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "internal/defines.h"
#include "internal/logging.h"
#include "internal/utils.h"

namespace torchfort {

Expand Down Expand Up @@ -94,14 +95,10 @@ static void update_rollout_buffer(const char* name, T* state, size_t state_dim,
auto model_device = registry[name]->modelDevice();
auto rb_device = registry[name]->rbDevice();
#ifdef ENABLE_GPU
c10::cuda::OptionalCUDAStreamGuard guard;
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
} else if (rb_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream);
#endif

// get tensors and copy:
Expand All @@ -127,14 +124,10 @@ static void update_rollout_buffer(const char* name, T* state, size_t state_dim,
auto model_device = registry[name]->modelDevice();
auto rb_device = registry[name]->rbDevice();
#ifdef ENABLE_GPU
c10::cuda::OptionalCUDAStreamGuard guard;
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
} else if (rb_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream);
#endif

// get tensors and copy:
Expand All @@ -157,12 +150,10 @@ static void predict_explore(const char* name, T* state, size_t state_dim, int64_

#ifdef ENABLE_GPU
// device and stream handling
c10::cuda::OptionalCUDAStreamGuard guard;
auto model_device = registry[name]->modelDevice();
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
#endif

// create tensors
Expand Down Expand Up @@ -194,12 +185,10 @@ static void predict(const char* name, T* state, size_t state_dim, int64_t* state

#ifdef ENABLE_GPU
// device and stream handling
c10::cuda::OptionalCUDAStreamGuard guard;
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
auto model_device = registry[name]->modelDevice();
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
}
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
#endif

// create tensors
Expand Down Expand Up @@ -232,12 +221,10 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_

#ifdef ENABLE_GPU
// device and stream handling
c10::cuda::OptionalCUDAStreamGuard guard;
auto model_device = registry[name]->modelDevice();
if (model_device.is_cuda()) {
auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index());
guard.reset_stream(stream);
}
c10::cuda::OptionalCUDAStreamGuard stream_guard;
c10::cuda::OptionalCUDAGuard cuda_guard;
set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream);
#endif

// create tensors
Expand Down
9 changes: 9 additions & 0 deletions src/csrc/include/internal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include <vector>

#include <c10/core/TensorOptions.h>
#ifdef ENABLE_GPU
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#endif
#include <torch/torch.h>

#ifdef ENABLE_GPU
Expand Down Expand Up @@ -114,4 +118,9 @@ std::string print_tensor_shape(torch::Tensor tensor);
// Helper function to get the lrs
std::vector<double> get_current_lrs(const char* name);

#ifdef ENABLE_GPU
// Helper function to set the device and stream with device checks
void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard,
torch::Device device, cudaStream_t ext_stream);
#endif
} // namespace torchfort
Loading