diff --git a/.github/workflows/ifu.yml b/.github/workflows/ifu.yml
new file mode 100644
index 000000000000..82c13c6a12a3
--- /dev/null
+++ b/.github/workflows/ifu.yml
@@ -0,0 +1,63 @@
+name: IntegrateFromUpstream
+on:
+# schedule:
+# # verified via crontab.guru website. “At 06:55 on Monday.”
+# - cron: '55 6 * * 1'
+ workflow_dispatch:
+ inputs:
+ message:
+ description: 'Reason for manual trigger'
+ required: false
+ default: 'refresh branch'
+jobs:
+ IntegrateFromUpstream:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+ - name: Get Current Date
+ id: date
+ run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
+ - name: Extract branch name
+ id: extract_branch
+ shell: bash
+ run: echo "##[set-output name=branch;]$(echo ${GITHUB_REF#refs/heads/})"
+ - name: Fetch and Merge
+ id: fetch_and_merge
+ run: |
+ echo "Reason for trigger: ${{ github.event.inputs.message }}"
+ echo "Actor for trigger: ${{ github.actor }}"
+ git config user.name github-actions
+ git config user.email github-actions@github.com
+ git remote add upstream https://github.com/microsoft/DeepSpeed
+ git fetch upstream master
+ git merge upstream/master
+ # Since we use our own fork of DeepSpeedExamples, ignore theirs
+ git checkout HEAD DeepSpeedExamples
+ - name: Create Pull Request
+ id: create_pull_request
+ uses: jithunnair-amd/create-pull-request@v3
+ with:
+# token: ${{ secrets.PAT }}
+ branch: IFU-${{ steps.extract_branch.outputs.branch }}-${{ steps.date.outputs.date }}
+ title: IFU-${{ steps.extract_branch.outputs.branch }}-${{ steps.date.outputs.date }}
+ assignees: rraminen
+ reviewers: jithunnair-amd
+ delete-branch: true
+ - name: Send email
+ uses: jithunnair-amd/action-send-mail@v3.1.0
+ if: always()
+ with:
+ server_address: smtp.gmail.com
+ server_port: 465
+ secure: true
+ username: ${{ secrets.GMAIL_USERNAME }}
+ password: ${{ secrets.GMAIL_PASSWORD }}
+ subject: IFU to ${{ steps.extract_branch.outputs.branch }} branch of ${{ github.repository }}
+ to: Jithun.Nair@amd.com, RamyaSai.Ramineni@amd.com
+ from: ${{ secrets.GMAIL_USERNAME }}
+ html_body: |
+ Fetch and Merge: ${{ steps.fetch_and_merge.outcome }}
+ Create Pull Request: ${{ steps.create_pull_request.outcome }}
+ Pull request: ${{ steps.create_pull_request.outputs.pull-request-url }}
diff --git a/.gitmodules b/.gitmodules
index 37adb6f39e5c..072bdc50817e 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,4 +1,3 @@
[submodule "DeepSpeedExamples"]
path = DeepSpeedExamples
- url = https://github.com/microsoft/DeepSpeedExamples
- branch = master
+ url = https://github.com/ROCmSoftwarePlatform/DeepSpeedExamples.git
diff --git a/DeepSpeedExamples b/DeepSpeedExamples
index 25d73cf73fb3..36846da89d5b 160000
--- a/DeepSpeedExamples
+++ b/DeepSpeedExamples
@@ -1 +1 @@
-Subproject commit 25d73cf73fb3dc66faefa141b7319526555be9fc
+Subproject commit 36846da89d5be7e13465f95be7074b4ccd5898cd
diff --git a/csrc/includes/cublas_wrappers.h b/csrc/includes/cublas_wrappers.h
index 19d726c3bcd3..9bb6cc30f6ae 100644
--- a/csrc/includes/cublas_wrappers.h
+++ b/csrc/includes/cublas_wrappers.h
@@ -5,7 +5,9 @@
#include
#include
#include
+#ifndef __HIP_PLATFORM_HCC__
#include
+#endif
#include
int cublas_gemm_ex(cublasHandle_t handle,
@@ -19,7 +21,11 @@ int cublas_gemm_ex(cublasHandle_t handle,
const float* A,
const float* B,
float* C,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
+#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
+#endif
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
@@ -32,7 +38,11 @@ int cublas_gemm_ex(cublasHandle_t handle,
const __half* A,
const __half* B,
__half* C,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
+#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+#endif
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
@@ -49,7 +59,11 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_B,
int stride_C,
int batch,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
+#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
+#endif
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
@@ -66,4 +80,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_B,
int stride_C,
int batch,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
+#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+#endif
diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h
index eb10234d11f2..57d372f6ee14 100644
--- a/csrc/includes/custom_cuda_layers.h
+++ b/csrc/includes/custom_cuda_layers.h
@@ -5,7 +5,11 @@
#include
#include
+#ifdef __HIP_PLATFORM_HCC__
+#include
+#else
#include
+#endif
#include
#include "context.h"
diff --git a/csrc/includes/feed_forward.h b/csrc/includes/feed_forward.h
index 7b7379d9b998..3a59d56ee6cd 100644
--- a/csrc/includes/feed_forward.h
+++ b/csrc/includes/feed_forward.h
@@ -43,7 +43,11 @@ class FeedForward {
weights,
input_ptr,
out,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(config_.gemm_algos[0]));
+#else
cublasGemmAlgo_t(config_.gemm_algos[0]));
+#endif
}
void Backward(int bsz,
const T* out_grad,
@@ -68,7 +72,11 @@ class FeedForward {
input_ptr,
out_grad,
weights_grad,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(config_.gemm_algos[1]));
+#else
cublasGemmAlgo_t(config_.gemm_algos[1]));
+#endif
cublas_gemm_ex(_cublasHandle,
CUBLAS_OP_N,
@@ -81,7 +89,11 @@ class FeedForward {
weights,
out_grad,
inp_grad_out,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(config_.gemm_algos[2]));
+#else
cublasGemmAlgo_t(config_.gemm_algos[2]));
+#endif
launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, config_.outputSize, stream);
}
diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h
index b920896b419e..647861c09afd 100644
--- a/csrc/includes/gemm_test.h
+++ b/csrc/includes/gemm_test.h
@@ -2,7 +2,9 @@
#pragma once
#include
+#ifndef __HIP_PLATFORM_HCC__
#include
+#endif
#include
#include
#include
@@ -58,7 +60,11 @@ class GemmTest {
B,
A,
C,
+#ifdef __HIP_PLATFORM_HCC__
+ static_cast(algo));
+#else
static_cast(algo));
+#endif
});
int algo_bw1 = Run(loops, [=](int algo) {
@@ -73,7 +79,11 @@ class GemmTest {
A,
C,
B,
+#ifdef __HIP_PLATFORM_HCC__
+ static_cast(algo));
+#else
static_cast(algo));
+#endif
});
int algo_bw2 = Run(loops, [=](int algo) {
@@ -88,7 +98,11 @@ class GemmTest {
B,
C,
A,
+#ifdef __HIP_PLATFORM_HCC__
+ static_cast(algo));
+#else
static_cast(algo));
+#endif
});
return std::array({algo_fw, algo_bw1, algo_bw2});
@@ -100,8 +114,13 @@ class GemmTest {
float fast_latency = (std::numeric_limits::max)();
int fast_algo = 0;
+#ifdef __HIP_PLATFORM_HCC__
+ for (int algo = (int)rocblas_gemm_algo_standard;
+ algo <= (int)rocblas_gemm_algo_standard;
+#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
+#endif
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
@@ -186,7 +205,11 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
+#ifdef __HIP_PLATFORM_HCC__
+ static_cast(algo));
+#else
static_cast(algo));
+#endif
});
int algo_bw1 = Run(loops, [=](int algo) {
@@ -216,7 +239,11 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
+#ifdef __HIP_PLATFORM_HCC__
+ static_cast(algo));
+#else
static_cast(algo));
+#endif
});
int algo_bw2 = Run(loops, [=](int algo) {
@@ -243,7 +270,11 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
+#ifdef __HIP_PLATFORM_HCC__
+ static_cast(algo));
+#else
static_cast(algo));
+#endif
});
return std::array({algo_fw, algo_bw1, algo_bw2});
@@ -255,8 +286,13 @@ class StridedGemmTest {
float fast_latency = (std::numeric_limits::max)();
int fast_algo = 0;
+#ifdef __HIP_PLATFORM_HCC__
+ for (int algo = (int)rocblas_gemm_algo_standard;
+ algo <= (int)rocblas_gemm_algo_standard;
+#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
+#endif
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
diff --git a/csrc/includes/general_kernels.h b/csrc/includes/general_kernels.h
index 588cf2aaa048..62416f0124dc 100644
--- a/csrc/includes/general_kernels.h
+++ b/csrc/includes/general_kernels.h
@@ -3,7 +3,11 @@
#include
#include
+#ifdef __HIP_PLATFORM_HCC__
+#include
+#else
#include
+#endif
#include
#include "context.h"
diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h
index 44a1b313b986..6233e3e06240 100644
--- a/csrc/includes/strided_batch_gemm.h
+++ b/csrc/includes/strided_batch_gemm.h
@@ -72,7 +72,11 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(_config.gemm_algos[0]));
+#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
+#endif
}
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
@@ -96,7 +100,11 @@ class StridedBatchGemm {
stride_b,
stride_c,
_config.batch_size,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(_config.gemm_algos[0]));
+#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
+#endif
k_buf = _buffer_a;
q_buf = _buffer_b;
@@ -136,7 +144,11 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(_config.gemm_algos[1]));
+#else
cublasGemmAlgo_t(_config.gemm_algos[1]));
+#endif
// A need to transpose.
cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
@@ -161,7 +173,11 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_gemm_algo(_config.gemm_algos[2]));
+#else
cublasGemmAlgo_t(_config.gemm_algos[2]));
+#endif
}
inline int GetN() const { return _config.k; }
diff --git a/csrc/lamb/fused_lamb_cuda_kernel.cu b/csrc/lamb/fused_lamb_cuda_kernel.cu
index 0448a45368b9..10a17e98a13d 100644
--- a/csrc/lamb/fused_lamb_cuda_kernel.cu
+++ b/csrc/lamb/fused_lamb_cuda_kernel.cu
@@ -14,7 +14,11 @@
#include
//#include
+#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
+#include
+#else
#include
+#endif
#include
#include
diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu
index c48ae38969e3..f79c3ecb1e12 100644
--- a/csrc/quantization/quantizer.cu
+++ b/csrc/quantization/quantizer.cu
@@ -5,7 +5,7 @@ namespace cg = cooperative_groups;
__global__ void qunatize_kernel(__half* vals, int group_size, int num_bits)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
@@ -206,7 +206,7 @@ __global__ void sr_qunatize_kernel(__half* vals,
int num_bits,
std::pair seed)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
@@ -484,7 +484,7 @@ template void launch_sr_qunatize_kernel(__half* vals,
__global__ void qunatize_kernel_asym(__half* vals, int group_size, int num_bits)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
@@ -729,7 +729,7 @@ __global__ void sr_qunatize_kernel_asym(__half* vals,
int num_bits,
std::pair seed)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
diff --git a/csrc/transformer/cublas_wrappers.cu b/csrc/transformer/cublas_wrappers.cu
index 72b62386ea6d..d94fc5169efb 100644
--- a/csrc/transformer/cublas_wrappers.cu
+++ b/csrc/transformer/cublas_wrappers.cu
@@ -1,5 +1,19 @@
#include "cublas_wrappers.h"
+#ifdef __HIP_PLATFORM_HCC__
+int cublas_gemm_ex(rocblas_handle handle,
+ rocblas_operation transa,
+ rocblas_operation transb,
+ int m,
+ int n,
+ int k,
+ const float* alpha,
+ const float* beta,
+ const float* A,
+ const float* B,
+ float* C,
+ rocblas_gemm_algo algo)
+#else
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
@@ -12,7 +26,34 @@ int cublas_gemm_ex(cublasHandle_t handle,
const float* B,
float* C,
cublasGemmAlgo_t algo)
+#endif
{
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_status status = rocblas_gemm_ex(handle,
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ (const void*)alpha,
+ (const void*)A,
+ rocblas_datatype_f32_r,
+ (transa == rocblas_operation_none) ? m : k,
+ (const void*)B,
+ rocblas_datatype_f32_r,
+ (transb == rocblas_operation_none) ? k : n,
+ (const void*)beta,
+ C,
+ rocblas_datatype_f32_r,
+ m,
+ C,
+ rocblas_datatype_f32_r,
+ m,
+ rocblas_datatype_f32_r,
+ algo,
+ 0,
+ 0);
+#else
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
@@ -32,8 +73,13 @@ int cublas_gemm_ex(cublasHandle_t handle,
m,
CUDA_R_32F,
algo);
+#endif
+#ifdef __HIP_PLATFORM_HCC__
+ if (status != rocblas_status_success) {
+#else
if (status != CUBLAS_STATUS_SUCCESS) {
+#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
@@ -45,6 +91,20 @@ int cublas_gemm_ex(cublasHandle_t handle,
return 0;
}
+#ifdef __HIP_PLATFORM_HCC__
+int cublas_gemm_ex(rocblas_handle handle,
+ rocblas_operation transa,
+ rocblas_operation transb,
+ int m,
+ int n,
+ int k,
+ const float* alpha,
+ const float* beta,
+ const __half* A,
+ const __half* B,
+ __half* C,
+ rocblas_gemm_algo algo)
+#else
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
@@ -57,7 +117,34 @@ int cublas_gemm_ex(cublasHandle_t handle,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
+#endif
{
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_status status = rocblas_gemm_ex(handle,
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ (const void*)alpha,
+ (const void*)A,
+ rocblas_datatype_f16_r ,
+ (transa == rocblas_operation_none) ? m : k,
+ (const void*)B,
+ rocblas_datatype_f16_r,
+ (transb == rocblas_operation_none) ? k : n,
+ (const void*)beta,
+ (void*)C,
+ rocblas_datatype_f16_r,
+ m,
+ (void*)C,
+ rocblas_datatype_f16_r,
+ m,
+ rocblas_datatype_f32_r,
+ algo,
+ 0,
+ 0);
+#else
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
@@ -77,8 +164,13 @@ int cublas_gemm_ex(cublasHandle_t handle,
m,
CUDA_R_32F,
algo);
+#endif
+#ifdef __HIP_PLATFORM_HCC__
+ if (status != rocblas_status_success) {
+#else
if (status != CUBLAS_STATUS_SUCCESS) {
+#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
@@ -90,6 +182,24 @@ int cublas_gemm_ex(cublasHandle_t handle,
return 0;
}
+#ifdef __HIP_PLATFORM_HCC__
+int cublas_strided_batched_gemm(rocblas_handle handle,
+ int m,
+ int n,
+ int k,
+ const float* alpha,
+ const float* beta,
+ const float* A,
+ const float* B,
+ float* C,
+ rocblas_operation op_A,
+ rocblas_operation op_B,
+ int stride_A,
+ int stride_B,
+ int stride_C,
+ int batch,
+ rocblas_gemm_algo algo)
+#else
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
@@ -106,7 +216,39 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
+#endif
{
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_status status = rocblas_gemm_strided_batched_ex(handle,
+ op_A,
+ op_B,
+ m,
+ n,
+ k,
+ alpha,
+ A,
+ rocblas_datatype_f32_r,
+ (op_A == rocblas_operation_none) ? m : k,
+ stride_A,
+ B,
+ rocblas_datatype_f32_r,
+ (op_B == rocblas_operation_none) ? k : n,
+ stride_B,
+ beta,
+ C,
+ rocblas_datatype_f32_r,
+ m,
+ stride_C,
+ C,
+ rocblas_datatype_f32_r,
+ m,
+ stride_C,
+ batch,
+ rocblas_datatype_f32_r,
+ algo,
+ 0,
+ 0);
+#else
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
@@ -130,8 +272,13 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
batch,
CUDA_R_32F,
algo);
+#endif
+#ifdef __HIP_PLATFORM_HCC__
if (status != CUBLAS_STATUS_SUCCESS) {
+#else
+ if (status != rocblas_status_success) {
+#endif
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
@@ -144,6 +291,24 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
return 0;
}
+#ifdef __HIP_PLATFORM_HCC__
+int cublas_strided_batched_gemm(rocblas_handle handle,
+ int m,
+ int n,
+ int k,
+ const float* alpha,
+ const float* beta,
+ const __half* A,
+ const __half* B,
+ __half* C,
+ rocblas_operation op_A,
+ rocblas_operation op_B,
+ int stride_A,
+ int stride_B,
+ int stride_C,
+ int batch,
+ rocblas_gemm_algo algo)
+#else
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
@@ -160,7 +325,39 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
+#endif
{
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_status status = rocblas_gemm_strided_batched_ex(handle,
+ op_A,
+ op_B,
+ m,
+ n,
+ k,
+ alpha,
+ A,
+ rocblas_datatype_f16_r,
+ (op_A == rocblas_operation_none) ? m : k,
+ stride_A,
+ B,
+ rocblas_datatype_f16_r,
+ (op_B == rocblas_operation_none) ? k : n,
+ stride_B,
+ beta,
+ C,
+ rocblas_datatype_f16_r,
+ m,
+ stride_C,
+ C,
+ rocblas_datatype_f16_r,
+ m,
+ stride_C,
+ batch,
+ rocblas_datatype_f32_r,
+ algo,
+ 0,
+ 0);
+#else
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
@@ -184,8 +381,13 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
batch,
CUDA_R_32F,
algo);
+#endif
+#ifdef __HIP_PLATFORM_HCC__
+ if (status != rocblas_status_success) {
+#else
if (status != CUBLAS_STATUS_SUCCESS) {
+#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp
index e729e15a3912..bc9cf59ec902 100644
--- a/csrc/transformer/ds_transformer_cuda.cpp
+++ b/csrc/transformer/ds_transformer_cuda.cpp
@@ -140,7 +140,9 @@ BertTransformerLayer::~BertTransformerLayer()
template
void BertTransformerLayer::Initialize()
{
+#ifndef __HIP_PLATFORM_HCC__
if (std::is_same::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
+#endif
}
template
diff --git a/csrc/transformer/gelu_kernels.cu b/csrc/transformer/gelu_kernels.cu
index 12048006266e..dbb8828ce977 100644
--- a/csrc/transformer/gelu_kernels.cu
+++ b/csrc/transformer/gelu_kernels.cu
@@ -60,7 +60,7 @@ __global__ void gelu_kernel(const float* input, float* vals, int row_stride, int
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
@@ -131,7 +131,7 @@ __global__ void fused_bias_gelu(const __half* input,
int row_stride,
int iterations)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
@@ -214,7 +214,7 @@ __global__ void d_gelu_func(__half* d_output,
int row_stride,
int iterations)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu
index ddf7a958822a..43d50f00c058 100644
--- a/csrc/transformer/inference/csrc/dequantize.cu
+++ b/csrc/transformer/inference/csrc/dequantize.cu
@@ -46,7 +46,7 @@ __global__ void dequantize_kernel(__half* output,
unsigned groups,
unsigned merge_count)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu
index fc3faacc54e8..1737855f614d 100755
--- a/csrc/transformer/inference/csrc/gelu.cu
+++ b/csrc/transformer/inference/csrc/gelu.cu
@@ -39,7 +39,7 @@ __global__ void fused_bias_gelu(__half* input,
int total_count,
int intermediate_size)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast(input);
const float2* bias_cast = reinterpret_cast(bias);
@@ -117,7 +117,7 @@ __global__ void fused_bias_add(float* input, const float* bias, int total_count,
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast(input);
const float2* bias_cast = reinterpret_cast(bias);
@@ -195,7 +195,7 @@ __global__ void fused_bias_residual(__half* input,
int total_count,
int intermediate_size)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast(input);
const float2* residual_cast = reinterpret_cast(residual);
diff --git a/csrc/transformer/inference/csrc/normalize.cu b/csrc/transformer/inference/csrc/normalize.cu
index ecd73154f37f..dc0f6be01144 100755
--- a/csrc/transformer/inference/csrc/normalize.cu
+++ b/csrc/transformer/inference/csrc/normalize.cu
@@ -85,7 +85,7 @@ __global__ void fused_bias_residual_layer_norm(__half* output,
float epsilon,
int row_stride)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
@@ -287,7 +287,7 @@ __global__ void fused_residual_layer_norm(__half* norm,
int row_stride,
bool preLN)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu
index cee509965106..b347945df636 100644
--- a/csrc/transformer/inference/csrc/softmax.cu
+++ b/csrc/transformer/inference/csrc/softmax.cu
@@ -37,7 +37,7 @@ __global__ void attn_softmax_v2(__half* vals,
int num_seq,
float scale)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile g = cg::tiled_partition(b);
diff --git a/csrc/transformer/normalize_kernels.cu b/csrc/transformer/normalize_kernels.cu
index 366e93724638..c9bc4a46ee5e 100644
--- a/csrc/transformer/normalize_kernels.cu
+++ b/csrc/transformer/normalize_kernels.cu
@@ -121,7 +121,7 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
__half* means,
int row_stride)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
@@ -404,7 +404,7 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
__half* vars,
int row_stride)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu
index be776b0c074d..a4d84c37dd3b 100644
--- a/csrc/transformer/softmax_kernels.cu
+++ b/csrc/transformer/softmax_kernels.cu
@@ -142,7 +142,7 @@ __global__ void attn_softmax(__half* vals,
int seq_length,
int iterations)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
diff --git a/csrc/transformer/transform_kernels.cu b/csrc/transformer/transform_kernels.cu
index 7d8a27eeeb43..b68d70f67ae1 100755
--- a/csrc/transformer/transform_kernels.cu
+++ b/csrc/transformer/transform_kernels.cu
@@ -96,7 +96,7 @@ __global__ void transform_0213<__half>(__half* output,
int heads,
int head_ext)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
@@ -219,7 +219,7 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int heads,
int head_ext)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
@@ -289,7 +289,7 @@ __global__ void bias_add_transform_0213_v2(__half* output,
int seq_length,
int heads)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
@@ -451,7 +451,7 @@ __global__ void transform4d_0213<__half>(__half* out,
int hidden_dim,
int head_ext)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
@@ -487,7 +487,7 @@ __global__ void transform4d_0213_v2(__half* out,
int seq_length,
int hidden_dim)
{
-#if __CUDA_ARCH__ >= 700
+#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py
index b14ac4464835..7166104fde95 100644
--- a/deepspeed/env_report.py
+++ b/deepspeed/env_report.py
@@ -85,6 +85,8 @@ def debug_report():
torch.__version__),
("torch cuda version",
torch.version.cuda),
+ ("torch hip version",
+ torch.version.hip),
("nvcc version",
nvcc_version()),
("deepspeed install path",
@@ -93,7 +95,7 @@ def debug_report():
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
),
("deepspeed wheel compiled w.",
- f"torch {torch_info['version']}, cuda {torch_info['cuda_version']}"),
+ f"torch {torch_info['version']}, cuda {torch_info['cuda_version']}, hip {torch_info['hip_version']}"),
]
print("DeepSpeed general environment info:")
for name, value in report:
diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py
index f04982c74f0d..a806475c397b 100644
--- a/deepspeed/git_version_info.py
+++ b/deepspeed/git_version_info.py
@@ -14,4 +14,4 @@
from .ops.op_builder import ALL_OPS
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
- torch_info = {'version': "0.0", "cuda_version": "0.0"}
+ torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"}
diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py
index ea5bb7b046fa..dac5378a101c 100755
--- a/deepspeed/ops/__init__.py
+++ b/deepspeed/ops/__init__.py
@@ -1,6 +1,8 @@
from . import adam
from . import lamb
-from . import sparse_attention
+from ..git_version_info_installed import installed_ops as __installed_ops__
+if __installed_ops__['sparse_attn']:
+ from . import sparse_attention
from . import transformer
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm
new file mode 100644
index 000000000000..95f9233ff21e
--- /dev/null
+++ b/docker/Dockerfile.rocm
@@ -0,0 +1,183 @@
+FROM rocm/pytorch:latest
+
+
+##############################################################################
+# Temporary Installation Directory
+##############################################################################
+ENV STAGE_DIR=/tmp
+RUN mkdir -p ${STAGE_DIR}
+
+##############################################################################
+# Installation/Basic Utilities
+##############################################################################
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ software-properties-common build-essential autotools-dev \
+ nfs-common pdsh \
+ cmake g++ gcc \
+ curl wget vim tmux emacs less unzip \
+ htop iftop iotop ca-certificates openssh-client openssh-server \
+ rsync iputils-ping net-tools sudo \
+ llvm-9-dev
+
+##############################################################################
+# Installation Latest Git
+##############################################################################
+RUN add-apt-repository ppa:git-core/ppa -y && \
+ apt-get update && \
+ apt-get install -y git && \
+ git --version
+
+##############################################################################
+# Client Liveness & Uncomment Port 22 for SSH Daemon
+##############################################################################
+# Keep SSH client alive from server side
+RUN echo "ClientAliveInterval 30" >> /etc/ssh/sshd_config
+RUN cp /etc/ssh/sshd_config ${STAGE_DIR}/sshd_config && \
+ sed "0,/^#Port 22/s//Port 22/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config
+
+##############################################################################
+# Mellanox OFED
+##############################################################################
+#ENV MLNX_OFED_VERSION=4.6-1.0.1.1
+#RUN apt-get install -y libnuma-dev
+#RUN cd ${STAGE_DIR} && \
+# wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64.tgz | tar xzf - && \
+# cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64 && \
+# ./mlnxofedinstall --user-space-only --without-fw-update --all -q && \
+# cd ${STAGE_DIR} && \
+# rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64*
+
+##############################################################################
+# OPENMPI
+##############################################################################
+#ENV OPENMPI_BASEVERSION=4.0
+#ENV OPENMPI_VERSION=${OPENMPI_BASEVERSION}.1
+#RUN cd ${STAGE_DIR} && \
+# wget -q -O - https://download.open-mpi.org/release/open-mpi/v${OPENMPI_BASEVERSION}/openmpi-${OPENMPI_VERSION}.tar.gz | tar xzf - && \
+# cd openmpi-${OPENMPI_VERSION} && \
+# ./configure --prefix=/usr/local/openmpi-${OPENMPI_VERSION} && \
+# make -j"$(nproc)" install && \
+# ln -s /usr/local/openmpi-${OPENMPI_VERSION} /usr/local/mpi && \
+# # Sanity check:
+# test -f /usr/local/mpi/bin/mpic++ && \
+# cd ${STAGE_DIR} && \
+# rm -r ${STAGE_DIR}/openmpi-${OPENMPI_VERSION}
+#ENV PATH=/usr/local/mpi/bin:${PATH} \
+# LD_LIBRARY_PATH=/usr/local/lib:/usr/local/mpi/lib:/usr/local/mpi/lib64:${LD_LIBRARY_PATH}
+## Create a wrapper for OpenMPI to allow running as root by default
+#RUN mv /usr/local/mpi/bin/mpirun /usr/local/mpi/bin/mpirun.real && \
+# echo '#!/bin/bash' > /usr/local/mpi/bin/mpirun && \
+# echo 'mpirun.real --allow-run-as-root --prefix /usr/local/mpi "$@"' >> /usr/local/mpi/bin/mpirun && \
+# chmod a+x /usr/local/mpi/bin/mpirun
+
+##############################################################################
+# Python
+##############################################################################
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PYTHON_VERSION=3.6
+RUN apt-get install -y python3.6 python3.6-dev && \
+ rm -f /usr/bin/python && \
+ ln -s /usr/bin/python3.6 /usr/bin/python && \
+ curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py && \
+ pip install --upgrade pip && \
+ # Print python an pip version
+ python -V && pip -V
+RUN pip install pyyaml
+RUN pip install ipython
+
+##############################################################################
+# TensorFlow
+##############################################################################
+RUN pip install tensorflow-rocm
+
+##############################################################################
+# Some Packages
+##############################################################################
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ libsndfile-dev \
+ libjpeg-dev \
+ libpng-dev \
+ screen
+RUN pip install psutil \
+ yappi \
+ cffi \
+ ipdb \
+ pandas \
+ matplotlib \
+ py3nvml \
+ pyarrow \
+ graphviz \
+ astor \
+ boto3 \
+ tqdm \
+ sentencepiece \
+ msgpack \
+ requests \
+ pandas \
+ sphinx \
+ sphinx_rtd_theme \
+ scipy \
+ numpy \
+ sklearn \
+ scikit-learn \
+ mpi4py \
+ h5py
+
+##############################################################################
+## SSH daemon port inside container cannot conflict with host OS port
+###############################################################################
+ENV SSH_PORT=2222
+RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \
+ sed "0,/^#Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config
+
+##############################################################################
+# PyTorch
+##############################################################################
+#ENV PYTORCH_VERSION=1.2.0
+#ENV TORCHVISION_VERSION=0.4.0
+#ENV TENSORBOARDX_VERSION=1.8
+#RUN pip install torch==${PYTORCH_VERSION}
+#RUN pip install torchvision==${TORCHVISION_VERSION}
+#RUN pip install tensorboardX==${TENSORBOARDX_VERSION}
+
+##############################################################################
+# PyYAML build issue
+# https://stackoverflow.com/a/53926898
+##############################################################################
+RUN rm -rf /usr/lib/python3/dist-packages/yaml && \
+ rm -rf /usr/lib/python3/dist-packages/PyYAML-*
+
+##############################################################################
+## CuPy installation
+###############################################################################
+RUN git clone https://github.com/ROCmSoftwarePlatform/cupy ${STAGE_DIR}/cupy
+RUN cd ${STAGE_DIR}/cupy && \
+ git submodule update --init && \
+ CUPY_INSTALL_USE_HIP=1 ROCM_HOME=/opt/rocm pip install -e . --no-cache-dir -vvvv
+RUN rm -rf ${STAGE_DIR}/cupy
+
+##############################################################################
+## Add deepspeed user
+###############################################################################
+# Add a deepspeed user with user id 8877
+#RUN useradd --create-home --uid 8877 deepspeed
+#RUN useradd --create-home --uid 1000 --shell /bin/bash deepspeed
+#RUN usermod -aG sudo deepspeed
+#RUN echo "deepspeed ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
+# # Change to non-root privilege
+#USER deepspeed
+
+##############################################################################
+# DeepSpeed
+##############################################################################
+RUN git clone https://github.com/ROCmSoftwarePlatform/DeepSpeed.git ${STAGE_DIR}/DeepSpeed
+RUN cd ${STAGE_DIR}/DeepSpeed && \
+ git checkout . && \
+ git checkout master && \
+ DS_BUILD_FUSED_ADAM=1 DS_BUILD_FUSED_LAMB=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_STOCHASTIC_TRANSFORMER=1 DS_BUILD_UTILS=1 ./install.sh --allow_sudo
+RUN rm -rf ${STAGE_DIR}/DeepSpeed
+RUN cd ~ && python -c "import deepspeed; print(deepspeed.__version__)"
diff --git a/install.sh b/install.sh
index 7c26883d6db0..a03455efa2e3 100755
--- a/install.sh
+++ b/install.sh
@@ -156,7 +156,7 @@ python setup.py $VERBOSE bdist_wheel
if [ "$local_only" == "1" ]; then
echo "Installing deepspeed"
- $PIP_SUDO pip uninstall -y deepspeed
+# $PIP_SUDO pip uninstall -y deepspeed
$PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl
ds_report
else
@@ -171,7 +171,11 @@ else
tmp_wheel_path="/tmp/deepspeed_wheels"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*; else mkdir -pv $tmp_wheel_path; fi"
- pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
+ if [ -e "/opt/rocm" ]; then
+ pdcp -w $hosts requirements/requirements-rocm.txt ${tmp_wheel_path}/
+ else
+ pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
+ fi
echo "Installing deepspeed"
pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed"
diff --git a/op_builder/__init__.py b/op_builder/__init__.py
index f19ed916c332..bf315489044e 100755
--- a/op_builder/__init__.py
+++ b/op_builder/__init__.py
@@ -9,7 +9,7 @@
from .stochastic_transformer import StochasticTransformerBuilder
from .utils import UtilsBuilder
from .async_io import AsyncIOBuilder
-from .builder import get_default_compute_capatabilities
+from .builder import get_default_compute_capatabilities, is_rocm_pytorch
from .transformer_inference import InferenceBuilder
from .quantizer import QuantizerBuilder
diff --git a/op_builder/builder.py b/op_builder/builder.py
index e061072a88b9..f2a96534d815 100644
--- a/op_builder/builder.py
+++ b/op_builder/builder.py
@@ -17,6 +17,10 @@
DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
+is_rocm_pytorch = False
+if torch.__version__ >= '1.5':
+ from torch.utils.cpp_extension import ROCM_HOME
+ is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
def installed_cuda_version():
import torch.utils.cpp_extension
@@ -66,17 +70,31 @@ def assert_no_cuda_mismatch():
def assert_torch_info(torch_info):
install_torch_version = torch_info['version']
install_cuda_version = torch_info['cuda_version']
+ install_hip_version = torch_info['hip_version']
+
+ if not is_rocm_pytorch:
+ current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
+ else:
+ current_hip_version = ".".join(torch.version.hip.split('.')[:2])
- current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
current_torch_version = ".".join(torch.__version__.split('.')[:2])
- if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
- raise RuntimeError(
- "PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
- "with a different version than what is being used at runtime. Please re-install "
- f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
- f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
- f"torch={current_torch_version}, cuda={current_cuda_version}")
+ if not is_rocm_pytorch:
+ if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
+ raise RuntimeError(
+ "PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
+ "with a different version than what is being used at runtime. Please re-install "
+ f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
+ f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
+ f"torch={current_torch_version}, cuda={current_cuda_version}")
+ else:
+ if install_hip_version != current_hip_version or install_torch_version != current_torch_version:
+ raise RuntimeError(
+ "PyTorch and HIP version mismatch! DeepSpeed ops were compiled and installed "
+ "with a different version than what is being used at runtime. Please re-install "
+ f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
+ f"torch={install_torch_version}, hip={install_hip_version}, runtime versions:"
+ f"torch={current_torch_version}, hip={current_hip_version}")
class OpBuilder(ABC):
@@ -227,7 +245,7 @@ def jit_load(self, verbose=True):
f"Unable to JIT load the {self.name} op due to ninja not being installed."
)
- if isinstance(self, CUDAOpBuilder):
+ if isinstance(self, CUDAOpBuilder) and not is_rocm_pytorch:
assert_no_cuda_mismatch()
self.jit_mode = True
@@ -241,9 +259,10 @@ def jit_load(self, verbose=True):
os.makedirs(ext_path, exist_ok=True)
start_build = time.time()
+ sources = [self.deepspeed_src_path(path) for path in self.sources()]
op_module = load(
name=self.name,
- sources=[self.deepspeed_src_path(path) for path in self.sources()],
+ sources=sources,
extra_include_paths=[
self.deepspeed_src_path(path) for path in self.include_paths()
],
@@ -331,7 +350,8 @@ def is_compatible(self):
def builder(self):
from torch.utils.cpp_extension import CUDAExtension
- assert_no_cuda_mismatch()
+ if not is_rocm_pytorch:
+ assert_no_cuda_mismatch()
return CUDAExtension(name=self.absolute_name(),
sources=self.sources(),
include_dirs=self.include_paths(),
diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py
index 75fa042613fc..d5a8216d2095 100644
--- a/op_builder/cpu_adam.py
+++ b/op_builder/cpu_adam.py
@@ -5,7 +5,7 @@
import sys
import torch
import subprocess
-from .builder import CUDAOpBuilder
+from .builder import CUDAOpBuilder, is_rocm_pytorch
class CPUAdamBuilder(CUDAOpBuilder):
@@ -26,8 +26,15 @@ def sources(self):
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/custom_cuda_kernel.cu']
def include_paths(self):
- CUDA_INCLUDE = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")
- return ['csrc/includes', CUDA_INCLUDE]
+ if not is_rocm_pytorch:
+ CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
+ else:
+ CUDA_INCLUDE = [
+ os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
+ os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
+ os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
+ ]
+ return ['csrc/includes'] + CUDA_INCLUDE
def simd_width(self):
if not self.command_exists('lscpu'):
@@ -40,14 +47,17 @@ def simd_width(self):
result = subprocess.check_output('lscpu', shell=True)
result = result.decode('utf-8').strip().lower()
if 'genuineintel' in result:
- if 'avx512' in result:
+ if not is_rocm_pytorch and 'avx512' in result:
return '-D__AVX512__'
elif 'avx2' in result:
return '-D__AVX256__'
return '-D__SCALAR__'
def cxx_args(self):
- CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
+ if not is_rocm_pytorch:
+ CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
+ else:
+ CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
SIMD_WIDTH = self.simd_width()
return [
@@ -66,11 +76,20 @@ def cxx_args(self):
def nvcc_args(self):
args = [
'-O3',
- '--use_fast_math',
- '-std=c++14',
- '-U__CUDA_NO_HALF_OPERATORS__',
- '-U__CUDA_NO_HALF_CONVERSIONS__',
- '-U__CUDA_NO_HALF2_OPERATORS__'
+ '-std=c++14'
]
- args += self.compute_capability_args()
+ if is_rocm_pytorch:
+ args += [
+ '-U__HIP_NO_HALF_OPERATORS__',
+ '-U__HIP_NO_HALF_CONVERSIONS__',
+ '-U__HIP_NO_HALF2_OPERATORS__'
+ ]
+ else:
+ args += [
+ '--use_fast_math',
+ '-U__CUDA_NO_HALF_OPERATORS__',
+ '-U__CUDA_NO_HALF_CONVERSIONS__',
+ '-U__CUDA_NO_HALF2_OPERATORS__'
+ ]
+ args += self.compute_capability_args()
return args
diff --git a/op_builder/fused_adam.py b/op_builder/fused_adam.py
index 0340ed02a8fb..6fa240ac3a3f 100644
--- a/op_builder/fused_adam.py
+++ b/op_builder/fused_adam.py
@@ -2,7 +2,7 @@
Copyright 2020 The Microsoft DeepSpeed Team
"""
import torch
-from .builder import CUDAOpBuilder
+from .builder import CUDAOpBuilder, is_rocm_pytorch
class FusedAdamBuilder(CUDAOpBuilder):
@@ -19,14 +19,14 @@ def sources(self):
return ['csrc/adam/fused_adam_frontend.cpp', 'csrc/adam/multi_tensor_adam.cu']
def include_paths(self):
- return ['csrc/includes']
+ return ['csrc/includes', 'csrc/adam']
def cxx_args(self):
args = super().cxx_args()
return args + self.version_dependent_macros()
def nvcc_args(self):
- return ['-lineinfo',
- '-O3',
- '--use_fast_math'
- ] + self.version_dependent_macros() + self.compute_capability_args()
+ nvcc_flags=['-O3'] + self.version_dependent_macros()
+ if not is_rocm_pytorch:
+ nvcc_flags.extend(['-lineinfo', '--use_fast_math'] + self.compute_capability_args())
+ return nvcc_flags
diff --git a/op_builder/fused_lamb.py b/op_builder/fused_lamb.py
index 4c73ecb404e3..cb60412baea7 100644
--- a/op_builder/fused_lamb.py
+++ b/op_builder/fused_lamb.py
@@ -2,7 +2,7 @@
Copyright 2020 The Microsoft DeepSpeed Team
"""
import torch
-from .builder import CUDAOpBuilder
+from .builder import CUDAOpBuilder, is_rocm_pytorch
class FusedLambBuilder(CUDAOpBuilder):
@@ -26,7 +26,7 @@ def cxx_args(self):
return args + self.version_dependent_macros()
def nvcc_args(self):
- return ['-lineinfo',
- '-O3',
- '--use_fast_math'
- ] + self.version_dependent_macros() + self.compute_capability_args()
+ nvcc_flags=['-O3'] + self.version_dependent_macros()
+ if not is_rocm_pytorch:
+ nvcc_flags.extend(['-lineinfo', '--use_fast_math'] + self.compute_capability_args())
+ return nvcc_flags
diff --git a/op_builder/transformer.py b/op_builder/transformer.py
index a94c6da62e39..606d0be255ef 100644
--- a/op_builder/transformer.py
+++ b/op_builder/transformer.py
@@ -2,7 +2,7 @@
Copyright 2020 The Microsoft DeepSpeed Team
"""
import torch
-from .builder import CUDAOpBuilder
+from .builder import CUDAOpBuilder, is_rocm_pytorch
class TransformerBuilder(CUDAOpBuilder):
@@ -29,16 +29,29 @@ def sources(self):
]
def include_paths(self):
- return ['csrc/includes']
+ includes = ['csrc/includes']
+ if is_rocm_pytorch:
+ from torch.utils.cpp_extension import ROCM_HOME
+ includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)]
+ return includes
def nvcc_args(self):
args = [
'-O3',
- '--use_fast_math',
'-std=c++14',
- '-U__CUDA_NO_HALF_OPERATORS__',
- '-U__CUDA_NO_HALF_CONVERSIONS__',
- '-U__CUDA_NO_HALF2_OPERATORS__'
]
-
- return args + self.compute_capability_args()
+ if is_rocm_pytorch:
+ args += [
+ '-U__HIP_NO_HALF_OPERATORS__',
+ '-U__HIP_NO_HALF_CONVERSIONS__',
+ '-U__HIP_NO_HALF2_OPERATORS__'
+ ]
+ else:
+ args += [
+ '--use_fast_math',
+ '-U__CUDA_NO_HALF_OPERATORS__',
+ '-U__CUDA_NO_HALF_CONVERSIONS__',
+ '-U__CUDA_NO_HALF2_OPERATORS__'
+ ]
+ args += self.compute_capability_args()
+ return args
diff --git a/requirements/requirements-rocm.txt b/requirements/requirements-rocm.txt
new file mode 100644
index 000000000000..e0e1d0c482ce
--- /dev/null
+++ b/requirements/requirements-rocm.txt
@@ -0,0 +1,7 @@
+#torch>=1.2
+#torchvision>=0.4.0
+tqdm
+tensorboardX==1.8
+ninja
+numpy
+psutil
diff --git a/setup.py b/setup.py
index 654b983eec81..27374a7f748b 100755
--- a/setup.py
+++ b/setup.py
@@ -29,7 +29,7 @@
raise ImportError('Unable to import torch, please visit https://pytorch.org/ '
'to see how to properly install torch on your system.')
-from op_builder import ALL_OPS, get_default_compute_capatabilities
+from op_builder import ALL_OPS, get_default_compute_capatabilities, is_rocm_pytorch
def fetch_requirements(path):
@@ -38,6 +38,9 @@ def fetch_requirements(path):
install_requires = fetch_requirements('requirements/requirements.txt')
+if is_rocm_pytorch:
+ print("NOTE: Please manually install torch and torchvision packages for ROCm")
+ install_requires = fetch_requirements('requirements/requirements-rocm.txt')
extras_require = {
'1bit_adam': fetch_requirements('requirements/requirements-1bit-adam.txt'),
'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'),
@@ -47,7 +50,10 @@ def fetch_requirements(path):
# If MPI is available add 1bit-adam requirements
if torch.cuda.is_available():
if shutil.which('ompi_info') or shutil.which('mpiname'):
- cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}"
+ if is_rocm_pytorch:
+ cupy = "cupy"
+ else:
+ cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}"
extras_require['1bit_adam'].append(cupy)
# Make an [all] extra that installs all needed dependencies
@@ -170,9 +176,13 @@ def create_dir_symlink(src, dest):
torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR])
# Set cuda_version to 0.0 if cpu-only
cuda_version = "0.0"
+# Set hip_version to 0.0 if cpu-only
+hip_version = "0.0"
if torch.version.cuda is not None:
cuda_version = ".".join(torch.version.cuda.split('.')[:2])
-torch_info = {"version": torch_version, "cuda_version": cuda_version}
+if torch.version.hip is not None:
+ hip_version = ".".join(torch.version.hip.split('.')[:2])
+torch_info = {"version": torch_version, "cuda_version": cuda_version, "hip_version": hip_version}
print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}")
with open('deepspeed/git_version_info_installed.py', 'w') as fd:
@@ -206,7 +216,7 @@ def create_dir_symlink(src, dest):
extras_require=extras_require,
packages=find_packages(exclude=["docker",
"third_party"]),
- include_package_data=True,
+# include_package_data=True, #FIXME
scripts=[
'bin/deepspeed',
'bin/deepspeed.pt',
diff --git a/tests/unit/common.py b/tests/unit/common.py
index f92b1058aa92..316fcf227232 100644
--- a/tests/unit/common.py
+++ b/tests/unit/common.py
@@ -8,10 +8,24 @@
import deepspeed
import pytest
+from functools import wraps
+import unittest
# Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
+TEST_WITH_ROCM = os.getenv('DEEPSPEED_TEST_WITH_ROCM', '0') == '1'
+
+def skipIfRocm(reason="test doesn't currently work on the ROCm stack"):
+ def decorator(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ if TEST_WITH_ROCM:
+ raise unittest.SkipTest(reason)
+ else:
+ fn(*args, **kwargs)
+ return wrapper
+ return decorator
def distributed_test(world_size=2, backend='nccl'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index 01004a0fa867..bf207fcb80b2 100755
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -3,7 +3,7 @@
import pytest
import json
import argparse
-from common import distributed_test
+from common import distributed_test, skipIfRocm
from simple_model import SimpleModel, create_config_from_dict, random_dataloader
import torch.distributed as dist
@@ -56,6 +56,7 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
(2,32,8,2,True),
(2,33,17,2,False),
(2,32,18,1,False)]) # yapf: disable
+@skipIfRocm()
def test_batch_config(num_ranks, batch, micro_batch, gas, success):
@distributed_test(world_size=2)
def _test_batch_config(num_ranks, batch, micro_batch, gas, success):
diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py
index 62ccdbdc68d9..fbc70447a39d 100755
--- a/tests/unit/test_cuda_backward.py
+++ b/tests/unit/test_cuda_backward.py
@@ -8,6 +8,7 @@
import time
import copy
from torch import nn
+from common import skipIfRocm
from modelingpreln import BertEncoder as BertEncoderPreln
from modeling import BertEncoder as BertEncoderPostln
from modeling import BertConfig, BertLayerNorm
@@ -268,6 +269,7 @@ def run_backward(ds_config, seq_len, atol=1e-2, verbose=False):
#(3,128,51,2,24,False,False, 0.1),
#(3,128,54,2,24,False,True, 0.2),
]) # yapf: disable
+@skipIfRocm()
def test_backward(batch_size,
hidden_size,
seq_len,
diff --git a/tests/unit/test_dist.py b/tests/unit/test_dist.py
index 04b97031b3e5..18a74b0a16fd 100644
--- a/tests/unit/test_dist.py
+++ b/tests/unit/test_dist.py
@@ -5,7 +5,6 @@
import pytest
-
@distributed_test(world_size=3)
def test_init():
assert dist.is_initialized()
diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py
index 9796a70953f8..919042d04c7c 100644
--- a/tests/unit/test_onebit.py
+++ b/tests/unit/test_onebit.py
@@ -14,7 +14,7 @@
from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology
PipeTopo = PipeDataParallelTopology
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
-from common import distributed_test
+from common import distributed_test, skipIfRocm
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args
from test_pipe import AlexNetPipe, train_cifar
@@ -25,6 +25,7 @@
allow_module_level=True)
+@skipIfRocm("Skipped for now as cupy is not available on ROCm")
def test_onebitadam_fp16_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
@@ -68,6 +69,7 @@ def _test_onebitadam_fp16_basic(args, model, hidden_dim):
_test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
+@skipIfRocm("Skipped for now as cupy is not available on ROCm")
def test_onebitadam_fp32_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
@@ -107,6 +109,7 @@ def _test_onebitadam_fp32_basic(args, model, hidden_dim):
_test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim)
+@skipIfRocm("Skipped for now as cupy is not available on ROCm")
def test_onebitadam_exp_avg_mask(tmpdir):
config_dict = {
"train_batch_size": 2,
@@ -168,6 +171,7 @@ def _test_onebitadam_exp_avg_mask(args, model, hidden_dim):
_test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim)
+@skipIfRocm("Skipped for now as cupy is not available on ROCm")
def test_onebitadam_checkpointing(tmpdir):
config_dict = {
"train_batch_size": 2,
@@ -849,6 +853,7 @@ def _helper(topo, tmpdir, steps=500):
_helper(topo, tmpdir)
+@skipIfRocm("Skipped for now as cupy is not available on ROCm")
def test_compressed_allreduce_basic(tmpdir):
@distributed_test(world_size=[1, 2])
def _test_compressed_allreduce_basic():
diff --git a/tests/unit/test_pipe_module.py b/tests/unit/test_pipe_module.py
index a29d22a2a954..af87546c4ee3 100644
--- a/tests/unit/test_pipe_module.py
+++ b/tests/unit/test_pipe_module.py
@@ -14,7 +14,7 @@
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.utils import RepeatingLoader
-from common import distributed_test
+from common import distributed_test, skipIfRocm
from simple_model import args_from_dict
HIDDEN_DIM = 32
@@ -56,6 +56,7 @@ def simple_args(tmpdir):
return args
+@skipIfRocm()
def test_pipe_module_sequential(sequential_model, simple_args):
batch_input = torch.randn(1, HIDDEN_DIM)
diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py
index 176363688de4..e907af06427c 100644
--- a/tests/unit/test_topology.py
+++ b/tests/unit/test_topology.py
@@ -7,7 +7,7 @@
from deepspeed.runtime.pipe.topology import ProcessTopology as Topo
from deepspeed.runtime.pipe.topology import _prime_factors
-from common import distributed_test
+from common import distributed_test, skipIfRocm
def test_topology_2d():
@@ -157,6 +157,7 @@ def test_topology_comm_list():
assert topo.get_axis_comm_lists('jeff') == []
+@skipIfRocm()
@distributed_test(world_size=4)
def test_grid_pipe_data():
topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
@@ -182,7 +183,6 @@ def test_grid_pipe_data():
data_group = grid.dp_group
assert torch.all(rank_tensor == sum(data_group))
-
@distributed_test(world_size=4)
def test_stage_to_global():
topo = Topo(axes=['pipe', 'data'], dims=[2, 2])