@@ -28,22 +28,34 @@ static bool emit_cuda_c_code(CudaKernel* spec) {
2828 Module * final_mod ;
2929 emit_c (config , emitter_config , dst_mod , & spec -> cuda_code_size , & spec -> cuda_code , & final_mod );
3030 spec -> final_module = final_mod ;
31+
32+ if (get_log_level () <= DEBUG )
33+ write_file ("cuda_dump.cu" , spec -> cuda_code_size - 1 , spec -> cuda_code );
34+
3135 return true;
3236}
3337
3438static bool cuda_c_to_ptx (CudaKernel * kernel ) {
3539 nvrtcProgram program ;
3640 CHECK_NVRTC (nvrtcCreateProgram (& program , kernel -> cuda_code , kernel -> key .entry_point , 0 , NULL , NULL ), return false );
37- const char * args [] = { "--use_fast_math" };
38- nvrtcResult compile_result = nvrtcCompileProgram (program , sizeof (args ) / sizeof (* args ), args );
41+
42+ assert (kernel -> device -> cc_major < 10 && kernel -> device -> cc_minor < 10 );
43+
44+ char arch_flag [] = "-arch=compute_00" ;
45+ arch_flag [14 ] = '0' + kernel -> device -> cc_major ;
46+ arch_flag [15 ] = '0' + kernel -> device -> cc_minor ;
47+
48+ const char * options [] = {
49+ arch_flag ,
50+ "--use_fast_math"
51+ };
52+
53+ nvrtcResult compile_result = nvrtcCompileProgram (program , sizeof (options )/sizeof (* options ), options );
3954 if (compile_result != NVRTC_SUCCESS ) {
4055 error_print ("NVRTC compilation failed: %s\n" , nvrtcGetErrorString (compile_result ));
4156 debug_print ("Dumping source:\n%s" , kernel -> cuda_code );
4257 }
4358
44- if (get_log_level () <= DEBUG )
45- write_file ("cuda_dump.cu" , kernel -> cuda_code_size - 1 , kernel -> cuda_code );
46-
4759 size_t log_size ;
4860 CHECK_NVRTC (nvrtcGetProgramLogSize (program , & log_size ), return false );
4961 char * log_buffer = calloc (log_size , 1 );
@@ -61,13 +73,58 @@ static bool cuda_c_to_ptx(CudaKernel* kernel) {
6173 read_file (override_file , & kernel -> ptx_size , & kernel -> ptx );
6274 }
6375
76+ if (get_log_level () <= DEBUG )
77+ write_file ("cuda_dump.ptx" , kernel -> ptx_size - 1 , kernel -> ptx );
78+
6479 return true;
6580}
6681
6782static bool load_ptx_into_cuda_program (CudaKernel * kernel ) {
68- CHECK_CUDA (cuModuleLoadDataEx (& kernel -> cuda_module , kernel -> ptx , 0 , NULL , NULL ), return false );
69- CHECK_CUDA (cuModuleGetFunction (& kernel -> entry_point_function , kernel -> cuda_module , kernel -> key .entry_point ), return false );
83+ char info_log [10240 ] = {};
84+ char error_log [10240 ] = {};
85+
86+ CUjit_option options [] = {
87+ CU_JIT_INFO_LOG_BUFFER , CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
88+ CU_JIT_ERROR_LOG_BUFFER , CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
89+ CU_JIT_TARGET
90+ };
91+
92+ void * option_values [] = {
93+ info_log , (void * )(uintptr_t )sizeof (info_log ),
94+ error_log , (void * )(uintptr_t )sizeof (error_log ),
95+ (void * )(uintptr_t )(kernel -> device -> cc_major * 10 + kernel -> device -> cc_minor )
96+ };
97+
98+ CUlinkState linker ;
99+ CHECK_CUDA (cuLinkCreate (sizeof (options )/sizeof (options [0 ]), options , option_values , & linker ), goto err_linker_create );
100+ CHECK_CUDA (cuLinkAddData (linker , CU_JIT_INPUT_PTX , kernel -> ptx , kernel -> ptx_size , NULL , 0U , NULL , NULL ), goto err_post_linker_create );
101+
102+ void * binary ;
103+ size_t binary_size ;
104+ CHECK_CUDA (cuLinkComplete (linker , & binary , & binary_size ), goto err_post_linker_create );
105+
106+ if (* info_log )
107+ info_print ("CUDA JIT info: %s\n" , info_log );
108+
109+ if (get_log_level () <= DEBUG )
110+ write_file ("cuda_dump.cubin" , binary_size , binary );
111+
112+ CHECK_CUDA (cuModuleLoadData (& kernel -> cuda_module , binary ), goto err_post_linker_create );
113+ CHECK_CUDA (cuModuleGetFunction (& kernel -> entry_point_function , kernel -> cuda_module , kernel -> key .entry_point ), goto err_post_module_load );
114+
115+ cuLinkDestroy (linker );
70116 return true;
117+
118+ err_post_module_load :
119+ cuModuleUnload (kernel -> cuda_module );
120+ err_post_linker_create :
121+ cuLinkDestroy (linker );
122+ if (* info_log )
123+ info_print ("CUDA JIT info: %s\n" , info_log );
124+ if (* error_log )
125+ error_print ("CUDA JIT failed: %s\n" , error_log );
126+ err_linker_create :
127+ return false;
71128}
72129
73130static CudaKernel * create_specialized_program (CudaDevice * device , SpecProgramKey key ) {
0 commit comments