|
5 | 5 | from .cpupool import CPUPool |
6 | 6 | from .task import Task |
7 | 7 | import copy |
| 8 | +import warnings |
8 | 9 |
|
9 | 10 | class MultiStreamModuleHint(object): |
10 | 11 | def __init__(self, *args, **kwargs): |
@@ -91,6 +92,9 @@ def __init__(self, |
91 | 92 | output_concat_hint: MultiStreamModuleHint = default_multi_stream_module_concat_hint): |
92 | 93 | super(MultiStreamModule, self).__init__() |
93 | 94 | assert type(cpu_pool) is CPUPool, "Input of cpu_pool must be provided with type of ipex.cpu.runtime.CPUPool" |
| 95 | + if not isinstance(model, torch.jit.ScriptModule): |
| 96 | + warnings.warn("Creating MultiStreamModule on an nn.Module. This can be slow due " |
| 97 | + "to Python Global Interpreter Lock (GIL). Suggest to use JIT ScriptModule for better performance.") |
94 | 98 | self.core_list = cpu_pool.core_ids |
95 | 99 | if isinstance(num_streams, str): |
96 | 100 | # For str input of num_streams, it must be "auto" |
@@ -215,7 +219,7 @@ def _do_get_input_for_each_stream(self, hint_object, input_object, stream_input_ |
215 | 219 | self.init_forward_status(input_object[idx_or_key].size(hint_object[idx_or_key]), stream_id) |
216 | 220 | # Get the split input for each stream |
217 | 221 | # Here we assume split along the outside dim, otherwise memory copy happens and obviously hurt multi stream module's performance. |
218 | | - if hint_object[idx_or_key] is 0: |
| 222 | + if hint_object[idx_or_key] == 0: |
219 | 223 | # Split along dim 0, the slice will not create new tensor |
220 | 224 | stream_input_object[idx_or_key] = input_object[idx_or_key][self.current_split_start_idx:self.current_split_end_idx] |
221 | 225 | else: |
|
0 commit comments