|
1 | | -""" Model / Layer Config Singleton |
| 1 | +""" Model / Layer Config singleton state |
2 | 2 | """ |
3 | | -from typing import Any |
| 3 | +from typing import Any, Optional |
4 | 4 |
|
5 | | -__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit'] |
| 5 | +__all__ = [ |
| 6 | + 'is_exportable', 'is_scriptable', 'is_no_jit', |
| 7 | + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' |
| 8 | +] |
6 | 9 |
|
7 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) |
8 | 11 | _NO_JIT = False |
9 | 12 |
|
10 | 13 | # Set to True if prefer to have activation layers with no jit optimization |
| 14 | +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying |
| 15 | +# the jit flags so far are activations. This will change as more layers are updated and/or added. |
11 | 16 | _NO_ACTIVATION_JIT = False |
12 | 17 |
|
13 | 18 | # Set to True if exporting a model with Same padding via ONNX |
@@ -72,3 +77,39 @@ def __exit__(self, *args: Any) -> bool: |
72 | 77 | global _SCRIPTABLE |
73 | 78 | _SCRIPTABLE = self.prev |
74 | 79 | return False |
| 80 | + |
| 81 | + |
| 82 | +class set_layer_config: |
| 83 | + """ Layer config context manager that allows setting all layer config flags at once. |
| 84 | + If a flag arg is None, it will not change the current value. |
| 85 | + """ |
| 86 | + def __init__( |
| 87 | + self, |
| 88 | + scriptable: Optional[bool] = None, |
| 89 | + exportable: Optional[bool] = None, |
| 90 | + no_jit: Optional[bool] = None, |
| 91 | + no_activation_jit: Optional[bool] = None): |
| 92 | + global _SCRIPTABLE |
| 93 | + global _EXPORTABLE |
| 94 | + global _NO_JIT |
| 95 | + global _NO_ACTIVATION_JIT |
| 96 | + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT |
| 97 | + if scriptable is not None: |
| 98 | + _SCRIPTABLE = scriptable |
| 99 | + if exportable is not None: |
| 100 | + _EXPORTABLE = exportable |
| 101 | + if no_jit is not None: |
| 102 | + _NO_JIT = no_jit |
| 103 | + if no_activation_jit is not None: |
| 104 | + _NO_ACTIVATION_JIT = no_activation_jit |
| 105 | + |
| 106 | + def __enter__(self) -> None: |
| 107 | + pass |
| 108 | + |
| 109 | + def __exit__(self, *args: Any) -> bool: |
| 110 | + global _SCRIPTABLE |
| 111 | + global _EXPORTABLE |
| 112 | + global _NO_JIT |
| 113 | + global _NO_ACTIVATION_JIT |
| 114 | + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev |
| 115 | + return False |
0 commit comments