@@ -26,65 +26,6 @@ class _Enum_Type(object):
2626 def __init__ (self , v ):
2727 self .value = v
2828
29- class _clibrary (object ):
30-
31- def __libname (self , name ):
32- platform_name = platform .system ()
33- assert (len (platform_name ) >= 3 )
34-
35- libname = 'libaf' + name
36- if platform_name == 'Linux' :
37- libname += '.so'
38- elif platform_name == 'Darwin' :
39- libname += '.dylib'
40- elif platform_name == "Windows" or platform_name [:3 ] == "CYG" :
41- libname += '.dll'
42- libname = libname [3 :] # remove 'lib'
43- if platform_name == "Windows" :
44- '''
45- Supressing crashes caused by missing dlls
46- http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
47- https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
48- '''
49- ct .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 );
50- else :
51- raise OSError (platform_name + ' not supported' )
52-
53- return libname
54-
55- def set (self , name , unsafe = False ):
56- if (not unsafe and self .__lock ):
57- raise RuntimeError ("Can not change backend after creating an Array" )
58- if (self .clibs [name ] is None ):
59- raise RuntimeError ("Could not load any ArrayFire %s backend" % name )
60- self .name = name
61- return
62-
63- def __init__ (self ):
64- self .clibs = {}
65- self .name = None
66- self .__lock = False
67- # Iterate in reverse order of preference
68- for name in ('cpu' , 'opencl' , 'cuda' ):
69- try :
70- libname = self .__libname (name )
71- ct .cdll .LoadLibrary (libname )
72- self .clibs [name ] = ct .CDLL (libname )
73- self .name = name
74- except :
75- self .clibs [name ] = None
76-
77- if (self .name is None ):
78- raise RuntimeError ("Could not load any ArrayFire libraries" )
79-
80- def get (self ):
81- return self .clibs [self .name ]
82-
83- def lock (self ):
84- self .__lock = True
85-
86- backend = _clibrary ()
87-
8829class ERR (_Enum ):
8930 """
9031 Error values. For internal use only.
@@ -373,3 +314,168 @@ class BACKEND(_Enum):
373314 CPU = _Enum_Type (1 )
374315 CUDA = _Enum_Type (2 )
375316 OPENCL = _Enum_Type (4 )
317+
318+ class _clibrary (object ):
319+
320+ def __libname (self , name ):
321+ platform_name = platform .system ()
322+ assert (len (platform_name ) >= 3 )
323+
324+ libname = 'libaf' + name
325+ if platform_name == 'Linux' :
326+ libname += '.so'
327+ elif platform_name == 'Darwin' :
328+ libname += '.dylib'
329+ elif platform_name == "Windows" or platform_name [:3 ] == "CYG" :
330+ libname += '.dll'
331+ libname = libname [3 :] # remove 'lib'
332+ if platform_name == "Windows" :
333+ '''
334+ Supressing crashes caused by missing dlls
335+ http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
336+ https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
337+ '''
338+ ct .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 );
339+ else :
340+ raise OSError (platform_name + ' not supported' )
341+
342+ return libname
343+
344+ def set_unsafe (self , name ):
345+ lib = self .__clibs [name ]
346+ if (lib is None ):
347+ raise RuntimeError ("Backend not found" )
348+ self .__name = name
349+
350+ def __init__ (self ):
351+ self .__name = None
352+
353+ self .__clibs = {'cuda' : None ,
354+ 'opencl' : None ,
355+ 'cpu' : None ,
356+ '' : None }
357+
358+ self .__backend_map = {0 : 'default' ,
359+ 1 : 'cpu' ,
360+ 2 : 'cuda' ,
361+ 4 : 'opencl' }
362+
363+ self .__backend_name_map = {'default' : 0 ,
364+ 'cpu' : 1 ,
365+ 'cuda' : 2 ,
366+ 'opencl' : 4 }
367+
368+ # Iterate in reverse order of preference
369+ for name in ('cpu' , 'opencl' , 'cuda' , '' ):
370+ try :
371+ libname = self .__libname (name )
372+ ct .cdll .LoadLibrary (libname )
373+ self .__clibs [name ] = ct .CDLL (libname )
374+ self .__name = name
375+ except :
376+ pass
377+
378+ if (self .__name is None ):
379+ raise RuntimeError ("Could not load any ArrayFire libraries" )
380+
381+ def get_id (self , name ):
382+ return self .__backend_name_map [name ]
383+
384+ def get_name (self , bk_id ):
385+ return self .__backend_map [bk_id ]
386+
387+ def get (self ):
388+ return self .__clibs [self .__name ]
389+
390+ def name (self ):
391+ return self .__name
392+
393+ def is_unified (self ):
394+ return self .__name == ''
395+
396+ def parse (self , res ):
397+ lst = []
398+ for key ,value in self .__backend_name_map .items ():
399+ if (value & res ):
400+ lst .append (key )
401+ return tuple (lst )
402+
403+ backend = _clibrary ()
404+
405+ def set_backend (name , unsafe = False ):
406+ """
407+ Set a specific backend by name
408+
409+ Parameters
410+ ----------
411+
412+ name : str.
413+
414+ unsafe : optional: bool. Default: False.
415+ If False, does not switch backend if current backend is not unified backend.
416+ """
417+ if (backend .is_unified () == False and unsanfe == False ):
418+ raise RuntimeError ("Can not change backend after loading %s" % name )
419+
420+ if (backend .is_unified ()):
421+ safe_call (backend .get ().af_set_backend (backend .get_id (name )))
422+ else :
423+ backend .set_unsafe (name )
424+ return
425+
426+ def get_backend_id (A ):
427+ """
428+ Get backend name of an array
429+
430+ Parameters
431+ ----------
432+ A : af.Array
433+
434+ Returns
435+ ----------
436+
437+ name : str.
438+ Backend name
439+ """
440+ if (backend .is_unified ()):
441+ backend_id = ct .c_int (BACKEND .DEFAULT .value )
442+ safe_call (backend .get ().af_get_backend_id (ct .pointer (backend_id ), A .arr ))
443+ return backend .get_name (backend_id .value )
444+ else :
445+ return backend .name ()
446+
447+ def get_backend_count ():
448+ """
449+ Get number of available backends
450+
451+ Returns
452+ ----------
453+
454+ count : int
455+ Number of available backends
456+ """
457+ if (backend .is_unified ()):
458+ count = ct .c_int (0 )
459+ safe_call (backend .get ().af_get_backend_count (ct .pointer (count )))
460+ return count .value
461+ else :
462+ return 1
463+
464+ def get_available_backends ():
465+ """
466+ Get names of available backends
467+
468+ Returns
469+ ----------
470+
471+ names : tuple of strings
472+ Names of available backends
473+ """
474+ if (backend .is_unified ()):
475+ available = ct .c_int (0 )
476+ safe_call (backend .get ().af_get_available_backends (ct .pointer (available )))
477+ return backend .parse (int (available .value ))
478+ else :
479+ return (backend .name (),)
480+
481+ from .util import safe_call
0 commit comments