From 9d05ca2b24a4302bf5a9ae55d71afd119ad14a03 Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Wed, 24 Sep 2025 12:59:04 +0000 Subject: [PATCH 1/4] PyTorch version upgrade: tested on single-operator tests --- PyTorchSimFrontend/extension_codecache.py | 3 +- PyTorchSimFrontend/extension_device.cpp | 6 +- .../extension_device_interface.py | 63 +++++++++++++++++ .../extension_device_op_overrides.py | 25 +++++++ PyTorchSimFrontend/extension_utils.py | 26 +++++++ PyTorchSimFrontend/llvm/llvm_common.py | 3 +- PyTorchSimFrontend/mlir/mlir_autotune.py | 8 ++- .../mlir/mlir_codegen_backend.py | 67 ++++++++++++++++--- PyTorchSimFrontend/mlir/mlir_common.py | 34 +++++----- PyTorchSimFrontend/mlir/mlir_scheduling.py | 51 +++++++------- PyTorchSimFrontend/mlir/mlir_template.py | 19 +++--- Scheduler/scheduler.py | 19 ++++-- 12 files changed, 250 insertions(+), 74 deletions(-) create mode 100644 PyTorchSimFrontend/extension_device_interface.py create mode 100644 PyTorchSimFrontend/extension_device_op_overrides.py create mode 100644 PyTorchSimFrontend/extension_utils.py diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 20152e9f..6bd5e63c 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -3,7 +3,8 @@ import shlex import subprocess -from torch._inductor.codecache import AsyncCompile, get_lock_dir, get_hash, write +from torch._inductor.codecache import get_lock_dir, get_hash, write +from torch._inductor.async_compile import AsyncCompile from AsmParser.tog_generator import tog_generator from PyTorchSimFrontend.mlir.mlir_caller_codegen import MLIRKernelCallerCodeGen from PyTorchSimFrontend import extension_config diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp index 1a02bfe3..34cdc7d2 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimFrontend/extension_device.cpp @@ -113,7 +113,7 @@ at::Tensor custom_to_device( // A dummy allocator for our custom device, that secretly uses the CPU struct DummyCustomAllocator final : at::Allocator { DummyCustomAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { void* data = c10::alloc_cpu(nbytes); return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; } @@ -128,6 +128,10 @@ struct DummyCustomAllocator final : at::Allocator { at::DeleterFnPtr raw_deleter() const override { return &ReportAndDelete; } + + void copy_data(void* dest, const void* src, std::size_t count) const override { + std::memcpy(dest, src, count); + } }; // Register our dummy allocator diff --git a/PyTorchSimFrontend/extension_device_interface.py b/PyTorchSimFrontend/extension_device_interface.py new file mode 100644 index 00000000..e5875ab7 --- /dev/null +++ b/PyTorchSimFrontend/extension_device_interface.py @@ -0,0 +1,63 @@ +import torch +from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties + +class _ExtensionDeviceProperties: # FIXME: Dummy property values + name: str = "Extension_device" + platform_name: str + vendor: str + driver_version: str + version: str + max_compute_units: int + gpu_eu_count: int + max_work_group_size: int + max_num_sub_groups: int + sub_group_sizes: list[int] + has_fp16: bool + has_fp64: bool + has_atomic64: bool + has_bfloat16_conversions: bool + has_subgroup_matrix_multiply_accumulate: bool + has_subgroup_matrix_multiply_accumulate_tensor_float32: bool + has_subgroup_2d_block_io: bool + total_memory: int + multi_processor_count: int = 128 # gpu_subslice_count, num_sm + architecture: int + type: str + +_ExtensionDeviceProperties = _ExtensionDeviceProperties + +class ExtensionDeviceInterface(DeviceInterface): + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["extension_device"] = device + + @staticmethod + def current_device() -> int: + if "extension_device" in caching_worker_current_devices: + return caching_worker_current_devices["extension_device"] + return torch.xpu.current_device() + + @staticmethod + def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties: + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "extension_device" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = ExtensionDeviceInterface.Worker.current_device() + + if "extension_device" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["extension_device"] = device_prop + + return _ExtensionDeviceProperties + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + return 36 \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_device_op_overrides.py b/PyTorchSimFrontend/extension_device_op_overrides.py new file mode 100644 index 00000000..b76dae0f --- /dev/null +++ b/PyTorchSimFrontend/extension_device_op_overrides.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from textwrap import dedent + +from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides + +class ExtensionDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + +register_device_op_overrides("extension_device", ExtensionDeviceOpOverrides()) \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_utils.py b/PyTorchSimFrontend/extension_utils.py new file mode 100644 index 00000000..0418cacd --- /dev/null +++ b/PyTorchSimFrontend/extension_utils.py @@ -0,0 +1,26 @@ +import sympy +import torch + +""" +NOTE: Temporary File + +This file contains functions that were removed or changed in newer versions +of PyTorch. It is kept here only to temporarily enable compatibility while +upgrading to PyTorch 2.8 from PyTorch 2.2. + +These functions will eventually be integrated into the appropriate source files +or removed once no longer needed. + +This file is not intended to be permanent and should be deleted in the future. +""" + +def free_symbol_startswith(index: sympy.Expr, prefix: str): + return any(v.name.startswith(prefix) for v in index.free_symbols) + +def sympy_symbol(name: str) -> sympy.Symbol: + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) \ No newline at end of file diff --git a/PyTorchSimFrontend/llvm/llvm_common.py b/PyTorchSimFrontend/llvm/llvm_common.py index 1c76b826..68cacc2d 100644 --- a/PyTorchSimFrontend/llvm/llvm_common.py +++ b/PyTorchSimFrontend/llvm/llvm_common.py @@ -11,13 +11,14 @@ from torch.utils._sympy.value_ranges import ValueRanges from torch._inductor.utils import ( - free_symbol_startswith, get_sympy_Expr_dtype, IndentedBuffer, sympy_subs, unique, ) +from PyTorchSimFrontend.extension_utils import free_symbol_startswith + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") DTYPE_TO_LLVM = { diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index af101f44..8f4cb233 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -41,6 +41,9 @@ def __init__( self.extra_args = extra_args #self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + def make_run_fn( self, input_tensors: torch.Tensor, output_tensors: torch.Tensor ) -> Callable[[], None]: @@ -62,5 +65,6 @@ def make_run_fn( *args, ) - def __str__(self) -> str: - return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" \ No newline at end of file + def update_workspace_size(self) -> None: + # FIXME: Not implemented yet. Checkout torch/_inductor/codegen/rocm/rocm_benchmark_request.py + return \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 21d2868e..eb8b4fc7 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -5,10 +5,12 @@ from functools import reduce from operator import mul import torch +from typing import Optional from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from torch._dynamo.utils import dynamo_timed from torch._inductor.codegen import cpp, wrapper, common, memory_planning +from torch._inductor.ir import GraphPartitionSignature from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic, write from torch._inductor.utils import ( @@ -76,10 +78,25 @@ def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" raise AssertionError(reduction_type) -class ExtensionWrapperCodegen(wrapper.WrapperCodeGen): +class ExtensionWrapperCodegen(wrapper.PythonWrapperCodegen): def __init__(self): super().__init__() + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[wrapper.PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None and parent_wrapper is not None + return wrapper.SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return cls() + def write_header(self): self.header.splice( f""" @@ -108,6 +125,7 @@ def write_header(self): reinterpret_tensor = torch.ops.aten._reinterpret_tensor custom_async_compile = CustomAsyncCompile() os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__ + print(f\'Wrapper Codegen Path = {{__file__}}\') """ ) self.header.splice( @@ -151,7 +169,7 @@ def call(args): self.prefix.writeline(f"{lhs} = args") self.prefix.writeline("args.clear()") - self.codegen_inputs(self.prefix, V.graph.graph_inputs) + self.codegen_inputs() self.codegen_input_size_asserts() self.codegen_sram_plan_prefix() @@ -171,10 +189,27 @@ def codegen_sram_plan_postfix(self, outputs): continue self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - @dynamo_timed + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + def generate(self, is_inference): result = IndentedBuffer() - result.splice(self.header) + # result.splice(self.header) with contextlib.ExitStack() as stack: stack.enter_context(self.wrapper_call.indent()) @@ -189,8 +224,13 @@ def generate(self, is_inference): if isinstance(line, wrapper.MemoryPlanningLine): line.codegen(self.wrapper_call) + elif isinstance(line, wrapper.KernelCallLine): + self.wrapper_call.writeline(self.wrap_kernel_call(line.kernel_name, line.call_args)) else: - self.wrapper_call.writeline(line) + if isinstance(line, wrapper.WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) # Add buffer plan hook for alloc if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") @@ -199,7 +239,9 @@ def generate(self, is_inference): self.mark_output_type() self.generate_return(output_refs) - self.append_precomputed_sizes_to_prefix() + # self.append_precomputed_sizes_to_prefix() # FIXME: Need to replace append_precomputed_sizes_to_prefix() + result.splice(self.header) + self.finalize_prefix() result.splice(self.prefix) @@ -208,7 +250,10 @@ def generate(self, is_inference): self.generate_end(result) self.add_benchmark_harness(result) - return result.getvaluewithlinemap() + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) def memory_plan(self): self.lines = memory_planning.MemoryPlanner(self).plan(self.lines) @@ -1603,16 +1648,16 @@ def get_cycle(choice): return optimal_src_code def codegen_nodes(self, nodes, kernel_name): - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) if not extension_config.CONFIG_AUTOTUNE or extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: - return src_code + return src_code, meta_code else: optimal_src_code = self.autotune(nodes, kernel_name) if optimal_src_code: - return optimal_src_code + return optimal_src_code, meta_code else: - return src_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): write_path = extension_codecache.get_write_path(src_code) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 73996351..cd4fdb74 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -13,6 +13,7 @@ from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep +from torch._inductor.codegen.wrapper import KernelDefinitionLine from torch.utils._sympy.functions import ModularIndexing import sympy import contextlib @@ -24,15 +25,19 @@ import torch.fx from torch.utils._sympy.value_ranges import ValueRanges from torch._inductor.utils import ( - free_symbol_startswith, get_sympy_Expr_dtype, IndentedBuffer, sympy_subs, - sympy_symbol, unique, ) from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest + +from PyTorchSimFrontend.extension_utils import ( + free_symbol_startswith, + sympy_symbol +) + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") DTYPE_TO_MLIR = { @@ -520,7 +525,7 @@ def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this - wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) + wrapper.generate_kernel_call(kernel_name, call_args, triton=False) def is_modular_indexing(self, expr): return "ModularIndexing" in str(expr) @@ -740,8 +745,8 @@ def codegen_nodes(self, nodes, kernel_name): V.graph.removed_buffers |= self.removed_buffers # V.graph.inplaced_to_remove |= self.inplaced_to_remove src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def run_bench(self, nodes, kernel_name, src_code): _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() @@ -784,12 +789,9 @@ def codegen_kernel(self, kernel_name): return code.getvalue() def meta_kernel(self): - wrapper = V.graph.wrapper_code _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") - return arg_attributes + meta_code = arg_attributes + return meta_code def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] @@ -886,10 +888,10 @@ def load(name: str, index: sympy.Expr): if name in store_cache: return store_cache[name] key = name+str(index) - if key not in self.cse.cache: + if key not in self.cse._cache: result = self.load(name, index) - self.cse.cache[key] = result - return self.cse.cache[key] + self.cse._cache[key] = result + return self.cse._cache[key] @staticmethod def store(name, index, value, mode=None): @@ -897,7 +899,7 @@ def store(name, index, value, mode=None): if mode is None: self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: return self.store(name, index, value, mode=mode) @@ -907,7 +909,7 @@ def store_reduction(name, index, value): self.store_buffer_names.add(name) self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: @@ -953,7 +955,7 @@ def bucketize( super().__enter__() assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) + parent_handler = self.overrides() self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) return self diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 2bbdb41d..4979df3f 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -22,8 +22,6 @@ class MLIRScheduling(BaseScheduling): target_kernel = MLIRKernel def __init__(self, scheduler): self.scheduler = scheduler - self.scheduler.can_fuse_origin = self.scheduler.can_fuse - self.scheduler.can_fuse = self.can_fuse_with_exceptions #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False @@ -90,6 +88,9 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule def _set_flush_status(self, status: bool): self._ready_to_flush = status + def reset_kernel_group(self): + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + def can_fuse_vertical(self, node1, node2): return self.can_fuse_horizontal(node1, node2) @@ -101,7 +102,7 @@ def can_fuse_horizontal(self, node1, node2): # Reduction is currently not supported if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: - return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users + return vars1 == vars2 and reduce1 == reduce2 # and node1.inverse_users == node2.inverse_users if node1.is_reduction() or node2.is_reduction(): return False @@ -178,7 +179,8 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - def codegen_nodes(self, nodes): + def codegen_node(self, _node): + nodes = _node.get_nodes() _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group @@ -208,8 +210,8 @@ def codegen_nodes(self, nodes): kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) - kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, + src_code, meta_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, meta_code, kernel_name_candidate, ex_kernel.vector_lane, ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() @@ -228,35 +230,39 @@ def codegen_sync(self): pass def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() self._set_flush_status(False) def define_function(self, kernel): partial_code, function_name = kernel.def_function() if partial_code is not None and function_name not in self.outer_function: with V.set_kernel_handler(kernel): - code = partial_code.finalize() + code = partial_code.finalize_all() wrapper = V.graph.wrapper_code wrapper.header.writeline(code) self.outer_function.add(function_name) - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): + def define_kernel(self, src_code, meta_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: kernel_name = wrapper.src_to_kernel[src_code] else: wrapper.src_to_kernel[src_code] = kernel_name - codecache_def = IndentedBuffer() codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") codecache_def.writeline(f"vectorlane_size={vector_lane},") codecache_def.writeline(f"loop_size={loop_size},") codecache_def.writeline(f"spad_info={spad_info},") codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes,") + codecache_def.writeline(f"arg_attributes={meta_code},") codecache_def.writeline(f"vlen={extension_config.CONFIG_VLEN})") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), gpu=False) return kernel_name def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): @@ -330,7 +336,7 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e src_code = ( partial_code if isinstance(partial_code, str) - else partial_code.finalize() + else partial_code.finalize_all() ) # For consistency, white space could make wrong write_path @@ -338,18 +344,7 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e buffer.splice(src_code) return buffer.getvalue() - def codegen_template(self, template_node, epilogue_nodes): - # Handle prologue pattern - prologue_nodes = [] - if not template_node.is_template(): - epilogue_nodes = [template_node] + epilogue_nodes - for i, node in enumerate(epilogue_nodes): - if node.is_template(): - template_node = node - prologue_nodes = epilogue_nodes[:i] - epilogue_nodes = epilogue_nodes[i+1:] - break - + def codegen_template(self, template_node, epilogue_nodes, prologue_nodes): _, (numel, rnumel) = template_node.group template_buffer = template_node.node kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) @@ -367,8 +362,8 @@ def codegen_template(self, template_node, epilogue_nodes): spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({kernel.spad_info['spad_size']*kernel.vector_lane})));" codegen_header(src_code, (kernel.header.getvalue()+spad_end_symbol+spad_section_end_symbol, kernel.gem5_header.getvalue())) - kernel.meta_kernel() - kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, + meta_code = kernel.meta_kernel() + kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) self.define_function(kernel) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 820d5c0d..90594ba0 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -11,8 +11,8 @@ from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, View +from torch._inductor.codegen.common import Kernel, KernelTemplate, OpOverrides, CSE, DeferredLine +from torch._inductor.ir import Buffer, ChoiceCaller, IRNode, TemplateBuffer, View from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -394,18 +394,14 @@ def meta_kernel(self): for idx in range(len(arg_attributes)): if arg_attributes[idx][0] == name: arg_attributes[idx][1] = attr - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") + return arg_attributes def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", - call_args, cuda=False) + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) def codegen_prologue_body(self): body = IndentedBuffer() @@ -626,7 +622,7 @@ def hook(): return "" def def_function(self): - _, call_args, _ = self.kernel_group.args.python_argdefs() + _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) return PartialRender( @@ -1055,7 +1051,7 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): """ super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) self.input_reorder = input_reorder self.layout = layout @@ -1119,7 +1115,10 @@ def make_kernel_render( self.output_node.get_layout(), make_kernel_render, bmreq, + False, # supports_epilogue_fusion self, + kwargs, + "" # Currently Empty description ) def render(self, **kwargs) -> str: diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index 834698a6..7423c57c 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -1,5 +1,6 @@ from typing import List import os +import sys import numpy as np import torch from pathlib import Path @@ -7,6 +8,10 @@ from PyTorchSimFrontend.extension_codecache import hash_prefix from Simulator.simulator import BackendSimulator from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.extension_device_interface import ExtensionDeviceInterface + +from torch._dynamo.device_interface import register_interface_for_device + def import_module_from_path(module_name, path): module_path = Path(path) # Convert to Path object for safety @@ -191,16 +196,22 @@ def setup_device(): from PyTorchSimFrontend.mlir.mlir_scheduling import ( MLIRScheduling ) + register_backend_for_device( - "extension_device", MLIRScheduling, ExtensionWrapperCodegen - ) - assert( - get_scheduling_for_device("extension_device") == MLIRScheduling + "extension_device", + lambda scheduling: MLIRScheduling(scheduling), + ExtensionWrapperCodegen ) + import PyTorchSimFrontend.extension_device_op_overrides + assert( get_wrapper_codegen_for_device("extension_device") == ExtensionWrapperCodegen ) + + torch.utils.rename_privateuse1_backend("extension_device") + sys.modules['torch.extension_device'] = module + register_interface_for_device(module.custom_device(), ExtensionDeviceInterface) return module def submit(self, batched_req, partition_idx) -> List[RequestReturn]: From bc53b4bb9b4bcb912fb89071e203211b5e17679d Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Wed, 24 Sep 2025 13:28:55 +0000 Subject: [PATCH 2/4] [Test] Add torch.no_grad(), change to use torch.nn.ReLU, fuion off --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 1 + tests/test_activation.py | 5 +++-- tests/test_conv2d.py | 25 +++++++++++----------- tests/test_layernorm.py | 5 +++-- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 4979df3f..acc5ec9d 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -95,6 +95,7 @@ def can_fuse_vertical(self, node1, node2): return self.can_fuse_horizontal(node1, node2) def can_fuse_horizontal(self, node1, node2): + return False if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: return False _, (vars1, reduce1) = node1.group diff --git a/tests/test_activation.py b/tests/test_activation.py index de3542c3..40bcca8e 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -23,9 +23,10 @@ def test_ReLU(device, size=(128, 128)): input = torch.randn(size) x1 = input.to(device=device) x2 = input.to("cpu") - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu) + ReLU = torch.nn.ReLU() + opt_fn = torch.compile(dynamic=False)(ReLU) y = opt_fn(x1) - cpu_y = torch.nn.functional.relu(x2) + cpu_y = ReLU(x2) test_result("ReLU", y, cpu_y) def test_GeLU(device, size=(128, 128), approximate='none'): diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 21bbfec7..6b7c60cf 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -44,15 +44,16 @@ def custom_conv2d(a, b, bias): module = ExecutionEngine.setup_device() device = module.custom_device() torch._dynamo.config.cache_size_limit = 64 - test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) - test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) + with torch.no_grad(): + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) diff --git a/tests/test_layernorm.py b/tests/test_layernorm.py index 1cea9d9f..f812b3f5 100644 --- a/tests/test_layernorm.py +++ b/tests/test_layernorm.py @@ -44,5 +44,6 @@ def test_LayerNorm(device, size=(64, 64)): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - #test_LayerNorm(device) - test_LayerNorm(device, shape) + with torch.no_grad(): + #test_LayerNorm(device) + test_LayerNorm(device, shape) From 5d9195ef89f1f699b00ae92c57444b3a2f533916 Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Wed, 24 Sep 2025 14:06:29 +0000 Subject: [PATCH 3/4] [CI] Upgrade PyTorch version to 2.8 --- Dockerfile.base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.base b/Dockerfile.base index 1ac5e175..778ffec3 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -23,7 +23,7 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime +FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime # Copied from Gem5 Docker file ENV DEBIAN_FRONTEND=noninteractive From ab88edc847896ca78056f35a404f93e0e59128f1 Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Thu, 6 Nov 2025 05:28:52 +0000 Subject: [PATCH 4/4] [Implement] Hook and GuardImpl for extension device --- PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp | 8 ++ PyTorchSimDevice/ExtensionDeviceGuardImpl.h | 127 ++++++++++++++++++ .../extension_device.cpp | 10 +- .../extension_device_interface.py | 0 .../extension_device_op_overrides.py | 0 PyTorchSimDevice/extension_hooks.cpp | 48 +++++++ PyTorchSimDevice/extension_hooks.h | 30 +++++ Scheduler/scheduler.py | 8 +- 8 files changed, 221 insertions(+), 10 deletions(-) create mode 100644 PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp create mode 100644 PyTorchSimDevice/ExtensionDeviceGuardImpl.h rename {PyTorchSimFrontend => PyTorchSimDevice}/extension_device.cpp (98%) rename {PyTorchSimFrontend => PyTorchSimDevice}/extension_device_interface.py (100%) rename {PyTorchSimFrontend => PyTorchSimDevice}/extension_device_op_overrides.py (100%) create mode 100644 PyTorchSimDevice/extension_hooks.cpp create mode 100644 PyTorchSimDevice/extension_hooks.h diff --git a/PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp b/PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp new file mode 100644 index 00000000..a0b1395d --- /dev/null +++ b/PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp @@ -0,0 +1,8 @@ +#include "ExtensionDeviceGuardImpl.h" +#include + +namespace c10::extension_device::impl { + +C10_REGISTER_GUARD_IMPL(extension_device, ExtensionDeviceGuardImpl); + +} // namespace c10::extension_device::impl diff --git a/PyTorchSimDevice/ExtensionDeviceGuardImpl.h b/PyTorchSimDevice/ExtensionDeviceGuardImpl.h new file mode 100644 index 00000000..6d35677b --- /dev/null +++ b/PyTorchSimDevice/ExtensionDeviceGuardImpl.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10::extension_device::impl { + +struct ExtensionDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::PrivateUse1; // ✅ your backend type + + ExtensionDeviceGuardImpl() = default; + + explicit ExtensionDeviceGuardImpl(DeviceType t) { + TORCH_CHECK( + t == static_type, + "ExtensionDeviceGuardImpl initialized with non-extension_device DeviceType: ", + t); + } + + // -------------------------------------------------------------------------- + // 기본적인 device guard (CPU처럼 동작) + // -------------------------------------------------------------------------- + DeviceType type() const override { + return static_type; + } + + Device exchangeDevice(Device d) const override { + TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d); + return d; // nothing to exchange, CPU-like + } + + Device getDevice() const override { + return Device(static_type, 0); + } + + void setDevice(Device d) const override { + TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d); + } + + void uncheckedSetDevice(Device d) const noexcept override {} + + DeviceIndex deviceCount() const noexcept override { + return 1; // pretend single device + } + + // -------------------------------------------------------------------------- + // Stream handling (동기식이므로 기본 stream만 사용) + // -------------------------------------------------------------------------- + Stream getStream(Device d) const override { + return Stream(Stream::DEFAULT, d); + } + + Stream getNewStream(Device d, int priority = 0) const override { + return Stream(Stream::DEFAULT, d); + } + + Stream getStreamFromGlobalPool(Device d, bool = false) const override { + return Stream(Stream::DEFAULT, d); + } + + Stream exchangeStream(Stream s) const override { + return s; + } + + bool queryStream(const Stream& stream) const override { + (void)stream; + return true; + } + + void synchronizeStream(const Stream& stream) const override { + (void)stream; + } + + void synchronizeDevice(DeviceIndex device_index) const override { + (void)device_index; + } + + // -------------------------------------------------------------------------- + // Event handling (전부 no-op) + // -------------------------------------------------------------------------- + void destroyEvent(void* event, const DeviceIndex device_index) const noexcept override { + (void)event; + (void)device_index; + } + + void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { + (void)event; + (void)stream; + (void)device_index; + (void)flag; + } + + void block(void* event, const Stream& stream) const override { + (void)event; + (void)stream; + } + + bool queryEvent(void* event) const override { + (void)event; + return true; + } + + void synchronizeEvent(void* event) const override { + (void)event; + } + + double elapsedTime(void* start_event, void* end_event, const DeviceIndex device_index) const override { + (void)start_event; + (void)end_event; + (void)device_index; + return 0.0; + } + + // -------------------------------------------------------------------------- + // Misc (allocator integration) + // -------------------------------------------------------------------------- + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override { + (void)data_ptr; + (void)stream; + } +}; + +} // namespace c10::extension_device::impl diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimDevice/extension_device.cpp similarity index 98% rename from PyTorchSimFrontend/extension_device.cpp rename to PyTorchSimDevice/extension_device.cpp index 34cdc7d2..68a1b370 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimDevice/extension_device.cpp @@ -17,16 +17,12 @@ #include #include +#include "ExtensionDeviceGuardImpl.h" + static uint64_t op_counter = 0; static uint64_t last_saved_value = 0; -// register guard -namespace at { -namespace detail { - -C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); - -}} // namespace at::detail +C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::extension_device::impl::ExtensionDeviceGuardImpl); // basic dummy add function at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { diff --git a/PyTorchSimFrontend/extension_device_interface.py b/PyTorchSimDevice/extension_device_interface.py similarity index 100% rename from PyTorchSimFrontend/extension_device_interface.py rename to PyTorchSimDevice/extension_device_interface.py diff --git a/PyTorchSimFrontend/extension_device_op_overrides.py b/PyTorchSimDevice/extension_device_op_overrides.py similarity index 100% rename from PyTorchSimFrontend/extension_device_op_overrides.py rename to PyTorchSimDevice/extension_device_op_overrides.py diff --git a/PyTorchSimDevice/extension_hooks.cpp b/PyTorchSimDevice/extension_hooks.cpp new file mode 100644 index 00000000..aadd6d2a --- /dev/null +++ b/PyTorchSimDevice/extension_hooks.cpp @@ -0,0 +1,48 @@ +#include "extension_hooks.h" + +bool ExtensionPU1Hooks::isBuilt() const { return true; } +bool ExtensionPU1Hooks::isAvailable() const { return true; } + +const at::Generator& ExtensionPU1Hooks::getDefaultGenerator(c10::DeviceIndex idx) const { + if (idx < 0) idx = 0; + static std::vector gens; + static std::mutex m; + std::lock_guard g(m); + if (gens.size() <= (size_t)idx) gens.resize((size_t)idx + 1); + if (!gens[idx].defined()) gens[idx] = at::GetGeneratorForPrivateuse1(idx); + return gens[idx]; // 영속 객체 참조 반환 +} + +at::Generator ExtensionPU1Hooks::getNewGenerator(c10::DeviceIndex idx) const { + if (idx < 0) idx = 0; + return at::GetGeneratorForPrivateuse1(idx); +} + +at::Device ExtensionPU1Hooks::getDeviceFromPtr(void* data) const { + return at::Device(at::kPrivateUse1, 0); // MVP: 단일 디바이스 가정 +} + +bool ExtensionPU1Hooks::isPinnedPtr(const void* data) const { + return false; +} + +at::Allocator* ExtensionPU1Hooks::getPinnedMemoryAllocator() const { + return at::getHostAllocator(at::kPrivateUse1); +} + +bool ExtensionPU1Hooks::hasPrimaryContext(c10::DeviceIndex device_index) const { return true; } + +void ExtensionPU1Hooks::resizePrivateUse1Bytes(const c10::Storage&, size_t) const { + TORCH_CHECK(false, "resizePrivateUse1Bytes not implemented"); +} + +// REGISTER_EXTENSION_HOOKS(ExtensionPU1Hooks); + +namespace { +struct AutoRegistrar { + AutoRegistrar() { + at::RegisterPrivateUse1HooksInterface(new ExtensionPU1Hooks()); + } +}; +static AutoRegistrar _auto_registrar; +} diff --git a/PyTorchSimDevice/extension_hooks.h b/PyTorchSimDevice/extension_hooks.h new file mode 100644 index 00000000..fdf3505a --- /dev/null +++ b/PyTorchSimDevice/extension_hooks.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +struct ExtensionPU1Hooks final : public at::PrivateUse1HooksInterface { + ExtensionPU1Hooks() {} + bool isBuilt() const; + bool isAvailable() const; + + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override; + + at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; + + at::Device getDeviceFromPtr(void* data) const override; + + bool isPinnedPtr(const void* data) const override; + + at::Allocator* getPinnedMemoryAllocator() const override; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override; + + void resizePrivateUse1Bytes(const c10::Storage& /*storage*/, size_t /*newsize*/) const override; +}; \ No newline at end of file diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index 7423c57c..cb3453da 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -8,7 +8,7 @@ from PyTorchSimFrontend.extension_codecache import hash_prefix from Simulator.simulator import BackendSimulator from PyTorchSimFrontend import extension_config -from PyTorchSimFrontend.extension_device_interface import ExtensionDeviceInterface +from PyTorchSimDevice.extension_device_interface import ExtensionDeviceInterface from torch._dynamo.device_interface import register_interface_for_device @@ -171,14 +171,16 @@ def __init__(self, backend_simulator : BackendSimulator, num_partion=1) -> None: def setup_device(): source_file_path = os.path.dirname(os.path.abspath(__file__)) source_file = os.path.join( - source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimFrontend/extension_device.cpp" + source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimDevice/extension_device.cpp" ) + hook_file = os.path.join(source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimDevice/extension_hooks.cpp") import torch.utils.cpp_extension module = torch.utils.cpp_extension.load( name="extension_device", sources=[ str(source_file), + str(hook_file), ], extra_cflags=["-g"], verbose=True, @@ -202,7 +204,7 @@ def setup_device(): lambda scheduling: MLIRScheduling(scheduling), ExtensionWrapperCodegen ) - import PyTorchSimFrontend.extension_device_op_overrides + import PyTorchSimDevice.extension_device_op_overrides assert( get_wrapper_codegen_for_device("extension_device")