|
19 | 19 | from torch.autograd import Variable |
20 | 20 |
|
21 | 21 | from tensor_comprehensions.tc import ATenCompilationUnit |
22 | | -from tensor_comprehensions.tc import global_debug_init as GlobalDebugInit |
| 22 | +from tensor_comprehensions.tc import set_logtostderr, set_debug_lang, set_debug_halide, set_debug_tc_mapper, set_debug_cuda, set_debug_tuner, set_dump_cuda |
23 | 23 | from tensor_comprehensions.torch_tc.tc_function import TCFunction, unpack_variables, get_tensors, make_contiguous |
24 | 24 | from tensor_comprehensions.autotuner import ATenAutotuner |
25 | 25 | from tensor_comprehensions.mapping_options import Options |
|
38 | 38 | "threads": 32, "generations": 5, "tuner_min_launch_total_threads": 1, |
39 | 39 | } |
40 | 40 |
|
| 41 | +############################################################################### |
| 42 | +# Set global debugging flags |
| 43 | +############################################################################### |
| 44 | +class SetDebugFlags(object): |
| 45 | + def __init__(self, **kwargs): |
| 46 | + self.set_gflags(**kwargs) |
| 47 | + |
| 48 | + def set_gflags( |
| 49 | + self, debug_lang=False, debug_halide=False, debug_tc_mapper=False, |
| 50 | + debug_cuda=False, debug_tuner=False, dump_cuda=False, **kwargs |
| 51 | + ): |
| 52 | + set_logtostderr(True) |
| 53 | + set_debug_lang(debug_lang) |
| 54 | + set_debug_halide(debug_halide) |
| 55 | + set_debug_tc_mapper(debug_tc_mapper) |
| 56 | + set_debug_cuda(debug_cuda) |
| 57 | + set_debug_tuner(debug_tuner) |
| 58 | + set_dump_cuda(dump_cuda) |
| 59 | + |
| 60 | + |
41 | 61 | ############################################################################### |
42 | 62 | # Some helper functions |
43 | 63 | ############################################################################### |
|
0 commit comments