Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion backends/apple/coreml/compiler/coreml_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class COMPILE_SPEC_KEYS(Enum):
MODEL_COMPUTE_PRECISION = "model_compute_precision"
OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
ENUMERATED_SHAPES = "enumerated_shapes"
PASS_PIPELINE = "pass_pipeline"


class MODEL_PATHS(Enum):
Expand Down Expand Up @@ -220,6 +221,33 @@ def op_linear_quantizer_config_from_compile_specs(

return None

@staticmethod
def generate_pass_pipeline_compile_spec(pass_names: List[str]) -> CompileSpec:
"""
Creates a compile spec representing the pass pipeline to be used by the CoreML backend
:param pass_names: the list of pass names
"""
str_representation = json.dumps(pass_names)
byte_representation = str_representation.encode("utf-8")
return CompileSpec(COMPILE_SPEC_KEYS.PASS_PIPELINE.value, byte_representation)

@staticmethod
def pass_pipeline_from_compile_specs(
compile_specs: List[CompileSpec],
) -> ct.PassPipeline:
"""
Creates a PassPipeline from the list of compile specs, or returns the default if none are provided.
"""
for compile_spec in compile_specs:
if compile_spec.key == COMPILE_SPEC_KEYS.PASS_PIPELINE.value:
pass_names_str = compile_spec.value.decode("utf-8")
pass_names = json.loads(pass_names_str)
return ct.PassPipeline(
pass_names, pipeline_name="executorch_user_pipeline"
)

return ct.PassPipeline.DEFAULT

@staticmethod
def generate_enumerated_shapes_compile_spec(
ep: ExportedProgram,
Expand Down Expand Up @@ -275,6 +303,7 @@ def generate_compile_specs(
compute_precision: ct.precision = ct.precision.FLOAT16,
model_type: MODEL_TYPE = MODEL_TYPE.MODEL,
op_linear_quantizer_config: Optional[Dict] = None,
pass_names: Optional[List[str]] = None,
) -> List[CompileSpec]:
"""
Returns the list of compile specs that's used by CoreMLBackend to lower the module.
Expand All @@ -298,6 +327,10 @@ def generate_compile_specs(
op_linear_quantizer_config
)
)
if pass_names is not None:
compile_specs.append(
CoreMLBackend.generate_pass_pipeline_compile_spec(pass_names)
)

return compile_specs

Expand Down Expand Up @@ -503,6 +536,9 @@ def preprocess(
enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs(
compile_specs
)
pass_pipeline: ct.PassPipeline = CoreMLBackend.pass_pipeline_from_compile_specs(
compile_specs
)

# If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
# explicitly
Expand Down Expand Up @@ -530,7 +566,7 @@ def preprocess(
model=edge_program,
source="pytorch",
convert_to="mlprogram",
pass_pipeline=ct.PassPipeline.DEFAULT,
pass_pipeline=pass_pipeline,
skip_model_load=skip_model_load,
compute_precision=model_compute_precision,
minimum_deployment_target=minimum_deployment_target,
Expand Down