Skip to content

Commit 9f9b2e0

Browse files
committed
FEAT: Changes to use the unified backend from arrayfire when available
1 parent 5e92e72 commit 9f9b2e0

File tree

2 files changed

+165
-61
lines changed

2 files changed

+165
-61
lines changed

arrayfire/array.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
369369

370370
_type_char='f'
371371

372-
backend.lock()
373-
374372
if src is not None:
375373

376374
if (isinstance(src, Array)):

arrayfire/library.py

Lines changed: 165 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -26,65 +26,6 @@ class _Enum_Type(object):
2626
def __init__(self, v):
2727
self.value = v
2828

29-
class _clibrary(object):
30-
31-
def __libname(self, name):
32-
platform_name = platform.system()
33-
assert(len(platform_name) >= 3)
34-
35-
libname = 'libaf' + name
36-
if platform_name == 'Linux':
37-
libname += '.so'
38-
elif platform_name == 'Darwin':
39-
libname += '.dylib'
40-
elif platform_name == "Windows" or platform_name[:3] == "CYG":
41-
libname += '.dll'
42-
libname = libname[3:] # remove 'lib'
43-
if platform_name == "Windows":
44-
'''
45-
Supressing crashes caused by missing dlls
46-
http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
47-
https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
48-
'''
49-
ct.windll.kernel32.SetErrorMode(0x0001 | 0x0002);
50-
else:
51-
raise OSError(platform_name + ' not supported')
52-
53-
return libname
54-
55-
def set(self, name, unsafe=False):
56-
if (not unsafe and self.__lock):
57-
raise RuntimeError("Can not change backend after creating an Array")
58-
if (self.clibs[name] is None):
59-
raise RuntimeError("Could not load any ArrayFire %s backend" % name)
60-
self.name = name
61-
return
62-
63-
def __init__(self):
64-
self.clibs = {}
65-
self.name = None
66-
self.__lock = False
67-
# Iterate in reverse order of preference
68-
for name in ('cpu', 'opencl', 'cuda'):
69-
try:
70-
libname = self.__libname(name)
71-
ct.cdll.LoadLibrary(libname)
72-
self.clibs[name] = ct.CDLL(libname)
73-
self.name = name
74-
except:
75-
self.clibs[name] = None
76-
77-
if (self.name is None):
78-
raise RuntimeError("Could not load any ArrayFire libraries")
79-
80-
def get(self):
81-
return self.clibs[self.name]
82-
83-
def lock(self):
84-
self.__lock = True
85-
86-
backend = _clibrary()
87-
8829
class ERR(_Enum):
8930
"""
9031
Error values. For internal use only.
@@ -373,3 +314,168 @@ class BACKEND(_Enum):
373314
CPU = _Enum_Type(1)
374315
CUDA = _Enum_Type(2)
375316
OPENCL = _Enum_Type(4)
317+
318+
class _clibrary(object):
319+
320+
def __libname(self, name):
321+
platform_name = platform.system()
322+
assert(len(platform_name) >= 3)
323+
324+
libname = 'libaf' + name
325+
if platform_name == 'Linux':
326+
libname += '.so'
327+
elif platform_name == 'Darwin':
328+
libname += '.dylib'
329+
elif platform_name == "Windows" or platform_name[:3] == "CYG":
330+
libname += '.dll'
331+
libname = libname[3:] # remove 'lib'
332+
if platform_name == "Windows":
333+
'''
334+
Supressing crashes caused by missing dlls
335+
http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
336+
https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
337+
'''
338+
ct.windll.kernel32.SetErrorMode(0x0001 | 0x0002);
339+
else:
340+
raise OSError(platform_name + ' not supported')
341+
342+
return libname
343+
344+
def set_unsafe(self, name):
345+
lib = self.__clibs[name]
346+
if (lib is None):
347+
raise RuntimeError("Backend not found")
348+
self.__name = name
349+
350+
def __init__(self):
351+
self.__name = None
352+
353+
self.__clibs = {'cuda' : None,
354+
'opencl' : None,
355+
'cpu' : None,
356+
'' : None}
357+
358+
self.__backend_map = {0 : 'default',
359+
1 : 'cpu' ,
360+
2 : 'cuda' ,
361+
4 : 'opencl' }
362+
363+
self.__backend_name_map = {'default' : 0,
364+
'cpu' : 1,
365+
'cuda' : 2,
366+
'opencl' : 4}
367+
368+
# Iterate in reverse order of preference
369+
for name in ('cpu', 'opencl', 'cuda', ''):
370+
try:
371+
libname = self.__libname(name)
372+
ct.cdll.LoadLibrary(libname)
373+
self.__clibs[name] = ct.CDLL(libname)
374+
self.__name = name
375+
except:
376+
pass
377+
378+
if (self.__name is None):
379+
raise RuntimeError("Could not load any ArrayFire libraries")
380+
381+
def get_id(self, name):
382+
return self.__backend_name_map[name]
383+
384+
def get_name(self, bk_id):
385+
return self.__backend_map[bk_id]
386+
387+
def get(self):
388+
return self.__clibs[self.__name]
389+
390+
def name(self):
391+
return self.__name
392+
393+
def is_unified(self):
394+
return self.__name == ''
395+
396+
def parse(self, res):
397+
lst = []
398+
for key,value in self.__backend_name_map.items():
399+
if (value & res):
400+
lst.append(key)
401+
return tuple(lst)
402+
403+
backend = _clibrary()
404+
405+
def set_backend(name, unsafe=False):
406+
"""
407+
Set a specific backend by name
408+
409+
Parameters
410+
----------
411+
412+
name : str.
413+
414+
unsafe : optional: bool. Default: False.
415+
If False, does not switch backend if current backend is not unified backend.
416+
"""
417+
if (backend.is_unified() == False and unsanfe == False):
418+
raise RuntimeError("Can not change backend after loading %s" % name)
419+
420+
if (backend.is_unified()):
421+
safe_call(backend.get().af_set_backend(backend.get_id(name)))
422+
else:
423+
backend.set_unsafe(name)
424+
return
425+
426+
def get_backend_id(A):
427+
"""
428+
Get backend name of an array
429+
430+
Parameters
431+
----------
432+
A : af.Array
433+
434+
Returns
435+
----------
436+
437+
name : str.
438+
Backend name
439+
"""
440+
if (backend.is_unified()):
441+
backend_id = ct.c_int(BACKEND.DEFAULT.value)
442+
safe_call(backend.get().af_get_backend_id(ct.pointer(backend_id), A.arr))
443+
return backend.get_name(backend_id.value)
444+
else:
445+
return backend.name()
446+
447+
def get_backend_count():
448+
"""
449+
Get number of available backends
450+
451+
Returns
452+
----------
453+
454+
count : int
455+
Number of available backends
456+
"""
457+
if (backend.is_unified()):
458+
count = ct.c_int(0)
459+
safe_call(backend.get().af_get_backend_count(ct.pointer(count)))
460+
return count.value
461+
else:
462+
return 1
463+
464+
def get_available_backends():
465+
"""
466+
Get names of available backends
467+
468+
Returns
469+
----------
470+
471+
names : tuple of strings
472+
Names of available backends
473+
"""
474+
if (backend.is_unified()):
475+
available = ct.c_int(0)
476+
safe_call(backend.get().af_get_available_backends(ct.pointer(available)))
477+
return backend.parse(int(available.value))
478+
else:
479+
return (backend.name(),)
480+
481+
from .util import safe_call

0 commit comments

Comments
 (0)