77import sys
88import warnings
99
10+
11+ # Parse custom build arguments
12+ def parse_build_args ():
13+ """Parse custom build arguments for CPU/GPU selection.
14+
15+ Usage:
16+ python setup.py develop cpu # Force CPU-only build
17+ python setup.py develop gpu # Force GPU build (fallback to CPU if not available)
18+ python setup.py develop # Auto-detect (prefer GPU if available)
19+ """
20+ build_mode = 'auto' # Default to auto-detect
21+
22+ # Check for help request
23+ if 'help' in sys .argv or '--help' in sys .argv :
24+ print ("\n Sparse Transformers Build Options:" )
25+ print (" python setup.py develop cpu # Force CPU-only build" )
26+ print (" python setup.py develop gpu # Force GPU build" )
27+ print (" python setup.py develop # Auto-detect (prefer GPU)" )
28+ print ()
29+
30+ # Check for our custom arguments
31+ if 'cpu' in sys .argv :
32+ build_mode = 'cpu'
33+ sys .argv .remove ('cpu' )
34+ print ("Forced CPU-only build mode" )
35+ elif 'gpu' in sys .argv :
36+ build_mode = 'gpu'
37+ sys .argv .remove ('gpu' )
38+ print ("Forced GPU build mode" )
39+ else :
40+ print ("Auto-detecting build mode (default: GPU if available)" )
41+
42+ return build_mode
43+
44+
1045# Check PyTorch C++ ABI compatibility
1146def get_pytorch_abi_flag ():
1247 """Get the correct C++ ABI flag to match PyTorch compilation."""
1348 return f'-D_GLIBCXX_USE_CXX11_ABI={ int (torch ._C ._GLIBCXX_USE_CXX11_ABI )} '
1449
50+
1551# Get PyTorch ABI flag
1652pytorch_abi_flag = get_pytorch_abi_flag ()
1753print (f"Using PyTorch C++ ABI flag: { pytorch_abi_flag } " )
1854
55+ # Parse build mode from command line
56+ build_mode = parse_build_args ()
57+
1958# Create build directory if it doesn't exist
2059build_dir = Path (__file__ ).parent / 'build'
2160if build_dir .exists ():
@@ -36,12 +75,15 @@ def get_pytorch_abi_flag():
3675 for i in range (torch .cuda .device_count ()):
3776 arch_list .append (torch .cuda .get_device_capability (i ))
3877 arch_list = sorted (list (set (arch_list )))
39- arch_flags = [f"-gencode=arch=compute_{ arch [0 ]} { arch [1 ]} ,code=sm_{ arch [0 ]} { arch [1 ]} " for arch in arch_list ]
78+ arch_flags = [
79+ f"-gencode=arch=compute_{ arch [0 ]} { arch [1 ]} ,code=sm_{ arch [0 ]} { arch [1 ]} "
80+ for arch in arch_list
81+ ]
4082 print (f"CUDA architectures detected: { arch_list } " )
4183 except Exception as e :
4284 warnings .warn (f"Error detecting CUDA architecture: { e } " )
4385 # Use a common architecture as fallback
44- arch_flags = [" -gencode=arch=compute_86,code=sm_86" ]
86+ arch_flags = [' -gencode=arch=compute_86,code=sm_86' ]
4587
4688# Common optimization flags (compatible with both old and new ABI)
4789common_compile_args = [
@@ -59,6 +101,7 @@ def get_pytorch_abi_flag():
59101# Try to detect if we can use advanced CPU optimizations safely
60102try :
61103 import platform
104+
62105 if platform .machine () in ['x86_64' , 'AMD64' ]:
63106 advanced_cpu_flags = [
64107 '-march=native' , # Optimize for local CPU architecture
@@ -72,28 +115,39 @@ def get_pytorch_abi_flag():
72115 advanced_cpu_flags = []
73116
74117# CPU-specific optimization flags
75- cpu_compile_args = common_compile_args + advanced_cpu_flags + [
76- '-flto' , # Link-time optimization
77- '-funroll-loops' , # Unroll loops
78- '-fno-math-errno' , # Assume math functions never set errno
79- '-fno-trapping-math' , # Assume FP ops don't generate traps
80- '-fno-plt' , # Improve indirect call performance
81- '-fuse-linker-plugin' , # Enable LTO plugin
82- '-fomit-frame-pointer' , # Remove frame pointers
83- '-fno-stack-protector' , # Disable stack protector
84- '-fvisibility=hidden' , # Hide all symbols by default
85- '-fdata-sections' , # Place each data item into its own section
86- '-ffunction-sections' , # Place each function into its own section
87- '-fvisibility=default' ,
88- ]
118+ cpu_compile_args = (
119+ common_compile_args
120+ + advanced_cpu_flags
121+ + [
122+ '-flto' , # Link-time optimization
123+ '-funroll-loops' , # Unroll loops
124+ '-fno-math-errno' , # Assume math functions never set errno
125+ '-fno-trapping-math' , # Assume FP ops don't generate traps
126+ '-fno-plt' , # Improve indirect call performance
127+ '-fuse-linker-plugin' , # Enable LTO plugin
128+ '-fomit-frame-pointer' , # Remove frame pointers
129+ '-fno-stack-protector' , # Disable stack protector
130+ '-fvisibility=hidden' , # Hide all symbols by default
131+ '-fdata-sections' , # Place each data item into its own section
132+ '-ffunction-sections' , # Place each function into its own section
133+ '-fvisibility=default' ,
134+ ]
135+ )
89136
90137# CUDA-specific optimization flags (ensure C++17 compatibility and ABI matching)
91- cuda_compile_args = ['-O3' , '--use_fast_math' ] + arch_flags + [
92- '--compiler-options' , f"'-fPIC'" ,
93- '--compiler-options' , f"'-O3'" ,
94- '-std=c++17' , # Force C++17 for compatibility
95- '--compiler-options' , "'-fvisibility=default'" ,
96- ]
138+ cuda_compile_args = (
139+ ['-O3' , '--use_fast_math' ]
140+ + arch_flags
141+ + [
142+ '--compiler-options' ,
143+ "'-fPIC'" ,
144+ '--compiler-options' ,
145+ "'-O3'" ,
146+ '-std=c++17' , # Force C++17 for compatibility
147+ '--compiler-options' ,
148+ "'-fvisibility=default'" ,
149+ ]
150+ )
97151
98152# Add advanced CPU flags to CUDA compilation if available
99153if advanced_cpu_flags :
@@ -112,28 +166,30 @@ def get_pytorch_abi_flag():
112166 '-Wl,--exclude-libs,ALL' , # Don't export any symbols from libraries
113167]
114168
169+
115170# Get CUDA include paths
116171def get_cuda_include_dirs ():
117172 cuda_home = os .getenv ('CUDA_HOME' , '/usr/local/cuda' )
118173 if not os .path .exists (cuda_home ):
119174 cuda_home = os .getenv ('CUDA_PATH' ) # Windows
120-
175+
121176 if cuda_home is None :
122177 # Try common CUDA locations
123178 for path in ['/usr/local/cuda' , '/opt/cuda' , '/usr/cuda' ]:
124179 if os .path .exists (path ):
125180 cuda_home = path
126181 break
127-
182+
128183 if cuda_home is None :
129- warnings .warn (" CUDA installation not found. CUDA extensions will not be built." )
184+ warnings .warn (' CUDA installation not found. CUDA extensions will not be built.' )
130185 return []
131-
186+
132187 return [
133188 os .path .join (cuda_home , 'include' ),
134- os .path .join (cuda_home , 'samples' , 'common' , 'inc' )
189+ os .path .join (cuda_home , 'samples' , 'common' , 'inc' ),
135190 ]
136191
192+
137193# Base extension configuration
138194base_include_dirs = [
139195 os .path .dirname (torch .__file__ ) + '/include' ,
@@ -152,7 +208,34 @@ def get_cuda_include_dirs():
152208 warnings .warn (f"C++ source file not found: { cpp_source } " )
153209 raise FileNotFoundError (f"Missing source file: { cpp_source } " )
154210
155- if torch .cuda .is_available () and os .path .exists (cuda_source ):
211+ # Determine if we should build CUDA extension based on build mode
212+ should_build_cuda = False
213+
214+ if build_mode == 'cpu' :
215+ print ("CPU-only build requested - skipping CUDA" )
216+ should_build_cuda = False
217+ elif build_mode == 'gpu' :
218+ print ("GPU build requested" )
219+ if not torch .cuda .is_available ():
220+ print ("WARNING: GPU build requested but PyTorch CUDA not available" )
221+ print (" Falling back to CPU-only build" )
222+ should_build_cuda = False
223+ elif not os .path .exists (cuda_source ):
224+ print ("WARNING: GPU build requested but CUDA source file not found" )
225+ print (" Falling back to CPU-only build" )
226+ should_build_cuda = False
227+ else :
228+ should_build_cuda = True
229+ else : # auto mode
230+ # Default behavior: prefer GPU if available, otherwise CPU
231+ if torch .cuda .is_available () and os .path .exists (cuda_source ):
232+ print ("Auto-detected: Building GPU extension (CUDA available)" )
233+ should_build_cuda = True
234+ else :
235+ print ("Auto-detected: Building CPU-only extension (CUDA not available)" )
236+ should_build_cuda = False
237+
238+ if should_build_cuda :
156239 print ("Building CUDA extension..." )
157240 cuda_include_dirs = get_cuda_include_dirs ()
158241 if cuda_include_dirs :
@@ -161,35 +244,34 @@ def get_cuda_include_dirs():
161244 name = 'sparse_transformers.sparse_transformers' ,
162245 sources = [cpp_source , cuda_source ],
163246 include_dirs = base_include_dirs ,
164- extra_compile_args = {
165- 'cxx' : cpu_compile_args ,
166- 'nvcc' : cuda_compile_args
167- },
247+ extra_compile_args = {'cxx' : cpu_compile_args , 'nvcc' : cuda_compile_args },
168248 extra_link_args = extra_link_args ,
169249 libraries = ['gomp' , 'cudart' ],
170250 library_dirs = [str (build_dir / 'lib' )],
171- define_macros = [('WITH_CUDA' , None )]
251+ define_macros = [('WITH_CUDA' , None )],
172252 )
173253 else :
174- print ("CUDA include directories not found, falling back to CPU-only extension..." )
175- raise RuntimeError ("CUDA headers not found" )
176- else :
254+ print (
255+ "CUDA include directories not found, falling back to CPU-only extension..."
256+ )
257+ should_build_cuda = False
258+
259+ if not should_build_cuda :
177260 print ("Building CPU-only extension..." )
178- cuda_include_dirs = get_cuda_include_dirs ()
179- if cuda_include_dirs :
180- base_include_dirs .extend (cuda_include_dirs )
181261 extension = CppExtension (
182262 name = 'sparse_transformers.sparse_transformers' ,
183263 sources = [cpp_source ],
184264 extra_compile_args = cpu_compile_args ,
185265 extra_link_args = extra_link_args ,
186266 library_dirs = [str (build_dir / 'lib' )],
187267 include_dirs = base_include_dirs ,
188- libraries = ['gomp' ]
268+ libraries = ['gomp' ],
269+ define_macros = [('CPU_ONLY' , None )],
189270 )
190271
191272ext_modules .append (extension )
192- print (f"Extension configured successfully: { extension .name } " )
273+ build_type = "CUDA" if should_build_cuda else "CPU-only"
274+ print (f"Extension configured successfully: { extension .name } ({ build_type } )" )
193275
194276
195277# Custom build extension to handle clean builds and ABI compatibility
@@ -203,18 +285,19 @@ def get_ext_fullpath(self, ext_name):
203285 # Override to ensure extension is built in our build directory
204286 filename = self .get_ext_filename (ext_name )
205287 return str (build_dir / 'lib' / filename )
206-
288+
207289 def build_extensions (self ):
208290 # Disable parallel build for better error reporting and CUDA compatibility
209291 if self .parallel :
210292 self .parallel = False
211-
293+
212294 # Print compilation info for debugging
213295 print (f"Building extensions with PyTorch { torch .__version__ } " )
214296 print (f"PyTorch C++ ABI: { pytorch_abi_flag } " )
215297 super ().build_extensions ()
216298 print ("C++ extension built successfully!" )
217299
300+
218301# Read requirements from requirements.txt
219302def read_requirements ():
220303 requirements_path = Path (__file__ ).parent / 'requirements.txt'
@@ -228,6 +311,7 @@ def read_requirements():
228311 return requirements
229312 return []
230313
314+
231315setup (
232316 name = 'sparse_transformers' ,
233317 version = '0.0.1' ,
@@ -241,4 +325,4 @@ def read_requirements():
241325 python_requires = '>=3.8' ,
242326 include_package_data = True ,
243327 zip_safe = False , # Required for C++ extensions
244- )
328+ )
0 commit comments