@@ -104,6 +104,7 @@ class ModelState : public BackendModel {
104104 return enable_jit_executor_pair_;
105105 }
106106 bool EnabledInferenceMode () { return enable_inference_mode_; }
107+ bool EnabledCudnn () { return enable_cudnn_; }
107108 bool EnabledCacheCleaning () { return enable_cache_cleaning_; }
108109
109110 bool EnabledWeightSharing () { return enable_weight_sharing_; }
@@ -125,6 +126,9 @@ class ModelState : public BackendModel {
125126 // Flag to indicate whether inference mode is enabled. Defaults to false.
126127 bool enable_inference_mode_;
127128
129+ // Flag to indicate whether cudnn is enabled. Defaults to true.
130+ bool enable_cudnn_;
131+
128132 // Flag to indicate whether cache cleaning after each run is enabled.
129133 // Defaults to false.
130134 bool enable_cache_cleaning_;
@@ -227,8 +231,9 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
227231
228232ModelState::ModelState (TRITONBACKEND_Model* triton_model)
229233 : BackendModel(triton_model), enable_optimized_execution_(true ),
230- enable_inference_mode_ (true ), enable_cache_cleaning_(false ),
231- enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
234+ enable_inference_mode_ (true ), enable_cudnn_(true ),
235+ enable_cache_cleaning_(false ), enable_weight_sharing_(false ),
236+ enable_tensor_fuser_pair_({false , true }),
232237 enable_jit_profiling_pair_({false , true }),
233238 enable_jit_executor_pair_({false , true })
234239{
@@ -393,6 +398,24 @@ ModelState::ParseParameters()
393398 " for model instance '" + Name () + " '" )
394399 .c_str ());
395400
401+ // If 'DISABLE_CUDNN' is not present in 'parameters' then no update is made
402+ // to 'enable_cudnn_'.
403+ bool disable_cudnn = false ;
404+ err = ParseParameter (params, " DISABLE_CUDNN" , &disable_cudnn);
405+ if (err != nullptr ) {
406+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
407+ return err;
408+ } else {
409+ TRITONSERVER_ErrorDelete (err);
410+ }
411+ }
412+ enable_cudnn_ = !disable_cudnn;
413+ LOG_MESSAGE (
414+ TRITONSERVER_LOG_INFO,
415+ (std::string (" cuDNN is " ) + (enable_cudnn_ ? " enabled" : " disabled" ) +
416+ " for model instance '" + Name () + " '" )
417+ .c_str ());
418+
396419 // If 'ENABLE_TENSOR_FUSER' is not present in 'parameters' then no
397420 // update is made to 'enable_tensor_fuser'.
398421 bool enable_tensor_fuser = false ;
@@ -1562,6 +1585,9 @@ ModelInstanceState::Execute(
15621585 // enable/disable inference mode - supersedes NoGradGuard
15631586 torch::InferenceMode infer_guard (model_state_->EnabledInferenceMode ());
15641587
1588+ // enable/disable cudnn
1589+ at::globalContext ().setUserEnabledCuDNN (model_state_->EnabledCudnn ());
1590+
15651591 // JIT. No change is made unless parameter is explicitly set.
15661592 if (std::get<0 >(model_state_->EnabledJitProfiling ())) {
15671593 torch::jit::getProfilingMode () =
0 commit comments