Skip to content

Commit ae7e167

Browse files
min-jean-chomin-jean-cho
andauthored
fix test_launcher (#911)
* updated test memory allocator * reverted warning print, updated search pattern * remove redundant code * remove redundant code Co-authored-by: min-jean-cho <minjeanc@mlp-prod-skx-7825.ra.intel.com>
1 parent 75eb59d commit ae7e167

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tests/cpu/test_launcher.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,22 @@ def find_lib(self, lib_type):
2828

2929
def test_memory_allocator_setup(self):
3030
launcher = Launcher()
31-
31+
3232
# tcmalloc
33-
launcher.set_memory_allocator(enable_tcmalloc=True)
3433
find_tcmalloc = self.find_lib("tcmalloc")
35-
ld_preload_in_os = "LD_PRELOAD" in os.environ
36-
tcmalloc_enabled = "libtcmalloc.so" in os.environ["LD_PRELOAD"] if ld_preload_in_os else False
34+
cmd = ["python", "-m", "intel_extension_for_pytorch.cpu.launch", "--enable_tcmalloc", "--no_python", "ls"]
35+
r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
36+
tcmalloc_enabled = "libtcmalloc.so" in str(r.stdout, "utf-8").split("INFO - LD_PRELOAD=", 1)[1]
3737
self.assertEqual(find_tcmalloc, tcmalloc_enabled)
38-
39-
# jemalloc
40-
launcher.set_memory_allocator(enable_tcmalloc=False, enable_jemalloc=True)
38+
39+
# jemalloc
4140
find_jemalloc = self.find_lib("jemalloc")
42-
jemalloc_enabled = "libjemalloc.so" in os.environ["LD_PRELOAD"] if ld_preload_in_os else False
41+
cmd = ["python", "-m", "intel_extension_for_pytorch.cpu.launch", "--enable_jemalloc", "--no_python", "ls"]
42+
r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
43+
jemalloc_enabled = "libjemalloc.so" in str(r.stdout, "utf-8").split("INFO - LD_PRELOAD=", 1)[1]
4344
self.assertEqual(find_jemalloc, jemalloc_enabled)
4445
if jemalloc_enabled:
45-
self.assertEqual(jemalloc_enabled, "MALLOC_CONF" in os.environ)
46+
self.assertEqual(jemalloc_enabled, "MALLOC_CONF" in str(r.stdout, "utf-8"))
4647

4748
def test_mpi_pin_domain_and_ccl_worker_affinity(self):
4849
launcher = DistributedTrainingLauncher()

0 commit comments

Comments
 (0)