Skip to content

Commit cb6ed16

Browse files
committed
Update base for Update on "[Executorch] Make module constructors uniform across"
Existing constructors dont compose well such that if you want data loader or data files constructor then you cannot get to override memory allocator. Fix that. Differential Revision: [D86120037](https://our.internmc.facebook.com/intern/diff/D86120037/) [ghstack-poisoned]
2 parents ed145f5 + 93bf861 commit cb6ed16

File tree

133 files changed

+4838
-1363
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

133 files changed

+4838
-1363
lines changed

.github/workflows/pull.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,15 +862,24 @@ jobs:
862862
# Install Node.js and Emscripten
863863
source .ci/scripts/setup-emscripten.sh
864864
865+
export PNPM_VERSION=10.24.0
866+
867+
curl -fsSL https://get.pnpm.io/install.sh | env PNPM_VERSION=$PNPM_VERSION SHELL="$(which bash)" sh -
868+
869+
export PNPM_HOME="$HOME/.local/share/pnpm"
870+
export PATH="$PNPM_HOME:$PATH"
871+
872+
pnpm --version
873+
865874
# Test selective build
866875
bash scripts/build_wasm_tests.sh ${{ matrix.enable-etdump }}
867876
868877
# Install Jest
869878
cd cmake-out-wasm/extension/wasm/test
870-
npm install --save-dev jest
879+
pnpm add -D jest@30.2.0 --ignore-scripts
871880
872881
# Run unit test
873-
npm test
882+
pnpm test
874883
875884
unittest-nxp-neutron:
876885
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/aoti/aoti_backend.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import typing
1010
from abc import ABC, abstractmethod
1111
from enum import Enum
12-
from typing import Any, Dict, List, Optional, Set
12+
from typing import Any, Dict, List, Set
1313

1414
import torch
1515
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
@@ -70,10 +70,15 @@ def get_aoti_compile_options(
7070

7171
@classmethod
7272
@abstractmethod
73-
def get_custom_passes(cls) -> List[typing.Any]:
73+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
7474
"""Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
7575
pass
7676

77+
@classmethod
78+
def get_extra_aoti_compile_context_manager(cls):
79+
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
80+
return contextlib.nullcontext()
81+
7782
@classmethod
7883
@contextlib.contextmanager
7984
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
@@ -91,39 +96,24 @@ def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]
9196
)
9297

9398
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
94-
self,
95-
kernel: str,
96-
args: list[str],
97-
device: str,
98-
*,
99-
debug_args: Optional[list[str]] = None,
100-
debug_handle: Optional[int] = None,
101-
):
99+
self, kernel: str, *args: Any, **kwargs: Any
100+
) -> None:
102101
if kernel not in supported_kernels:
103102
missing_fallback_kernels.add(kernel)
104103

105-
original_generate_c_shim_extern_kernel_call(
106-
self,
107-
kernel,
108-
args,
109-
device,
110-
debug_args=debug_args,
111-
debug_handle=debug_handle,
104+
return original_generate_c_shim_extern_kernel_call(
105+
self, kernel, *args, **kwargs
112106
)
113107

114108
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
115-
self,
116-
op_overload,
117-
raw_args,
118-
output_args,
119-
raw_outputs,
120-
):
109+
self, op_overload: Any, *args: Any, **kwargs: Any
110+
) -> None:
121111
kernel_name = getattr(op_overload, "_name", str(op_overload))
122112
if kernel_name not in supported_kernels:
123113
missing_fallback_kernels.add(kernel_name)
124114

125-
original_generate_fallback_kernel_with_runtime_lookup_aot(
126-
self, op_overload, raw_args, output_args, raw_outputs
115+
return original_generate_fallback_kernel_with_runtime_lookup_aot(
116+
self, op_overload, *args, **kwargs
127117
)
128118

129119
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
@@ -164,7 +154,7 @@ def preprocess(
164154
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)
165155

166156
# Apply custom backend-specific passes
167-
custom_passes = cls.get_custom_passes()
157+
custom_passes = cls.get_custom_passes(compile_specs)
168158
for custom_pass in custom_passes:
169159
custom_pass(device_edge_program.graph_module)
170160

@@ -189,7 +179,7 @@ def preprocess(
189179
# Compile with fallback kernel collection
190180
with cls.collect_unsupported_fallback_kernels(
191181
missing_fallback_kernels
192-
), torch.no_grad():
182+
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
193183
paths = torch._inductor.aot_compile(
194184
edge_program_module, tuple(user_input_placeholders), options=options
195185
)

backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,39 +101,50 @@ + (NSString *)debugSymbolToHandlesKeyName {
101101
}
102102

103103
+ (nullable NSString *)assetsDirectoryPath {
104-
static dispatch_once_t onceToken;
105-
static NSString *result = nil;
106-
dispatch_once(&onceToken, ^{
107-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
108-
if (paths.count > 0) {
109-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
110-
}
111-
});
112-
113-
return result;
104+
#if defined(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH)
105+
return @(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH);
106+
#else
107+
static dispatch_once_t onceToken;
108+
static NSString *result = nil;
109+
dispatch_once(&onceToken, ^{
110+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
111+
if (paths.count > 0) {
112+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
113+
}
114+
});
115+
116+
return result;
117+
#endif
114118
}
115119

116120
+ (nullable NSString *)trashDirectoryPath {
117-
static dispatch_once_t onceToken;
118-
static NSString *result = nil;
119-
dispatch_once(&onceToken, ^{
120-
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
121-
});
122-
123-
return result;
121+
#if defined(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH)
122+
return @(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH);
123+
#else
124+
static dispatch_once_t onceToken;
125+
static NSString *result = nil;
126+
dispatch_once(&onceToken, ^{
127+
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
128+
});
129+
130+
return result;
131+
#endif
124132
}
125133

126134
+ (nullable NSString *)databaseDirectoryPath {
127-
static dispatch_once_t onceToken;
128-
static NSString *result = nil;
129-
dispatch_once(&onceToken, ^{
130-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
131-
if (paths.count > 0) {
132-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
133-
}
134-
});
135-
136-
return result;
135+
#if defined(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH)
136+
return @(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH);
137+
#else
138+
static dispatch_once_t onceToken;
139+
static NSString *result = nil;
140+
dispatch_once(&onceToken, ^{
141+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
142+
if (paths.count > 0) {
143+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
144+
}
145+
});
146+
return result;
147+
#endif
137148
}
138149

139150

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4242
return {}
4343

4444
@classmethod
45-
def get_custom_passes(cls) -> List[typing.Any]:
45+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
4646
"""Return Metal-specific passes (currently none)"""
4747
return []
4848

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1010
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
1111
from .broadcast_args_pass import BroadcastArgsPass # noqa
12-
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
1312
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1413
from .cast_to_int32_pass import CastToInt32Pass # noqa
1514
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -52,6 +51,7 @@
5251
from .decompose_int16_activation_conv2d_pass import ( # noqa
5352
DecomposeConv2dWithInt16ActivationPass,
5453
)
54+
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
5555
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
5656
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5757
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
@@ -100,6 +100,7 @@
100100
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
101101
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
102102
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
103+
from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa
103104
from .remove_getitem_pass import RemoveGetItemPass # noqa
104105
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
105106
from .remove_noop_pass import RemoveNoopPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
AnnotateDecomposedMatmulPass,
1515
AnnotateOutputDimOrderPass,
1616
BroadcastArgsPass,
17-
CastBoolToInt8Pass,
1817
CastInt64BuffersToInt32Pass,
1918
CastToInt32Pass,
2019
ComputeConstantOpsAOTPass,
@@ -55,6 +54,7 @@
5554
DecomposeGluPass,
5655
DecomposeGroupedConvPass,
5756
DecomposeGroupNormPass,
57+
DecomposeInt32ClampPass,
5858
DecomposeIntPowPass,
5959
DecomposeLayerNormPass,
6060
DecomposeLeakyReLUPass,
@@ -92,6 +92,7 @@
9292
InsertTableOpsPass,
9393
MatchArgDtypePass,
9494
MatchArgRanksPass,
95+
PromoteBoolOperandsPass,
9596
QuantizeClampArgumentsPass,
9697
RemoveGetItemPass,
9798
RemoveGraphAssertsPass,
@@ -122,7 +123,6 @@
122123

123124

124125
class ArmPassManager(PassManager):
125-
126126
def __init__(self, tosa_spec: TosaSpecification) -> None:
127127
self.tosa_spec = tosa_spec
128128
super().__init__()
@@ -174,6 +174,7 @@ def _tosa_pipeline(
174174
FuseQuantizedActivationPass(),
175175
RemoveGetItemPass(),
176176
ConvertToClampPass(),
177+
DecomposeInt32ClampPass(),
177178
DecomposeGroupNormPass(),
178179
DecomposeLayerNormPass(),
179180
DecomposeBatchNormNoStatsPass(),
@@ -217,7 +218,7 @@ def _tosa_pipeline(
217218
DecomposeEluPass(),
218219
DecomposeExpm1Pass(),
219220
DecomposeIntPowPass(),
220-
CastBoolToInt8Pass(),
221+
PromoteBoolOperandsPass(),
221222
DecomposeSinhPass(),
222223
DecomposeSignPass(),
223224
DecomposeFloorDividePass(),
@@ -329,7 +330,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
329330
DecomposeScaledDotProductAttentionPass(),
330331
DecomposeRoundPass(),
331332
DecomposeLogitPass(),
332-
CastBoolToInt8Pass(),
333+
PromoteBoolOperandsPass(),
333334
DecomposeSignPass(),
334335
DecomposeAddmmPass(),
335336
DecomposeRemainderPass(),

backends/arm/_passes/cast_bool_to_int8_pass.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)