diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index fe8a4a7d..8444f318 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -663,6 +663,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_scheduler.py + test_llama: + name: Run test_llama1&2 + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_llama.py + run: | + echo "Running test_llama.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/Llama/test_llama.py + test_accuracy: name: Run test_accuracy runs-on: self-hosted diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 4d57b987..2e35220c 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -278,7 +278,7 @@ def dummy_simulator(*args, **kwargs): vectorlane_size=vectorlane_size, spad_info=spad_info, silent_mode=silent_mode) if not extension_config.pytorchsim_timing_mode: - return + return [float("inf")] onnx_path = os.path.join(result_path, "tile_graph.onnx") attribute_path = os.path.join(runtime_path, "attribute") @@ -286,7 +286,7 @@ def dummy_simulator(*args, **kwargs): TOGSim = TOGSimulator(togsim_path, extension_config.CONFIG_TOGSIM_CONFIG) TOGSim.vectorlane_size = vectorlane_size attribute_path = TOGSim.create_attribute_file(attribute_path, args, loop_size=loop_size) - result_path = TOGSim.simulation(onnx_path, attribute_path, silent_mode=silent_mode) + result_path = TOGSim.simulation(onnx_path, attribute_path, silent_mode=silent_mode, autotune_mode=autotune) result = TOGSimulator.get_result_from_file(result_path) return result diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 239bbefe..ab8aea69 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -1,7 +1,7 @@ import os import sys import importlib -import json +import yaml CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') CONFIG_GEM5_PATH = os.environ.get('GEM5_PATH', default="/workspace/gem5/build/RISCV/gem5.opt") @@ -13,51 +13,53 @@ def __getattr__(name): # TOGSim config config_path = os.environ.get('TOGSIM_CONFIG', - default=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json") + default=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml") if name == "CONFIG_TOGSIM_CONFIG": return config_path - config_json = json.load(open(config_path, 'r')) + + with open(config_path, 'r') as f: + config_yaml = yaml.safe_load(f) # Hardware info config if name == "vpu_num_lanes": - return config_json["vpu_num_lanes"] + return config_yaml["vpu_num_lanes"] if name == "CONFIG_SPAD_INFO": return { "spad_vaddr" : 0xD0000000, "spad_paddr" : 0x2000000000, - "spad_size" : config_json["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane + "spad_size" : config_yaml["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane } if name == "CONFIG_PRECISION": return 4 # 32bit if name == "CONFIG_NUM_CORES": - return config_json["num_cores"] + return config_yaml["num_cores"] if name == "vpu_vector_length_bits": - return config_json["vpu_vector_length_bits"] + return config_yaml["vpu_vector_length_bits"] if name == "pytorchsim_functional_mode": - return config_json['pytorchsim_functional_mode'] + return config_yaml['pytorchsim_functional_mode'] if name == "pytorchsim_timing_mode": - return config_json['pytorchsim_timing_mode'] + return config_yaml['pytorchsim_timing_mode'] # Mapping strategy if name == "codegen_mapping_strategy": - codegen_mapping_strategy = config_json["codegen_mapping_strategy"] + codegen_mapping_strategy = config_yaml["codegen_mapping_strategy"] assert(codegen_mapping_strategy in ["heuristic", "autotune", "external-then-heuristic", "external-then-autotune"]), "Invalid mapping strategy!" return codegen_mapping_strategy if name == "codegen_external_mapping_file": - return config_json["codegen_external_mapping_file"] + return config_yaml["codegen_external_mapping_file"] # Autotune config if name == "codegen_autotune_max_retry": - return config_json["codegen_autotune_max_retry"] + return config_yaml["codegen_autotune_max_retry"] if name == "codegen_autotune_template_topk": - return config_json["codegen_autotune_template_topk"] + return config_yaml["codegen_autotune_template_topk"] # Compiler Optimization if name == "codegen_compiler_optimization": - opt_level = config_json["codegen_compiler_optimization"] + opt_level = config_yaml["codegen_compiler_optimization"] valid_opts = { "fusion", "reduction_epilogue", @@ -67,7 +69,7 @@ def __getattr__(name): "multi_tile_conv", "subtile" } - if opt_level == "all" or opt_level is "none": + if opt_level == "all" or opt_level == "none": pass elif isinstance(opt_level, list): # Check if provided list contains only valid options diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp index 1a02bfe3..cfaecf2b 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimFrontend/extension_device.cpp @@ -16,6 +16,44 @@ #include #include #include +#include +namespace py = pybind11; + +namespace { + bool g_amp_enabled = false; + at::ScalarType g_amp_dtype = at::kFloat; +} + +static at::ScalarType to_scalar_type(const py::object& dtype_obj) { + py::module torch_mod = py::module::import("torch"); + if (dtype_obj.is(torch_mod.attr("bfloat16"))) return at::kBFloat16; + if (dtype_obj.is(torch_mod.attr("float16"))) return at::kHalf; + if (dtype_obj.is(torch_mod.attr("float32"))) return at::kFloat; + if (dtype_obj.is(torch_mod.attr("float64"))) return at::kDouble; + throw std::runtime_error("Unsupported dtype for extension_device AMP"); +} + +static py::object to_torch_dtype(at::ScalarType st) { + py::module torch_mod = py::module::import("torch"); + switch (st) { + case at::kBFloat16: return torch_mod.attr("bfloat16"); + case at::kHalf: return torch_mod.attr("float16"); + case at::kFloat: return torch_mod.attr("float32"); + case at::kDouble: return torch_mod.attr("float64"); + default: + throw std::runtime_error("Unsupported scalar type in get_autocast_dtype"); + } +} + +static inline at::MemoryFormat fix_memory_format(c10::optional mf_opt) { + if (!mf_opt.has_value()) return at::MemoryFormat::Contiguous; + + auto mf = mf_opt.value(); + if (mf == at::MemoryFormat::Preserve) { + return at::MemoryFormat::Contiguous; + } + return mf; +} static uint64_t op_counter = 0; static uint64_t last_saved_value = 0; @@ -99,8 +137,16 @@ at::Tensor custom_to_device( TORCH_CHECK(self.is_contiguous()); op_counter += 1; - if (device != at::DeviceType::CPU) { - return at::empty(self.sizes(), self.options()); + if (device.type() == at::DeviceType::CPU) { + auto out = at::empty(self.sizes(), dtype, self.options().layout(), + device, false, memory_format); + std::memcpy(out.mutable_data_ptr(), self.data_ptr(), self.nbytes()); + return out; + } else { + auto opts = self.options().device(device).dtype(dtype); + auto out = at::empty(self.sizes(), opts); + std::memcpy(out.mutable_data_ptr(), self.data_ptr(), self.nbytes()); + return out; } auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, false, memory_format); @@ -135,33 +181,86 @@ static DummyCustomAllocator global_custom_alloc; REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc); at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { - TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows dummy device."); + TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, + "Dummy test only allows dummy device."); TORCH_CHECK(self.is_contiguous()); - // TORCH_CHECK(self.scalar_type() == c10::ScalarType::Float); op_counter += 1; - if (self.scalar_type() == c10::ScalarType::Float) { - auto _data = static_cast(self.mutable_data_ptr()); - for (size_t idx = 0; idx < self.numel(); idx++) { - _data[idx] = value.toFloat(); + + switch (self.scalar_type()) { + case c10::ScalarType::Float: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = value.toFloat(); + } + break; } - return self; - } else if (self.scalar_type() == c10::ScalarType::Int) { - auto _data = static_cast(self.mutable_data_ptr()); - for (size_t idx = 0; idx < self.numel(); idx++) { - _data[idx] = value.toInt(); + case c10::ScalarType::Double: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = value.toDouble(); + } + break; } - return self; - } else if (self.scalar_type() == c10::ScalarType::Long) { - auto _data = static_cast(self.mutable_data_ptr()); - for (size_t idx = 0; idx < self.numel(); idx++) { - _data[idx] = value.toLong(); + case c10::ScalarType::Half: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = at::Half(value.toHalf()); + } + break; } - return self; - } else { - TORCH_CHECK(false, "Unsupported scalar type."); + case c10::ScalarType::BFloat16: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = at::BFloat16(value.toBFloat16()); + } + break; + } + case c10::ScalarType::Int: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = value.toInt(); + } + break; + } + case c10::ScalarType::Long: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = value.toLong(); + } + break; + } + case c10::ScalarType::Short: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = static_cast(value.toShort()); + } + break; + } + case c10::ScalarType::Char: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = static_cast(value.toChar()); + } + break; + } + case c10::ScalarType::Byte: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = static_cast(value.toByte()); + } + break; + } + case c10::ScalarType::Bool: { + auto* data = self.mutable_data_ptr(); + for (int64_t i = 0; i < self.numel(); i++) { + data[i] = value.toBool(); + } + break; + } + default: + TORCH_CHECK(false, "Unsupported scalar type: ", self.scalar_type()); } - return self; } @@ -204,6 +303,9 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool "Dummy test only allows copy from cpu -> dummy device."); // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. + if (self.numel() != dst.numel()) { + custom_resize_(dst, self.sizes(), c10::nullopt); + } TORCH_CHECK(self.sizes() == dst.sizes()); const bool same_dtype = (self.scalar_type() == dst.scalar_type()); @@ -247,7 +349,7 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional dty constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); auto dtype = c10::dtype_or_default(dtype_opt); - return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, optional_memory_format); + return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, fix_memory_format(optional_memory_format)); } at::Tensor& custom_arange_start_out_impl( @@ -255,9 +357,36 @@ at::Tensor& custom_arange_start_out_impl( const c10::Scalar& end, const c10::Scalar& step, at::Tensor& out) { - //const int64_t n = arange_len(start.toDouble(), end.toDouble(), step.toDouble()); - //at::native::resize_output(out, {n}); - return out; + double s = start.toDouble(); + double e = end.toDouble(); + double st = step.toDouble(); + TORCH_CHECK(st != 0.0, "step must be nonzero"); + + int64_t length = 0; + if (st > 0) { + if (e > s) length = static_cast(std::ceil((e - s) / st)); + } else { + if (e < s) length = static_cast(std::ceil((e - s) / st)); + } + + // Resize out tensor + custom_resize_(out, {length}, c10::nullopt); + + if (out.scalar_type() == at::kFloat || out.scalar_type() == at::kDouble) { + double* data = out.mutable_data_ptr(); + for (int64_t i = 0; i < length; i++) { + data[i] = s + i * st; + } + } else if (out.scalar_type() == at::kLong) { + int64_t* data = out.mutable_data_ptr(); + for (int64_t i = 0; i < length; i++) { + data[i] = static_cast(s + i * st); + } + } else { + TORCH_CHECK(false, "Unsupported dtype for arange on dummy device"); + } + + return out; } static at::Tensor custom_to_dtype_impl(const at::Tensor& self, @@ -267,6 +396,62 @@ static at::Tensor custom_to_dtype_impl(const at::Tensor& self, return at::native::to(self, dtype, non_blocking, copy, memory_format); } +at::Tensor custom_zeros_like( + const at::Tensor& input, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt) +{ + // dtype / layout / device fallback + auto dtype = dtype_opt.value_or(input.scalar_type()); + auto layout = layout_opt.value_or(input.layout()); + auto device = device_opt.value_or(input.device()); + auto memfmt = memory_format_opt.value_or(c10::MemoryFormat::Contiguous); + + TORCH_CHECK( + device.type() == c10::DeviceType::PrivateUse1, + "custom_zeros_like: device must be PrivateUse1"); + + at::Tensor out = custom_empty( + input.sizes(), + dtype, + layout, + device, + pin_memory_opt, + memfmt + ); + size_t nbytes = out.numel() * out.element_size(); + void* ptr = out.mutable_data_ptr(); + + TORCH_CHECK(ptr != nullptr, + "custom_zeros_like: out.mutable_data_ptr() returned NULL"); + std::memset(ptr, 0, nbytes); + return out; +} + +at::Tensor& custom_zero_impl(at::Tensor& self) +{ + TORCH_CHECK( + self.device().type() == c10::DeviceType::PrivateUse1, + "custom_zero_: expected a PrivateUse1 device tensor"); + + if (self.numel() == 0) { + return self; + } + + void* data = self.mutable_data_ptr(); + TORCH_CHECK(data != nullptr, + "custom_zero_: self.mutable_data_ptr() returned NULL " + "(storage was not allocated)"); + + size_t nbytes = self.numel() * self.element_size(); + std::memset(data, 0, nbytes); + + return self; +} + // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. // Later in this file, we map a custom device to the PrivateUse1 device type, @@ -276,16 +461,18 @@ static at::Tensor custom_to_dtype_impl(const at::Tensor& self, // This macro registers your kernels to the PyTorch Dispatcher. // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("to.Device", &custom_to_device); - m.impl("to.dtype", &custom_to_dtype_impl); - m.impl("fill_.Scalar", &custom_fill__scalar); - m.impl("_copy_from", &custom__copy_from); + m.impl("to.Device", &custom_to_device); + m.impl("to.dtype", &custom_to_dtype_impl); + m.impl("fill_.Scalar", &custom_fill__scalar); + m.impl("_copy_from", &custom__copy_from); m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); - m.impl("empty_strided", &custom_empty_strided); - m.impl("empty.memory_format", &custom_empty); - m.impl("as_strided", at::native::as_strided_tensorimpl); - m.impl("view", at::native::view); - m.impl("arange.start_out", &custom_arange_start_out_impl); + m.impl("empty_strided", &custom_empty_strided); + m.impl("empty.memory_format", &custom_empty); + m.impl("as_strided", at::native::as_strided_tensorimpl); + m.impl("view", at::native::view); + m.impl("arange.start_out", &custom_arange_start_out_impl); + m.impl("zeros_like", &custom_zeros_like); + m.impl("zero_", &custom_zero_impl); } TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { @@ -293,11 +480,11 @@ TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { } TORCH_LIBRARY_FRAGMENT(aten, m) { -m.def( - "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", - torch::dispatch( - c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor), - {at::Tag::pt2_compliant_tag}); + m.def( + "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", + torch::dispatch(c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor), + {at::Tag::pt2_compliant_tag} + ); } void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { @@ -305,39 +492,162 @@ void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack } TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("abs", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("abs.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("abs_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("absolute", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("absolute.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("absolute_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("add.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("add.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("abs.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("add_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("cat", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("cat.names", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("cat.names_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("div.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("div.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Tensor_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("zero_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("div.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("div_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("div_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("eq.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("eq.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("eq.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("eq.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("equal", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("erf", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("erf.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("erf_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("erfc", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("erfc.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("erfc_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("exp", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("ge.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ge.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ge.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ge.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("gt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("gt.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("gt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("gt.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("le.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("le.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("le.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("le.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("lt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("lt.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("lt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("lt.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ne.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ne.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ne.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("ne.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("logical_and", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_and.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_and_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_not", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_not.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_not_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_or", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_or.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_or_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_xor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_xor.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("logical_xor_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("neg", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("neg.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("neg_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("mul.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("mul_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("pow.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow.Tensor_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow.Tensor_Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow.Tensor_Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow.Tensor_Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("pow_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("sub.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sub.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sub_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sub_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("sum", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sum.DimnameList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("sum.IntList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("eq.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("all.all_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_log_softmax", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_log_softmax_backward_data", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mse_loss.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_lerp_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_mul_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_addcmul_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_sqrt", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_div_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sum.dim_DimnameList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sum.dim_IntList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("resize_as_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + // Foreach ops + m.impl("_foreach_add.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("_foreach_add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_foreach_add_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + // Indexed + m.impl("index_add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_add_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_copy.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_copy_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_fill.int_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_fill.int_Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_fill.int_Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_fill.int_Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_fill_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("tril", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("tril_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("triu", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("triu_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("nll_loss2d_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("nll_loss2d_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("nll_loss_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("nll_loss_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("scatter.src_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("scatter.value_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("index_put.Default", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("mm.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("sigmoid.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("gather.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("silu.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + + m.impl("all.all_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_log_softmax", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_log_softmax_backward_data", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("mse_loss.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("where.self", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("min", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("max", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index_select", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); } // This basic implementation doesn't bother dealing with different device indices @@ -360,7 +670,6 @@ bool custom_op_called() { class PrivateGeneratorImpl : public at::CPUGeneratorImpl { public: - // Constructors PrivateGeneratorImpl(c10::DeviceIndex device_index) { device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); @@ -382,7 +691,21 @@ void register_generator() { // that's implemented in C++. // The implementation in this file maps directly to the `PrivateUse1` device type. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("custom_device", &get_custom_device, "get custom device object"); - m.def("custom_op_called", &custom_op_called, "check if our custom function was called"); - m.def("register_generator", ®ister_generator, "register generator for custom device"); + m.def("custom_device", &get_custom_device, "get custom device object"); + m.def("custom_op_called", &custom_op_called, "check if our custom function was called"); + m.def("register_generator", ®ister_generator, "register generator for custom device"); + m.def("is_autocast_enabled", []() -> bool { return g_amp_enabled;}); + m.def("set_autocast_enabled", [](bool flag) -> void {g_amp_enabled = flag;}); + m.def("get_autocast_dtype", []() -> py::object { return to_torch_dtype(g_amp_dtype); }); + m.def("set_autocast_dtype", [](py::object dtype_obj) -> void { + auto st = to_scalar_type(dtype_obj); + g_amp_dtype = st; + }); + m.def("get_amp_supported_dtype", []() -> py::list { + py::module torch_mod = py::module::import("torch"); + py::list lst; + lst.append(torch_mod.attr("float16")); + lst.append(torch_mod.attr("float32")); + return lst; + }); } \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_op.py b/PyTorchSimFrontend/extension_op.py index 786e7398..18bf65c3 100644 --- a/PyTorchSimFrontend/extension_op.py +++ b/PyTorchSimFrontend/extension_op.py @@ -276,7 +276,7 @@ def sparse_mm_stonne_outer(a, b, out): onnx_path, attribute_path, c_result_path = prepare_outer_product_matrix(a, b, out) togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") - stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_single_c1_simple_noc.json' + stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_single_c1_simple_noc.yml' TOGSim = TOGSimulator(togsim_path, stonne_config_path) result_path = TOGSim.simulation(onnx_path) TOGSimulator.get_result_from_file(result_path) diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index dff6b0fd..a539bdb9 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -1,4 +1,5 @@ import os +import math import subprocess import shlex import re @@ -58,7 +59,11 @@ def load_arg(self): if self.is_in_arg(arg_attribute[0]): argv_idx = self.get_argv_idx() if arg_name not in self.load_args else self.load_args[arg_name] self.load_args[arg_name] = argv_idx - self.writeline(f'if(load_arg(c_{arg_name}, sizeof(c_{arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') + ctype = DTYPE_TO_C[arg_attribute[1]] + elem_count = arg_attribute[2] + size_expr = f'({elem_count}ULL * sizeof({ctype}))' + + self.writeline(f'if(load_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') self.writeline(self.closed_bracket) @@ -67,7 +72,10 @@ def dump_arg(self): for arg_name, arg_attribute in self.arg_attributes: if self.is_out_arg(arg_attribute[0]): argv_idx = self.get_argv_idx() if not self.is_inout_arg(arg_attribute[0]) else self.load_args[arg_name] - self.writeline(f'if(dump_arg(c_{arg_name}, sizeof(c_{arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') + ctype = DTYPE_TO_C[arg_attribute[1]] + elem_count = arg_attribute[2] + size_expr = f'({elem_count}ULL * sizeof({ctype}))' + self.writeline(f'if(dump_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') self.writeline(self.closed_bracket) @@ -84,29 +92,25 @@ def generate_kernel_declare(self): def generate_args_define(self): name_set = set() if self.validation: - self.writeline(f'int padding[0x100000]{self.ending}') # FIXME. For pooling operation... Some pooling layer use negative offset + self.writeline(f"int* padding = malloc(0x100000ULL * sizeof(int)){self.ending}") for arg_name, (_, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes: if not arg_name in name_set: - if self.validation: - self.writeline(f'{DTYPE_TO_C[arg_type]} c_{arg_name}[{arg_size}ULL]{self.ending}') + if torch.is_floating_point(torch.tensor([], dtype=arg_type)): + bits = torch.finfo(arg_type).bits + elif arg_type == torch.bool: + bits = 8 else: - if torch.is_floating_point(torch.tensor([], dtype=arg_type)): - bits = torch.finfo(arg_type).bits - elif arg_type == torch.bool: - bits = 8 - else: - bits = torch.iinfo(arg_type).bits - self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({arg_size * bits // 8}ULL){self.ending}') + bits = torch.iinfo(arg_type).bits + buffer_size = int(math.ceil(arg_size * bits // 8 / 64) * 64) # Round up to 64 bytes + self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({buffer_size}ULL){self.ending}') name_set.add(arg_name) self.writeline(self.newline) def generate_main(self): - if self.validation: - self.generate_args_define() - self.writeline(f'{self.newline}int main(int argc, char *argv[]) {self.open_bracket}{self.newline}') with self.code.indent(): if self.validation: + self.generate_args_define() self.load_arg() self.writeline(self.newline) else: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6650f429..266d884b 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -24,6 +24,7 @@ from PyTorchSimFrontend import extension_config from . import mlir_common from .mlir_common import LoopLevel, LoopNest +from .mlir_ops import ExtensionOverrides from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest def reduction_init(reduction_type, dtype): @@ -36,19 +37,9 @@ def reduction_init(reduction_type, dtype): if reduction_type == "prod": return float(1) if dtype.is_floating_point else int(1) if reduction_type in {"max", "argmax"}: - if dtype == torch.float32: - return f"0x{mlir_common.MLIR_INF['-inf']['f32']:x}" - elif dtype == torch.float64: - return f"0x{mlir_common.MLIR_INF['-inf']['f64']:x}" - else: - return "0.0" + return "-inf" if reduction_type in {"min", "argmin"}: - if dtype == torch.float32: - return f"0x{mlir_common.MLIR_INF['inf']['f32']:x}" - elif dtype == torch.float64: - return f"0x{mlir_common.MLIR_INF['inf']['f64']:x}" - else: - return "0.0" + return "inf" if reduction_type in {"welford_reduce"}: return f"0.0" raise AssertionError(reduction_type) @@ -66,19 +57,6 @@ def reduction_partial_combine_vec(reduction_type, vector_value, init_value): return ops.logical_and(vector_value, init_value) raise AssertionError(reduction_type) -def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): - if reduction_type == "sum": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "prod": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "max": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "min": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "any": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - raise AssertionError(reduction_type) - class ExtensionWrapperCodegen(wrapper.WrapperCodeGen): def __init__(self): super().__init__() @@ -215,675 +193,6 @@ def generate(self, is_inference): def memory_plan(self): self.lines = memory_planning.MemoryPlanner(self).plan(self.lines) -class ExtensionOverrides(common.OpOverrides): - # Binary element wise operations - @staticmethod - def custom_cast(operand, target_type, *args, var_info=None, **kwargs): - dtype = var_info[operand][1] - if dtype == "index": - ret = ops.index_cast(operand, target_type, var_info=var_info) - else: - ret = ops.to_dtype(operand, target_type, var_info=var_info) - return ret, var_info[ret] - - @staticmethod - def binary_elementwise_common(operand1, operand2, var_info): - operand1.bounds = operand1.bounds.unknown() - operand2.bounds = operand2.bounds.unknown() - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - # Tile size check - if op_type1[0] != op_type2[0]: - # Try to broad cast - lhs_tile_size, lhs_dtype = op_type1 - rhs_tile_size, rhs_dtype = op_type2 - if lhs_tile_size > rhs_tile_size: - operand2 = ops.broadcast(operand2, operand1, var_info=var_info) - op_type2 = var_info[operand2] - elif lhs_tile_size < rhs_tile_size: - operand1 = ops.broadcast(operand1, operand2, var_info=var_info) - op_type1 = var_info[operand1] - - # Data type check - if op_type1[1] != op_type2[1]: - if op_type1[1] == "index" or op_type1 == "index": - if op_type1[1] == "index": - operand1 = ops.index_cast(operand1, op_type2[1], var_info) - op_type1 = var_info[operand1] - if op_type2[1] == "index": - operand2 = ops.index_cast(operand2, op_type1[1], var_info) - op_type2 = var_info[operand2] - elif op_type1[1][0] == "i" and op_type2[1][0] == "f": - operand1 = ops.to_dtype(operand1, op_type2[1], var_info) - op_type1 = var_info[operand1] - elif op_type1[1][0] == "f" and op_type2[1][0] == "i": - operand2 = ops.to_dtype(operand2, op_type1[1], var_info) - op_type2 = var_info[operand2] - elif op_type1[1][0] == op_type2[1][0]: - if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: - operand2 = ops.ext(operand2, op_type1[1]) - op_type2 = var_info[operand2] - elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: - operand1 = ops.ext(operand1, op_type2[1]) - op_type1 = var_info[operand1] - else: - raise NotImplementedError("Unsupported type converting") - - # Updated var info - tile_size = op_type1[0] - ret_type = op_type1[1] - return tile_size, ret_type, operand1, operand2 - - @staticmethod - def add(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.add{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def sub(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.sub{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def mul(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.mul{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def div(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.divf' - else: - opcode = f'arith.divui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def truediv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.divf' - else: - opcode = f'arith.divui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def modular(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - raise NotImplementedError("Not support remainder operation for floating point") - else: - opcode = f'arith.remui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def minimum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.minimumf' - else: - opcode = f'arith.minimumui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def maximum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.maximumf' - else: - opcode = f'arith.maximumui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): - src_mlir_dtype = var_info[operand][1] - if src_mlir_dtype == "index": - operand = ops.index_cast(operand, "i64", var_info=var_info) - src_mlir_dtype = var_info[operand][1] - - tile_size = var_info[operand][0] - if isinstance(dst_mlir_dtype, torch.dtype): - dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] - dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] - src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] - shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype - src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype - if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": - return f"arith.fptoui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": - return f"arith.uitofp %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - if dst_mlir_dtype[0] == "i": - if dst_bits > src_bits: - return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - elif dst_bits < src_bits: - return f"arith.trunc %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - return f"arith.maximumi %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] - elif dst_mlir_dtype[0] == "f": - if dst_bits > src_bits: - return f"arith.extf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - elif dst_bits < src_bits: - return f"arith.trunf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - return f"arith.maximumf %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] - else: - raise NotImplementedError("Unsupported type for to_dtype ops") - - @staticmethod - def constant(value, src_type, *args, var_info=None, **kwargs): - if isinstance(src_type, torch.dtype): - src_type = mlir_common.DTYPE_TO_MLIR[src_type] - - if "inf" == str(value) or "-inf" == str(value) or "nan" == str(value): - value = f"0x{mlir_common.MLIR_INF[str(value)][src_type]:x}" - # if value represented by e notation, convert to float (ex 1e-3 -> 1.0e-3) - elif "e" in str(value): - value = format(float(value), ".20f") - elif src_type[0] == "f": - value = format(value, ".20f") - elif src_type[0] == "i": - value = int(value) - return f'arith.constant {value} : {src_type}', [1, src_type] - - @staticmethod - def alloc(size, src_type, *args, var_info=None, **kwargs): - return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] - - @staticmethod - def extractelement(operand, idx, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"vector.extract %{operand}[{idx}]: {dtype} from {shape}", [1, dtype] - - # transcendental functions - @staticmethod - def exp(operand, *args, var_info=None, **kwargs): - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.exp(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.exp %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def exp2(operand, *args, var_info=None, **kwargs): - # Hands-on part: implement exp2 using math.exp2 - # var_info = {operand: [tile_size, dtype]} - # Ex) var_info[operand] = [8, "f32"] - - ln2 = math.log(2) - coeff = ops.constant(ln2, "f32") - operand = ops.mul(operand, coeff) - return ops.exp(operand), var_info[operand] - - @staticmethod - def erf(operand, *args, var_info=None, **kwargs): - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.erf(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.erf %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def tanh(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.tanh(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.tanh %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def sin(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.sin(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sin %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def cos(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.cos(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.cos %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def sqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def rsqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def pow(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - # Type check & auto cast - if ret_type[0] != "f": - operand1, ret_type = ops.to_dtype(operand1, "f32", var_info=var_info) - var_info[operand1] = ret_type - - # Type check & auto cast - if ret_type[0] != "f": - operand2, ret_type = ops.to_dtype(operand2, "f32", var_info=var_info) - var_info[operand2] = ret_type - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f"math.pow{ret_type[0]} %{operand1}, %{operand2} : {shape}", [tile_size, ret_type] - - @staticmethod - def log(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.log %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def reciprocal(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - return ops.div(ops.constant(1.0, dtype), operand), [tile_size, dtype] - - @staticmethod - def ext(operand, dtype, *args, var_info=None, **kwargs): - op_type = var_info[operand] - shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" - target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" - if op_type[0] == "f": - opcode = f'arith.extf' - else: - opcode = f'arith.extui' - return f'{opcode} %{operand} : {shape} to {target_type}', [op_type[0], dtype] - - # Logical operations - @staticmethod - def neg(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.negf %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def eq(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "oeq" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "eq" - else: - raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def ne(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "one" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sne" - else: - raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def lt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "olt" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "slt" - else: - raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def gt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "ogt" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sgt" - else: - raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def le(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "ole" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sle" - else: - raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def ge(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "oge" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sge" - else: - raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def and_(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - - # Type check & auto cast - if op_type1[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand1] = dtype - - # Type check & auto cast - if op_type2[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand2] = dtype - - ret_type = op_type1[1] - tile_size = op_type1[0] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def or_(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - - # Type check & auto cast - if op_type1[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand1] = dtype - - # Type check & auto cast - if op_type2[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand2] = dtype - - ret_type = op_type1[1] - tile_size = op_type1[0] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def xor(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - - # Type check & auto cast - if op_type1[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand1] = dtype - - # Type check & auto cast - if op_type2[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand2] = dtype - - ret_type = op_type1[1] - tile_size = op_type1[0] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - - @staticmethod - def logical_and(operand1, operand2, *args, var_info=None, **kwargs): - op_type = var_info[operand1] - # Type check & auto cast - if op_type[1] != "i1": - raise NotImplementedError("Logical operation with not bool data type") - return ExtensionOverrides.and_(operand1, operand2, *args, var_info=var_info, **kwargs) - - @staticmethod - def logical_not(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - ret_type = op_type[1] - tile_size = op_type[0] - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - const_one = ops.constant(0, ret_type) - const_one = ops.broadcast(const_one, operand, var_info=var_info) - ret = ops.eq(operand,const_one) - return ret, [tile_size, var_info[ret]] - - @staticmethod - def logical_or(operand1, operand2, *args, var_info=None, **kwargs): - op_type = var_info[operand1] - # Type check & auto cast - if op_type[1] != "i1": - raise NotImplementedError("Logical operation with not bool data type") - return ExtensionOverrides.or_(operand1, operand2, *args, var_info=var_info, **kwargs) - - @staticmethod - def logical_xor(operand1, operand2, *args, var_info=None, **kwargs): - op_type = var_info[operand1] - # Type check & auto cast - if op_type[1] != "i1": - raise NotImplementedError("Logical operation with not bool data type") - return ExtensionOverrides.xor(operand1, operand2, *args, var_info=var_info, **kwargs) - - @staticmethod - def relu(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - ret_type = "f32" - return ops.maximum(operand, ops.constant(0.0, "f32")), [tile_size, ret_type] - - @staticmethod - def sigmoid(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - ret_type = "f32" - one = ops.constant(1, "f32") - return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, ret_type] - - # Special operaitons - @staticmethod - def where(condition, operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - cond_type = var_info[condition] - if cond_type[0] < tile_size: - condition = ops.broadcast(condition, operand1, var_info=var_info) - elif cond_type[0] > tile_size: - operand1 = ops.broadcast(operand1, condition, var_info=var_info) - operand2 = ops.broadcast(operand2, condition, var_info=var_info) - tile_size, ret_type = var_info[operand1] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" - return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape} {shape}", [tile_size, ret_type] - - - @staticmethod - def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): - result = body() - val = ops.constant(other, dtype, *args, **kwargs) - result = ops.where(mask, result, val) - return result, var_info[result] - - @staticmethod - def index_cast(operand, target_type, *args, var_info=None, **kwrags): - op_type = var_info[operand] - src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] - des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type - return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] - - @staticmethod - def broadcast_unflat(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>"# if op_type1[0] > 1 else op_type1[1] - des_shape = f"vector<{op_type2[0]//op_type1[0]}x{op_type1[0]}x{op_type1[1]}>"# if op_type2[0] > 1 else op_type1[1] # Use tile size only - - expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" - return expand, [op_type2[0], op_type1[1]] - - @staticmethod - def broadcast(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] - des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" # if op_type2[0] > 1 else op_type1[1] # Use tile size only - - # Special case for length 2 vector. We used this vector to avoid scalar operations... - if op_type1[0] != 1 and op_type2[0] % op_type1[0] == 0: - unflat_operand = ops.broadcast_unflat(operand1, operand2) - unflat_shape = f"vector<{op_type2[0]//op_type1[0]}x{op_type1[0]}x{op_type1[1]}>" - expand = f"vector.shape_cast %{unflat_operand} : {unflat_shape} to {des_shape}" - elif op_type1[0] == 1: - expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" - else: - raise NotImplementedError("Not supporting broadcast type...") - return expand, [op_type2[0], op_type1[1]] RTYPE_TO_MLIR = { "sum": "add", @@ -977,8 +286,10 @@ def convert_index(self, expr, buffer): expr_str = expr_str.replace("//", " floordiv ") else: raise NotImplementedError("What is this case?") - - indices = [expr.args[0]] + first_arg = expr.args[0] + if len(first_arg.free_symbols) != 1: + raise NotImplementedError("What is this case?") + indices = [list(first_arg.free_symbols)[0]] args = ", ".join(map(str, indices)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) @@ -1031,7 +342,6 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) -> common.CSEVariable: if buffer is None: buffer = self.applys - zero_var = self.get_const_cse(0) expr_list = [arg for arg in expr_list] dim_list = [f"d{i}" for i in range(len(expr_list))] @@ -1102,6 +412,7 @@ def load(self, name: str, index: sympy.Expr): # Define scratch pad buffer sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # MVIN Encoding attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" @@ -1110,31 +421,34 @@ def load(self, name: str, index: sympy.Expr): self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching if not comptute_depedency: - compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector load instruction - if compute_vec_size > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - - out = self.cse.generate(load_buffer, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) - self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] - return out + with self.override_buffer_cse(buffer=load_buffer): + out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) else: + # FIXME. Any good idea? out = sram_var self.register_var_info(out, [compute_vec_size, mlir_dtype]) - self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] - return out + self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] + return out - def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): + def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs): index = self.rename_indexing(index) - dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + # Handle scatter store + if "tmp" in str(index): + if mode == "atomic_add": + # Convert the output buffer type to the inplace buffer + arg_name = V.graph.scheduler.mutation_real_name.get(name, name) + if arg_name not in self.kernel_group.args.inplace_buffers: + self.kernel_group.args.make_inplace(arg_name, arg_name) + + loaded_value = ops.load(name, index) + value = ops.add(loaded_value, value) + index, _ = self.convert_indirect_indexing(index) + dram_var = self.kernel_group.args.output(name) + # Prepare dma instruction local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vmap.vlane_split_axis @@ -1148,9 +462,6 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() require_store = True - if compute_vec_size < self.var_info[value][0]: - value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}") - self.register_var_info(value, [compute_vec_size, mlir_dtype]) if str(value) in self.spad_buffer_dict: # Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily @@ -1161,17 +472,16 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector store instruction - store_size, operand_type = self.var_info[value] + _, operand_type = self.var_info[value] if mlir_dtype != operand_type: - value = ops.custom_cast(value, mlir_dtype, var_info=self.var_info) + value = ops.to_dtype(value, mlir_dtype) - if compute_vec_size > 1 and store_size > 1: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - self.stores.writeline(common.DeferredLine(name, line)) # TODO: Should be changed to self.compute? + if compute_vec_size < self.var_info[value][0]: + value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}") + self.register_var_info(value, [compute_vec_size, mlir_dtype]) + + with self.override_buffer_cse(buffer=self.stores): + ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=name) else: sram_var = self.spad_buffer_dict[str(value)][0] sram_index_var = self.spad_buffer_dict[str(value)][3] @@ -1206,10 +516,12 @@ def reduction(self, dtype, src_dtype, reduction_type, value): vec_len = self.kernel_group.tile_desc.get_compute_vec_size() reduced_shape = self.kernel_group.tile_desc.get_mlir_vshape(type_name) + + # Prepare reduction init - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - init_vec = init if vec_len == 1 else self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") - self.register_var_info(init_vec, [vec_len, type_name]) + with self.override_buffer_cse(cse=self.const_cse, buffer=self.const_buffer): + init = self.get_const_cse(reduction_init(reduction_type, dtype), type_name) + init_vec = init if vec_len == 1 else ops.broadcast(init, vec_len) acc_var_list = [] iter_var_list = [] @@ -1239,104 +551,76 @@ def reduction(self, dtype, src_dtype, reduction_type, value): _, mask_var = self.get_mask() if mask_var is not None: value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) + result = ops.to_dtype(result, type_name) + self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iter_var_list[-1], reduced_shape) self.compute_body_loop.affine_yield[result] = reduced_shape - # Register affine yield var for reduction_depth, acc in enumerate(acc_var_list[1:]): self.affine_yield[acc] = reduced_shape, reduction_depth # Final reduction - acc = acc_var_list[0] # Set outermost acc var reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() + acc = acc_var_list[0] # Set outermost acc var + self.register_var_info(acc, [reduction_size, type_name]) assert(vec_len % reduction_size==0) - if vec_len > reduction_size: - init = self.const_cse.generate(self.reductions_suffix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - if reduction_size == 1: - final_reduced_shape = f"{type_name}" - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, acc, init, axis=0, shape=reduced_shape, reduced_shape=final_reduced_shape)) - else: - final_reduced_shape = f"vector<{reduction_size}x{type_name}>" - init_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{init} : {type_name} to {final_reduced_shape}") - new_vshape= f"vector<{vec_len//reduction_size}x{reduction_size}x{type_name}>" - value = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{acc} : {reduced_shape} to {new_vshape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, value, init_vec, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape)) - acc = out - - # reigster reduction output - var_info = [reduction_size, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(acc, var_info) + + # Prepare init value + init = self.get_const_cse(reduction_init(reduction_type, dtype), type_name) + if reduction_size != 1: + with self.override_buffer_cse(buffer=self.reductions_suffix): + init = ops.broadcast(init, reduction_size) + + # Final reduction codegen + with self.override_buffer_cse(buffer=self.reductions_suffix): + if vec_len > reduction_size: + acc = ops.multi_reduction(acc, init, vec_len, reduction_size, reduced_shape, reduction_type, type_name) return acc def store_reduction(self, name, index, value): - # Note: Change cse temporaily # Store reduction can't share cached value stored in cse, # since it is not innermost loop body. - tmp_cse = self.cse - tmp_apply_cse = self.apply_cse - self.cse = self.reduction_cse - self.apply_cse = self.reduction_cse - dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] index = self.rename_indexing(index) - # Tile is always reuduced in inner loop - local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) - vlane_split_axis = local_tile_desc.vmap.vlane_split_axis - vlane_stride = local_tile_desc.vmap.vlane_stride - - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = local_tile_desc.get_tile_stride() - compute_vec_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() - if compute_vec_size == 1: - vshape = f"{mlir_dtype}" - else: - vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) - if self.welford_reduce_out is not None: - sum, sqr_sum, _ = self.welford_reduce_out - # mean - reduction_numel = reduce(mul, self.ranges[self.reduction_depth:], 1) - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(reduction_numel)} : f32") - if compute_vec_size > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{mlir_dtype}>") - else: - divider_vec = divider - mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {vshape}") - - # m2 = (E(X^2) - E(X)^2) * N - sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sqr_sum}, %{divider_vec} : {vshape}") - mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{mean}, %{mean} : {vshape}") - variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {vshape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {vshape}") - if self.current_node.node.origin_node: # FIXME: This is a temporary solution - value = mean - else: - value = m2 - - # Select src type - if compute_vec_size == 1: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}" - else: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" - self.reductions_suffix.writeline(common.DeferredLine(name, line)) + with self.override_buffer_cse(cse=self.reduction_cse): + # Tile is always reuduced in inner loop + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) + vlane_split_axis = local_tile_desc.vmap.vlane_split_axis + vlane_stride = local_tile_desc.vmap.vlane_stride - # MVOUT Encoding - # Generate DMA instruction - attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) - self.reductions_suffix.writeline(common.DeferredLine(name, code)) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = local_tile_desc.get_tile_stride() - # Restore origin cse - self.cse = tmp_cse - self.apply_cse = tmp_apply_cse + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) + with self.override_buffer_cse(buffer=self.reductions_suffix): + if self.welford_reduce_out is not None: + # Calc var and mean + sum, sqr_sum, _ = self.welford_reduce_out + reduction_numel = reduce(mul, self.ranges[self.reduction_depth:], 1) + divider = self.get_const_cse(float(reduction_numel), "f32") + mean = ops.truediv(sum, divider) + sqr_mean = ops.truediv(sqr_sum, divider) + mean_sqr = ops.mul(mean, mean) + variance = ops.sub(sqr_mean, mean_sqr) + m2 = ops.mul(variance, divider) + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + value = mean + else: + value = m2 + # Store value to scratch pad + ops._store(value, sram_var, sram_index_var, tile_shape, buffer_name=name) + + # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.reductions_suffix.writeline(common.DeferredLine(name, code)) def indirect_indexing(self, index_var, size, check=True): return str(index_var) @@ -1354,77 +638,71 @@ def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): strides = tile_desc.get_tile_stride_per_lane() # Create vector index - compute_vec = self.cse.generate(self.compute, f"vector.broadcast %{self.compute_idx} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(compute_vec, [compute_vec_size, "index"]) + compute_vec = ops.broadcast(self.compute_idx, compute_vec_size) vector_index = ops.add(base_vector_index, compute_vec) # Create tile_dim index dim_list = [] for idx in range(len(tile_size)): - div_coeff = self.get_const_cse(strides[idx], "index") - mod_coeff = self.get_const_cse(tile_size[idx], "index") - div_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{div_coeff} : index to vector<{compute_vec_size}xindex>") - mod_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{mod_coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(div_vec, [compute_vec_size, "index"]) - self.register_var_info(mod_vec, [compute_vec_size, "index"]) - dim = ops.modular(ops.div(vector_index, div_vec), mod_vec) - if idx == tile_desc.vmap.vlane_split_axis: # Need to add vector lane offset - offset = tile_desc.vmap.vlane_stride #* strides[idx] - outer_sz = tile_size[idx] // tile_desc.vmap.vlane_stride - - nr_vector_lane = self.get_const_cse(self.vector_lane, "index") - nr_vector_lane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{nr_vector_lane} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(nr_vector_lane_vec, [compute_vec_size, "index"]) - + # Prepare initial values + offset = tile_desc.vmap.vlane_stride #* strides[idx] + outer_sz = tile_size[idx] // tile_desc.vmap.vlane_stride + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + div_coeff = self.get_const_cse(strides[idx], "index") + mod_coeff = self.get_const_cse(tile_size[idx], "index") vlane_stride_coeff = self.get_const_cse(tile_desc.vmap.vlane_stride, "index") vlane_outer_coeff = self.get_const_cse(outer_sz, "index") - vlane_stride_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_stride_coeff} : index to vector<{compute_vec_size}xindex>") - vlane_outer_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_outer_coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(vlane_stride_vec, [compute_vec_size, "index"]) - self.register_var_info(vlane_outer_vec, [compute_vec_size, "index"]) - stride_dim = ops.modular(dim, vlane_stride_vec) - outer_dim = ops.modular(ops.div(dim, vlane_stride_vec), vlane_outer_vec) + nr_vector_lane = self.get_const_cse(self.vector_lane, "index") + vlane_coeff = self.get_const_cse(0, "i64") - dim = ops.add(stride_dim, ops.mul(outer_dim, nr_vector_lane_vec)) + div_vec = ops.broadcast(div_coeff, compute_vec_size) + mod_vec = ops.broadcast(mod_coeff, compute_vec_size) + nr_vector_lane_vec = ops.broadcast(nr_vector_lane, compute_vec_size) + vlane_stride_vec = ops.broadcast(vlane_stride_coeff, compute_vec_size) + vlane_outer_vec = ops.broadcast(vlane_outer_coeff, compute_vec_size) # Prepare vlane offset (vidx) - vlane_coeff = self.get_const_cse(0, "i64") vlane_vec_size = 4 - vlane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_coeff} : i64 to vector<{vlane_vec_size}xi64>") + vlane_vec = ops.broadcast(vlane_coeff, vlane_vec_size) + + dim = ops.remainder(ops.truncdiv(vector_index, div_vec), mod_vec) + if idx == tile_desc.vmap.vlane_split_axis: # Need to add vector lane offset + stride_dim = ops.remainder(dim, vlane_stride_vec) + outer_dim = ops.remainder(ops.truncdiv(dim, vlane_stride_vec), vlane_outer_vec) + dim = ops.add(stride_dim, ops.mul(outer_dim, nr_vector_lane_vec)) + vlane_offset = self.const_cse.generate(self.const_buffer, f"arith.addi %{vlane_vec}, %{vlane_vec} {{ vlane_offset={offset} }} : vector<{vlane_vec_size}xi64> // vlane offset") self.register_var_info(vlane_offset, [vlane_vec_size, "i64"]) vlane_offset = ops.index_cast(vlane_offset, "index") - self.register_var_info(vlane_offset, [vlane_vec_size, "index"]) - dim = ops.add(dim, vlane_offset) dim_list.append(dim) indices = [str(i) for i in index.free_symbols] for idx in indices: i = int(idx[5:]) - index_vec = self.cse.generate(self.compute, f"vector.broadcast %{idx} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(index_vec, [compute_vec_size, "index"]) + idx = self.itervar_cses[idx] + index_vec = ops.broadcast(idx, compute_vec_size) offset = ops.add(index_vec, dim_list[i]) dim_list[i] = offset arg_lists = [] for arg in renamed_expression.args: if isinstance(arg, sympy.Integer): - offset = self.get_const_cse(int(arg)) - offset_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{offset} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(offset_vec, [compute_vec_size, "index"]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + offset = self.get_const_cse(int(arg), "index") + offset_vec = ops.broadcast(offset, compute_vec_size) arg_lists.append(offset_vec) elif isinstance(arg, sympy.Mul): if isinstance(arg.args[0], sympy.Integer) and isinstance(arg.args[1], sympy.Symbol): - coeff = self.get_const_cse(int(arg.args[0])) - coeff_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(coeff_vec, [compute_vec_size, "index"]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + coeff = self.get_const_cse(int(arg.args[0]), "index") + coeff_vec = ops.broadcast(coeff, compute_vec_size) result = ops.mul(dim_list[int(str(arg.args[1])[1:])], coeff_vec) arg_lists.append(result) elif isinstance(arg.args[1], sympy.Integer) and isinstance(arg.args[0], sympy.Symbol): - coeff = self.get_const_cse(int(arg.args[1])) - coeff_vec = self.cse.generate(self.compute, f"vector.broadcast %{coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(coeff_vec, [compute_vec_size, "index"]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + coeff = self.get_const_cse(int(arg.args[1]), "index") + coeff_vec = ops.broadcast(coeff, compute_vec_size) result = ops.mul(dim_list[int(str(arg.args[0])[1:])], coeff_vec) arg_lists.append(result) else: @@ -1474,18 +752,16 @@ def index_expr(self, index, dtype): # Initialize base vector if not self.base_vector_initialized: - init_iter = "iter" + init_iter = self.register_var_cse("init_iter", 1, "index") parallel_map = f"affine.parallel (%{init_iter}) = ({0}) to ({compute_vec_size}) {{ // Base vector initializer" self.spad_buffer.writeline(parallel_map) with self.spad_buffer.indent(): - self.spad_buffer.writeline(f"%init_vec = vector.broadcast %{init_iter} : index to vector<2xindex>") - self.spad_buffer.writeline(f"affine.vector_store %init_vec, %{sram_var}[%{init_iter}] : {tile_shape}, vector<2xindex>") + with self.override_buffer_cse(buffer=self.spad_buffer, cse=self.init_vec_cse): + init_vec = ops.broadcast(init_iter, 2) + ops._store(init_vec, sram_var, f"%{init_iter}", tile_shape) self.spad_buffer.writeline("}") self.base_vector_initialized = True - - line = f"affine.vector_load %{sram_var}[0] : {tile_shape}, {vshape}" - base_vector_index = self.cse.generate(self.compute, line) - self.register_var_info(base_vector_index, [compute_vec_size, "index"]) + base_vector_index = ops._load(compute_vec_size, "index", sram_var, "0", tile_shape) renamed_symbols = {symbol: "d"+str(symbol)[5:] for symbol in index.free_symbols} renamed_expression = index.subs(renamed_symbols) @@ -1643,7 +919,7 @@ def get_cycle(choice): return float("inf") # Exceeded maximum number of autotuning attempts choices = self.make_choices(*args) - if len(choices) == 0: # can't autotune + if len(choices) == 0: # Can't autotune return [None, None] with ThreadPoolExecutor(max_workers=8) as executor: results = list(executor.map(get_cycle, choices)) @@ -1736,15 +1012,15 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe total_dims = [int(str(i)[5:]) for i in self.itervars] local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop - indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)] - for indirect_dim in indirect_dims: - index = index.replace(sympy.Symbol(indirect_dim), 0) + indirect_syms = [s for s in index.free_symbols if "tmp" in s.name] + index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True) + indirect_dims = [f"{i}" for i in indirect_syms] # Reduction can have two type of tile size if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): local_dims = total_dims # Brodatcast tile shape - index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims) + index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims, comments=f"// store_reduction={store_reduction}") if kg_tile_desc.vmap.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vmap.vlane_split_axis) @@ -1957,14 +1233,18 @@ def get_scratchpad_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=N return sram_var, sram_index_var def get_const_cse(self, value, dtype="index") -> common.CSEVariable: + # Why not use ops.constant? Because there are some cases that can't use ops (e.g., def_dma_op) # Type convert - if dtype[0] == "f": + if value in ["inf", "-inf", "nan"]: + value = f"0x{mlir_common.MLIR_INF[value][dtype]:x}" + elif dtype[0] == "f": value = float(value) else: value = int(value) if value not in self.consts: self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") + self.register_var_info(self.consts[str(value)+dtype], [1, dtype]) return self.consts[str(value)+dtype] def get_tag_cse(self, value=None, shape="memref<1xi32>"): @@ -1979,16 +1259,16 @@ def get_mask(self): if self.compute_body_loop.size % self.compute_body_loop.step == 0: return None, None compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() - index_shape = f"vector<{self.compute_body_loop.step}xindex>" mask_shape = f"vector<{compute_vec_size}xi1>" - upper_bound = self.get_const_cse(self.compute_body_loop.size) - step_vec = self.const_cse.generate(self.const_buffer, f"vector.step : {index_shape}") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + upper_bound = ops.constant(self.compute_body_loop.size, "index") + step_vec = ops.step(self.compute_body_loop.step, "index") - gap = self.mask_cse.generate(self.masks, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") - gap_vec = self.mask_cse.generate(self.masks, f"vector.broadcast %{gap} : index to {index_shape}") - mask_var = self.mask_cse.generate(self.masks, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") - self.register_var_info(mask_var, [compute_vec_size, "i1"]) + with self.override_buffer_cse(buffer=self.masks, cse=self.mask_cse): + gap = ops.sub(upper_bound, self.compute_idx) + gap_vec = ops.broadcast(gap, self.compute_body_loop.step) + mask_var = ops.lt(step_vec, gap_vec) return mask_shape, mask_var def convert_indirect_indexing(self, index :sympy.Expr): @@ -2007,14 +1287,8 @@ def convert_indirect_indexing(self, index :sympy.Expr): indirect_dims.sort() first_dim = indirect_dims[0] spad_vars = dict() - old_compute, old_dma_lods, old_dma_stores = self.compute, self.dma_loads, self.dma_stores compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims]) - if compute_dependecy: - self.compute = old_dma_stores - target_dma_buffers = self.dma_stores - else: - self.compute = old_dma_lods - target_dma_buffers = self.dma_loads + target_dma_buffers = self.dma_stores if compute_dependecy else self.dma_loads # Load indirect operands for target_dim in indirect_dims: @@ -2028,6 +1302,7 @@ def convert_indirect_indexing(self, index :sympy.Expr): local_tile_desc = self.kernel_group.tile_desc tile_numel_per_lane = local_tile_desc.get_numel_per_lane() tile_shape = local_tile_desc.get_mlir_shape(var_info[1]) + tile_vec = local_tile_desc.get_compute_vec_size() vshape = f"vector<{var_info[0]}x{var_info[1]}>" sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim) self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] @@ -2038,52 +1313,37 @@ def convert_indirect_indexing(self, index :sympy.Expr): line = f"{opeartion} %{target_dim}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" self.stores.writeline(line) mlir_dtype = vshape.split("x")[1][:-1] - vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... - if tile_numel_per_lane > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape} // For indirect access" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape} // For indirect access" - out = self.cse.generate(target_dma_buffers, line) - self.register_var_info(out, [tile_numel_per_lane, mlir_dtype]) - spad_vars[target_dim] = out - - # Apply stride - for arg in index.args: - if "tmp" not in str(arg): - continue - if arg.is_Mul and arg.args[0].is_number: - coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] - coeff = ops.constant(int(arg.args[0]), coeff_dtype) - spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) - index = index.replace(arg, 0) - - # Sum - for dim, var in spad_vars.items(): - if dim == first_dim: - continue - spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) + with self.override_buffer_cse(buffer=target_dma_buffers): + out = ops._load(tile_numel_per_lane, mlir_dtype, sram_var, sram_index_var, tile_shape) + spad_vars[target_dim] = out + + with self.override_buffer_cse(buffer=target_dma_buffers): + # Apply stride + for arg in index.args: + if "tmp" not in str(arg): + continue + if arg.is_Mul and arg.args[0].is_number: + coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] + coeff = self.get_const_cse(int(arg.args[0]), coeff_dtype) + spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) + index = index.replace(arg, 0) + + # Sum + for dim, var in spad_vars.items(): + if dim == first_dim: + continue + spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) # Store index var sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim] mlir_dtype = vshape.split("x")[1][:-1] - vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... - if tile_numel_per_lane > 1: - operation = "affine.vector_store" - line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}" - out = self.cse.generate(target_dma_buffers, line, assignment=False) + with self.override_buffer_cse(buffer=target_dma_buffers): + ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) # FIXME. Maybe require fine grain compute... # Conversion mlir_dtype = self.var_info[spad_vars[first_dim]][1] - line = f"affine.load %{sram_var}[{sram_index_var}] : {tile_shape}" - out = self.cse.generate(target_dma_buffers, line) - if mlir_dtype != "index": - line = f"arith.index_cast %{out} : {mlir_dtype} to {'index'}" - out = self.cse.generate(target_dma_buffers, line) - self.register_var_info(out, [1, "index", [1]]) - self.compute, self.dma_loads, self.dma_stores = old_compute, old_dma_lods, old_dma_stores + with self.override_buffer_cse(buffer=target_dma_buffers): + out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape) + if mlir_dtype != "index": + out = ops.index_cast(out, "index") return index + sympy.Symbol(str(out)), compute_dependecy diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 4d33eea4..15408c0d 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -1,5 +1,7 @@ import dataclasses import math +import contextvars +from contextlib import contextmanager from dataclasses import dataclass from typing import Dict from typing import List @@ -68,7 +70,7 @@ torch.int8: "int8_t", torch.uint8: "uint8_t", torch.bool: "uint8_t", - torch.bfloat16: "bfloat16", + torch.bfloat16: "uint16_t", } MLIR_TO_BIT = { @@ -588,6 +590,7 @@ def __init__(self, kernel_group, reason=None): self.ranges = None self.reduction_depth = None self.itervars = None + self.itervar_cses = None # Code buffer self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() @@ -595,12 +598,17 @@ def __init__(self, kernel_group, reason=None): # MLIR SSA tracker self.var_info = {} # MLIR variable info self.buffer_types : dict = None # format: dtype, numel, size, stride - self.compute_idx = "compute_idx" + # Create compute idx + self.compute_idx = self.register_var_cse("compute_idx", 1, "index") self.compute_body_loop = LoopLevel(self.compute_idx, 1) self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False + # Context var for codegen + self.target_buffer_override = contextvars.ContextVar("Handler_compute_override", default=self.compute) + self.target_cse_override = contextvars.ContextVar("Handler_cse_override", default=self.cse) + def set_ranges(self, lengths, reduction_lengths): if self.call_ranges: assert self.call_ranges == tuple(lengths) + tuple( @@ -611,6 +619,7 @@ def set_ranges(self, lengths, reduction_lengths): self.call_ranges = tuple(lengths) + tuple(reduction_lengths) self.ranges = [self.rename_indexing(x) for x in self.call_ranges] self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + self.itervar_cses = {str(index) : self.register_var_cse(str(index), 1, "index") for index in self.itervars} self.reduction_depth = len(lengths) return ( self.itervars[: self.reduction_depth], @@ -783,8 +792,6 @@ def codegen_kernel(self, kernel_name): code.splice(self.codegen_global_init()) code.writeline(f'func.func @{kernel_decl_name}({arg_defs})') with code.indent(): - for old, new in self.kernel_group.args.aliases(): - code.writeline(f"auto {old} = {new};") # Loop body part code.splice(self.codegen_loops()) return code.getvalue() @@ -801,28 +808,6 @@ def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] return constant_vector - def get_constant_vector2(self, expr): - # Case 0. symbol ex) index 0 - # Case 1. inner product form ex) 16 * index0 + 1 * index1 - # Case 2. Complicated form ex) 16 * index0 + 8 * (index//4) + (index % 4) - constant_vector = [] - if expr.is_symbol: - constant_vector.append(tuple([1, expr])) - return constant_vector - - for arg in expr.args: - if arg.is_symbol: - constant_vector.append(tuple([1,arg])) - continue - if len(arg.args) == 0: #TODO: check this - continue - if arg.args[0].is_number: - constant_vector.append(arg.args) - else: - constant_vector.append([1, arg]) - - return constant_vector - def find_node_by_name(self, name): if name in V.graph.graph_inputs: return V.graph.graph_inputs[name] @@ -837,6 +822,11 @@ def is_scalar(self, name): def roundup_vectorlane(self, size, amp=1): return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp + def register_var_cse(self, name, size, dtype): + var = self.create_cse_var(name, ValueRanges.unknown()) + self.register_var_info(var, [size, dtype]) + return var + def register_var_info(self, var, var_info): self.var_info[var] = var_info @@ -854,6 +844,21 @@ def rename_indexing(self, index) -> sympy.Expr: } return sympy_subs(index, replacements) + @contextmanager + def override_buffer_cse(self, *, buffer=None, cse=None): + target_buffer = target_cse = None + try: + if buffer is not None: + target_buffer = self.target_buffer_override.set(buffer) + if cse is not None: + target_cse = self.target_cse_override.set(cse) + yield self + finally: + if target_cse is not None: + self.target_cse_override.reset(target_cse) + if target_buffer is not None: + self.target_buffer_override.reset(target_buffer) + def __enter__(self): class CSEProxy: self.name = "CSEProxy" @@ -861,16 +866,22 @@ class CSEProxy: @staticmethod def __getattr__(name: str) -> Callable[..., common.CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info) - csevar = self.cse.generate( - self.compute, - code, - bounds=ValueRanges.unknown(), - assignment=(ret_info[0] is not None) - ) - if ret_info[0] is not None: - self.register_var_info(csevar, ret_info) - csevar.update_on_args(name, args, kwargs) + code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info, **kwargs) + target_buffer = self.target_buffer_override.get() + target_cse = self.target_cse_override.get() + if isinstance(code, common.DeferredLine): + target_buffer.writeline(code) + return None + else: + csevar = target_cse.generate( + target_buffer, + code, + bounds=ValueRanges.unknown(), + assignment=(ret_info[0] is not None) + ) + if ret_info[0] is not None: + self.register_var_info(csevar, ret_info) + csevar.update_on_args(name, args, kwargs) return csevar return inner diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py new file mode 100644 index 00000000..21995512 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -0,0 +1,1038 @@ +import math +import torch + +from torch._inductor.codegen import common +from torch._inductor.virtualized import V, _ops as ops +from . import mlir_common + +def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): + if reduction_type == "sum": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "prod": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "max": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "min": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "any": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + raise AssertionError(reduction_type) + +class ExtensionOverrides(common.OpOverrides): + @staticmethod + def constant(value, src_type, *args, var_info=None, **kwargs): + if isinstance(src_type, torch.dtype): + src_type = mlir_common.DTYPE_TO_MLIR[src_type] + + str_val = str(value) + if "inf" == str_val or "-inf" == str_val or "nan" == str_val: + value = f"0x{mlir_common.MLIR_INF[str_val][src_type]:x}" + # scientific notation check + elif "e" in str_val: + value = format(float(value), ".20f") + elif src_type[0] == "f": + value = format(float(value), ".20f") + elif src_type[0] == "i": + value = int(float(value)) + return f'arith.constant {value} : {src_type}', [1, src_type] + + @staticmethod + def broadcast(operand, target_size, *args, var_info=None, **kwargs): + src_size, dtype = var_info[operand] + + src_shape = f"vector<{src_size}x{dtype}>" if src_size > 1 else dtype + dst_shape = f"vector<{target_size}x{dtype}>" + + op_str = "" + # Special case for length 2 vector. We used this vector to avoid scalar operations... + if src_size > 1: + if target_size % src_size == 0: + unflat_operand = ops.broadcast_unflat(operand, target_size) + outer_dim = target_size // src_size + unflat_shape = f"vector<{outer_dim}x{src_size}x{dtype}>" + # Flatten back to 1D + op_str = f"vector.shape_cast %{unflat_operand} : {unflat_shape} to {dst_shape}" + else: + raise NotImplementedError( + f"Vector broadcast size mismatch: src={src_size} cannot broadcast to target={target_size}" + ) + elif src_size == 1: + op_str = f"vector.broadcast %{operand} : {src_shape} to {dst_shape}" + else: + raise ValueError(f"Invalid source size: {src_size}") + return op_str, [target_size, dtype] + + @staticmethod + def broadcast_unflat(operand, target_size, *args, var_info=None, **kwargs): + src_size, dtype = var_info[operand] + + outer_dim = target_size // src_size + src_shape = f"vector<{src_size}x{dtype}>" + dst_shape = f"vector<{outer_dim}x{src_size}x{dtype}>" + + op_str = f"vector.broadcast %{operand} : {src_shape} to {dst_shape}" + return op_str, [target_size, dtype] + + def load_seed(self, *args, **kwargs): + raise NotImplementedError + + def rand(self, *args, **kwargs): + raise NotImplementedError + + def randn(self, *args, **kwargs): + raise NotImplementedError + + def randint64(self, *args, **kwargs): + raise NotImplementedError + + # Special operaitons + @staticmethod + def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): + result = body() + val = ops.constant(other, dtype, *args, **kwargs) + result = ops.where(mask, result, val) + return result, var_info[result] + + @staticmethod + def where(condition, operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + cond_type = var_info[condition] + operand_type = var_info[operand1] + condition = ops.to_bool(condition) + if cond_type[0] < tile_size: + condition = ops.broadcast(condition, tile_size) + elif cond_type[0] > tile_size: + operand1 = ops.broadcast(operand1, cond_type[0]) + operand2 = ops.broadcast(operand2, cond_type[0]) + tile_size, ret_type = var_info[operand1] + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + cond_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "" + return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape}, {shape}", [tile_size, ret_type] + + @staticmethod + def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): + # Extract source information + src_mlir_dtype = var_info[operand][1] + tile_size = var_info[operand][0] + + # Normalize destination type (Torch dtype -> MLIR string) + if isinstance(dst_mlir_dtype, torch.dtype): + dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] + + if src_mlir_dtype == "index" and dst_mlir_dtype != "index": + operand = ops.index_cast(operand, "i64") + src_mlir_dtype = "i64" # Update explicitly + + if dst_mlir_dtype == "index": + # If source is already index, return as is; otherwise cast + if src_mlir_dtype == "index": + return operand, [tile_size, "index"] + return ops.index_cast(operand, "index"), [tile_size, "index"] + + # Early return if types are identical + if src_mlir_dtype == dst_mlir_dtype: + return operand, [tile_size, dst_mlir_dtype] + + dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] + src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] + shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype + src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype + src_type_char = src_mlir_dtype[0] # 'i' or 'f' + dst_type_char = dst_mlir_dtype[0] # 'i' or 'f'o + + op_str = "" + + # Case A: Integer -> Float + if src_type_char == "i" and dst_type_char == "f": + op_str = f"arith.sitofp %{operand} : {src_shape} to {shape}" + # Case B: Float -> Integer + elif src_type_char == "f" and dst_type_char == "i": + op_str = f"arith.fptosi %{operand} : {src_shape} to {shape}" + # Case C: Integer -> Integer (Extension / Truncation) + elif src_type_char == "i" and dst_type_char == "i": + if dst_bits > src_bits: + op_str = f"arith.extsi %{operand} : {src_shape} to {shape}" + elif dst_bits < src_bits: + # Use arith.trunci for integer truncation + op_str = f"arith.trunci %{operand} : {src_shape} to {shape}" + else: + return operand, [tile_size, dst_mlir_dtype] + # Case D: Float -> Float (Extension / Truncation) + elif src_type_char == "f" and dst_type_char == "f": + if dst_bits > src_bits: + op_str = f"arith.extf %{operand} : {src_shape} to {shape}" + elif dst_bits < src_bits: + # Corrected 'trunf' to 'truncf' + op_str = f"arith.truncf %{operand} : {src_shape} to {shape}" + else: + return operand, [tile_size, dst_mlir_dtype] + else: + raise NotImplementedError(f"Unsupported conversion: {src_mlir_dtype} -> {dst_mlir_dtype}") + + return op_str, [tile_size, dst_mlir_dtype] + + @staticmethod + def identity(operand, *args, var_info=None, **kwargs): + operand_info = var_info[operand] + return operand, operand_info + + @staticmethod + def to_dtype_bitcast(operand, dtype, *args, var_info=None, **kwargs): + tile_size, current_src_type = var_info[operand] + + if isinstance(dtype, torch.dtype): + dst_mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] + else: + dst_mlir_type = dtype + + src_bits = mlir_common.MLIR_TO_BIT[current_src_type] + dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_type] + + if src_bits != dst_bits: + raise ValueError( + f"Bitcast failed: Bit width mismatch. " + f"Src: {current_src_type}({src_bits}b) != Dst: {dst_mlir_type}({dst_bits}b)" + ) + + src_shape = f"vector<{tile_size}x{current_src_type}>" if tile_size > 1 else current_src_type + dst_shape = f"vector<{tile_size}x{dst_mlir_type}>" if tile_size > 1 else dst_mlir_type + + return f"arith.bitcast %{operand} : {src_shape} to {dst_shape}", [tile_size, dst_mlir_type] + + # Binary element wise operations + @staticmethod + def binary_elementwise_common(operand1, operand2, var_info): + operand1.bounds = operand1.bounds.unknown() + operand2.bounds = operand2.bounds.unknown() + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + # Tile size check + if op_type1[0] != op_type2[0]: + # Try to broad cast + lhs_tile_size, lhs_dtype = op_type1 + rhs_tile_size, rhs_dtype = op_type2 + if lhs_tile_size > rhs_tile_size: + operand2 = ops.broadcast(operand2, lhs_tile_size) + op_type2 = var_info[operand2] + elif lhs_tile_size < rhs_tile_size: + operand1 = ops.broadcast(operand1, rhs_tile_size) + op_type1 = var_info[operand1] + + # Data type check + if op_type1[1] != op_type2[1]: + if op_type1[1] == "index" or op_type1 == "index": + if op_type1[1] == "index": + operand1 = ops.index_cast(operand1, op_type2[1]) + op_type1 = var_info[operand1] + if op_type2[1] == "index": + operand2 = ops.index_cast(operand2, op_type1[1]) + op_type2 = var_info[operand2] + elif op_type1[1][0] == "i" and op_type2[1][0] == "f": + operand1 = ops.to_dtype(operand1, op_type2[1]) + op_type1 = var_info[operand1] + elif op_type1[1][0] == "f" and op_type2[1][0] == "i": + operand2 = ops.to_dtype(operand2, op_type1[1]) + op_type2 = var_info[operand2] + elif op_type1[1][0] == op_type2[1][0]: + if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: + operand2 = ops.ext(operand2, op_type1[1]) + op_type2 = var_info[operand2] + elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: + operand1 = ops.ext(operand1, op_type2[1]) + op_type1 = var_info[operand1] + else: + raise NotImplementedError("Unsupported type converting") + + # Updated var info + tile_size = op_type1[0] + ret_type = op_type1[1] + return tile_size, ret_type, operand1, operand2 + + @staticmethod + def abs(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def exp(operand, *args, var_info=None, **kwargs): + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.exp(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.exp %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def exp2(operand, *args, var_info=None, **kwargs): + # Hands-on part: implement exp2 using math.exp2 + # var_info = {operand: [tile_size, dtype]} + # Ex) var_info[operand] = [8, "f32"] + + ln2 = math.log(2) + coeff = ops.constant(ln2, "f32") + operand = ops.mul(operand, coeff) + return ops.exp(operand), var_info[operand] + + @staticmethod + def expm1(operand, *args, var_info=None, **kwargs): + coeff = ops.constant(1.0, "f32") + operand = ops.exp(operand) + operand = ops.sub(operand, coeff) + return operand, var_info[operand] + + @staticmethod + def sqrt(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def relu(operand, *args, var_info=None, **kwargs): + src_mlir_dtype = var_info[operand][1] + tile_size = var_info[operand][0] + return ops.maximum(operand, ops.constant(0, src_mlir_dtype)), [tile_size, src_mlir_dtype] + + @staticmethod + def minimum(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.minimumf' + else: + opcode = f'arith.minsi' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def maximum(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.maximumf' + else: + opcode = f'arith.maxsi' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def cos(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.cos(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.cos %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def sin(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.sin(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.sin %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def tan(operand, *args, var_info=None, **kwargs): + sin_res = ops.sin(operand) + cos_res = ops.cos(operand) + operand = ops.truediv(sin_res, cos_res) + return operand, var_info[operand] + + @staticmethod + def lgamma(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def erf(operand, *args, var_info=None, **kwargs): + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.erf(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.erf %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def cosh(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def sinh(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def tanh(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.tanh(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.tanh %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def acos(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def acosh(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def asin(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def asinh(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def atan2(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def atan(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def atanh(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def copysign(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def erfc(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def erfinv(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def frexp(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def hypot(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def log10(operand, *args, var_info=None, **kwargs): + val_ln = ops.log(operand) + + tile_size, dtype = var_info[val_ln] + inv_ln10 = 1/math.log(10) + const_op = ops.constant(inv_ln10, dtype) + + # Multiply: ln(x) * (1/ln(10)) + result = ops.mul(val_ln, const_op) + return result, var_info[result] + + @staticmethod + def log2(operand, *args, var_info=None, **kwargs): + val_ln = ops.log(operand) + + tile_size, dtype = var_info[val_ln] + inv_ln10 = 1/math.log(2) + const_op = ops.constant(inv_ln10, dtype) + + # Multiply: ln(x) * (1/ln(10)) + result = ops.mul(val_ln, const_op) + return result, var_info[result] + + @staticmethod + def log(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.log %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def log1p(operand, *args, var_info=None, **kwargs): + tile_size, dtype = var_info[operand] + const_one = ops.constant(1, dtype) + + # 3. 덧셈 연산: (x + 1) + # ops.add가 (result_ssa, result_info)를 반환한다고 가정 + val_add = ops.add(operand, const_one) + result = ops.log(val_add) + return result, var_info[result] + + @staticmethod + def nextafter(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def logical_and(operand1, operand2, *args, var_info=None, **kwargs): + if var_info[operand1][1] != "i1": + operand1 = ops.to_bool(operand1) + + if var_info[operand2][1] != "i1": + operand2 = ops.to_bool(operand2) + result = ops.and_(operand1, operand2) + return result, var_info[result] + + @staticmethod + def logical_or(operand1, operand2, *args, var_info=None, **kwargs): + if var_info[operand1][1] != "i1": + operand1 = ops.to_bool(operand1) + + if var_info[operand2][1] != "i1": + operand2 = ops.to_bool(operand2) + result = ops.or_(operand1, operand2) + return result, var_info[result] + + @staticmethod + def logical_xor(operand1, operand2, *args, var_info=None, **kwargs): + if var_info[operand1][1] != "i1": + operand1 = ops.to_bool(operand1) + + if var_info[operand2][1] != "i1": + operand2 = ops.to_bool(operand2) + result = ops.xor(operand1, operand2) + return result, var_info[result] + + @staticmethod + def logical_not(operand, *args, var_info=None, **kwargs): + op_info = var_info[operand] + tile_size = op_info[0] + dtype = op_info[1] + + zero_const = ops.constant(0, dtype) + result = ops.eq(operand, zero_const) + return result, var_info[result] + + @staticmethod + def bitwise_and(operand1, operand2, *args, var_info=None, **kwargs): + # Float check + if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + raise ValueError("Bitwise AND not supported for floats") + + result = ops.and_(operand1, operand2) + return result, var_info[result] + + @staticmethod + def bitwise_not(operand, *args, var_info=None, **kwargs): + tile_size, dtype = var_info[operand] + # Float check + if var_info[operand][1].startswith("f"): + raise ValueError("Bitwise NOT not supported for floats") + + neg_one = ops.constant(-1, dtype) + result = ops.xor(operand, neg_one) + return result, var_info[result] + + @staticmethod + def bitwise_or(operand1, operand2, *args, var_info=None, **kwargs): + # Float check + if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + raise ValueError("Bitwise AND not supported for floats") + + result = ops.or_(operand1, operand2) + return result, var_info[result] + + @staticmethod + def bitwise_xor(operand1, operand2, *args, var_info=None, **kwargs): + # Float check + if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + raise ValueError("Bitwise AND not supported for floats") + + result = ops.xor(operand1, operand2) + return result, var_info[result] + + @staticmethod + def bitwise_left_shift(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def bitwise_right_shift(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def rsqrt(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def sigmoid(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + one = ops.constant(1, dtype) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, dtype] + + @staticmethod + def fmod(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def isinf(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def isnan(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def round(operand, *args, var_info=None, **kwargs): + tile_size, dtype = var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + return f"math.roundeven %{operand} : {shape}", [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + @staticmethod + def floor(operand, *args, var_info=None, **kwargs): + tile_size, dtype = var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + return f"math.floor %{operand} : {shape}", [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + @staticmethod + def sign(operand, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def trunc(operand, *args, var_info=None, **kwargs): + tile_size, dtype = var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + return f"math.trunc %{operand} : {shape}", [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + @staticmethod + def ceil(operand, *args, var_info=None, **kwargs): + tile_size, dtype = var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + return f"math.ceil %{operand} : {shape}", [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + # Logical operations + @staticmethod + def neg(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'arith.negf %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def reciprocal(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + return ops.truediv(ops.constant(1.0, dtype), operand), [tile_size, dtype] + + @staticmethod + def eq(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oeq" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "eq" + else: + raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def ne(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "one" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "ne" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def lt(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "olt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "slt" + else: + raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def gt(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ogt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sgt" + else: + raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def le(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ole" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sle" + else: + raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def ge(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oge" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sge" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def add(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def sub(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.sub{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def mul(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.mul{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def pow(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + # Type check & auto cast + if ret_type.startswith("f"): + operand1 = ops.to_dtype(operand1, "f32") + + # Type check & auto cast + if ret_type.startswith("f"): + operand2 = ops.to_dtype(operand2, "f32") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f"math.pow{ret_type[0]} %{operand1}, %{operand2} : {shape}", [tile_size, ret_type] + + @staticmethod + def and_(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def or_(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def xor(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def lshift(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def rshift(operand1, operand2, *args, var_info=None, **kwargs): + raise NotImplementedError + + @staticmethod + def truncdiv(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if ret_type.startswith("f"): + raise ValueError("truncdiv is strictly for integers. Use truediv for floats.") + + # arith.divsi: Signed Integer Division (Result is truncated) + return f'arith.divsi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def floordiv(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if ret_type.startswith("f"): + # Float의 floor division은 보통 divf 후 floor를 하므로 여기선 정수만 처리 + raise ValueError("floordiv implementation expects integers based on definition.") + + # arith.floordivsi: Floor Division for Signed Integers + return f'arith.floordivsi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def truediv(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if not ret_type.startswith("f"): + raise ValueError(f"truediv expects float inputs, but got {ret_type}. Use int_truediv for integers.") + + return f'arith.divf %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def int_truediv(operand1, operand2, *args, var_info=None, **kwargs): + """ + True division for Integers (Int -> Float). + Promotes integers to floats, then performs floating-point division. + """ + tile_size, src_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if not src_type.startswith("f"): + target_float_type = "f32" + operand1 = ops.to_dtype(operand1, target_float_type) + operand2 = ops.to_dtype(operand2, target_float_type) + src_type = target_float_type + + result = ops.truediv(operand1, operand2) + return result, var_info[result] + + @staticmethod + def mod(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + raise NotImplementedError("Not support remainder operation for floating point") + else: + opcode = f'arith.remsi' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def remainder(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if ret_type.startswith("f"): + opcode = 'arith.remf' + else: + opcode = 'arith.remsi' # Signed Integer Remainder (LHS sign) + + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def square(operand, *args, var_info=None, **kwargs): + result = ops.mul(operand, operand) + return result, var_info[result] + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # PyTorchSim specific operations + + @staticmethod + def alloc(size, src_type, *args, var_info=None, **kwargs): + return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] + + @staticmethod + def extractelement(operand, idx, *args, var_info=None, **kwargs): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f"vector.extract %{operand}[{idx}]: {dtype} from {shape}", [1, dtype] + + @staticmethod + def ext(operand, dtype, *args, var_info=None, **kwargs): + op_type = var_info[operand] + shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" + target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" + if op_type[0] == "f": + opcode = f'arith.extf' + else: + opcode = f'arith.extui' + return f'{opcode} %{operand} : {shape} to {target_type}', [op_type[0], dtype] + + @staticmethod + def to_bool(operand, *args, var_info=None, **kwargs): + tile_size, ret_type = var_info[operand] + if ret_type == "i1": + return operand, [tile_size, ret_type] + + const_one = ops.constant(0, ret_type) + if tile_size > 1: + const_one = ops.broadcast(const_one, tile_size) + ret = ops.ne(operand, const_one) + return ret, [tile_size, "i1"] + @staticmethod + def step(size, dtype, *args, **kwargs): + index_shape = f"vector<{size}x{dtype}>" + return f"vector.step : {index_shape}", [size, dtype] + + @staticmethod + def index_cast(operand, target_type, *args, var_info=None, **kwrags): + op_type = var_info[operand] + src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] + des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type + return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] + + @staticmethod + def shape_cast(operand, src_shape, dst_shape, *args, var_info=None, **kwargs): + operand_type = var_info[operand] + return f"vector.shape_cast %{operand} : {src_shape} to {dst_shape}", operand_type + + @staticmethod + def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_name, *args, **kwargs): + if red_size == 1: + final_reduced_shape = f"{type_name}" + line = reduction_combine_vec(red_type, acc, init, axis=0, shape=red_shape, reduced_shape=final_reduced_shape) + else: + final_reduced_shape = f"vector<{red_size}x{type_name}>" + new_vshape= f"vector<{vec_size//red_size}x{red_size}x{type_name}>" + value = ops.shape_cast(acc, red_shape, new_vshape) + line = reduction_combine_vec(red_type, value, init, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape) + return line, [red_size, type_name] + + @staticmethod + def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, var_info=None, **kwargs): + if compute_vec_size == 1: + vshape = f"{mlir_dtype}" + operation = "affine.load" + line = f"{operation} %{buffer}[{indices}] : {buffer_shape}" + else: + vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" + operation = "affine.vector_load" + line = f"{operation} %{buffer}[{indices}] : {buffer_shape}, {vshape}" + return line, [compute_vec_size, mlir_dtype] + + @staticmethod + def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, var_info=None, **kwargs): + compute_vec_size, mlir_dtype = var_info[operand][0], var_info[operand][1] + + if compute_vec_size == 1: + vshape = f"{mlir_dtype}" + operation = "affine.store" + line = f"{operation} %{operand}, %{buffer}[{indices}] : {buffer_shape}" + else: + vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" + operation = "affine.vector_store" + line = f"{operation} %{operand}, %{buffer}[{indices}] : {buffer_shape}, {vshape}" + + if buffer_name is not None: + return common.DeferredLine(buffer_name, line), [None, None] + else: + return line, [None, None] \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index e493464a..a36bc907 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -25,7 +25,7 @@ import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, is_welford_reduction from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode from torch._inductor.codegen import common @@ -85,7 +85,8 @@ def as_local(self): } try: self.set_buffers() - yield self + with self.kernel.override_buffer_cse(buffer=self.compute, cse=self.cse): + yield self finally: self.restore_buffers() @@ -822,7 +823,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") attribute = " {" + ", ".join(attribute_parts) + "}" code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, "") + dram_shape, tile_shape, "") local_code.writeline(code) local_code.writeline(attribute) return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() @@ -885,28 +886,18 @@ def load_epilogue(self, name: str, index: sympy.Expr): zero_var = self.get_const_cse(0) if not self.reduction_fusion: compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) - if compute_vec_size > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - out = self.cse.generate(self.loads, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) else: # For reduction case reduce_size = self.reduction_nr_outer_loop vsize = compute_vec_size//reduce_size - vshape = f"vector<{vsize}x{mlir_dtype}>" if compute_vec_size > 1: offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - out = self.cse.generate(self.loads, line) - else: - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - out = self.cse.generate(self.loads, line) + + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(vsize, mlir_dtype, sram_var, compute_index_var, tile_shape) self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) return out @@ -924,10 +915,6 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.kernel_group.tile_desc.get_tile_stride() - # Compute vector unit size - vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) - compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() - if name not in self.buffer_names: sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) self.buffer_names[name] = sram_var @@ -945,14 +932,9 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) # Generate vector load instruction - if compute_vec_size > 1: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - line = line if store_force else DeferredLine(name, line) - self.stores.writeline(line) + buffer_name = name if not store_force else None + with self.override_buffer_cse(buffer=self.stores): + ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=buffer_name) # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" @@ -991,6 +973,7 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): tile_shape = local_tile_desc.get_mlir_shape(type_name) vshape = local_tile_desc.get_mlir_vshape(type_name) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" self.reduction_buffer_idx += 1 @@ -1002,24 +985,21 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): zero_var_list = [f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim() zero_var_list[-2] = f"%{self.reduction_loop_idx}" compute_index_var = ", ".join(zero_var_list) - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - out = self.cse.generate(self.loads, line) - self.register_var_info(out, [self.compute_body_loop.step, type_name]) + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(vec_size, type_name, sram_var, compute_index_var, tile_shape) # Reduction body codegen - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") - self.register_var_info(init_vec, [local_tile_desc.get_compute_vec_size(), type_name]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + init = ops.constant(reduction_init(reduction_type, dtype), type_name) + init_vec = ops.broadcast(init, compute_vec_size) + mask_shape, mask_var = self.get_mask() if mask_var is not None: value = ops.where(mask_var, value, init_vec) result = reduction_partial_combine_vec(reduction_type, value, out) # Store partial result - operation = "affine.vector_store" - line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - self.compute.writeline(line) # Need to be placed after partial reduction + ops._store(result, sram_var, compute_index_var, tile_shape) # Need to be placed after partial reduction self.reduction_info[sram_var] = [reduction_type, local_tile_desc] return sram_var @@ -1050,63 +1030,59 @@ def store_reduction_epilogue(self, name, index, value): partial_tile_shape = partial_tile_desc.get_mlir_shape(mlir_dtype) # Prepare constant - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value][0], dtype)} : {mlir_dtype}") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + init = ops.constant(reduction_init(self.reduction_info[value][0], dtype), mlir_dtype) + init_vec = ops.broadcast(init, partial_vec_size) + init_vec2 = ops.broadcast(init, 2) + partial_zero_var_list = [f"%{self.get_const_cse(0)}"] * partial_tile_desc.get_nr_dim() final_zero_var_list = [f"%{self.get_const_cse(0)}"] * final_tile_desc.get_nr_dim() for i in range(self.reduction_body_loop.size): # Load partial result - body_index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") - partial_zero_var_list[-2] = f"%{body_index_var}" - compute_index_var = ",".join(partial_zero_var_list) - - operation = "affine.vector_load" - line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" - out = self.cse.generate(self.reductions_suffix, line) - operation = "affine.vector_store" - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {partial_vshape}") - line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" - self.reductions_suffix.writeline(line) - - # 2 step reduction - new_vec_size = 2 - new_vshape = f"vector<{partial_vec_size//new_vec_size}x{new_vec_size}x{mlir_dtype}>" - new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" - out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {partial_vshape} to {new_vshape}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {new_reduced_shape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value][0], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) - out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + body_index_var = ops.constant(i, "index") + partial_zero_var_list[-2] = f"%{body_index_var}" + compute_index_var = ",".join(partial_zero_var_list) + + with self.override_buffer_cse(buffer=self.reductions_suffix): + out = ops._load(partial_vec_size, mlir_dtype, value, compute_index_var, partial_tile_shape) + ops._store(init_vec, value, compute_index_var, partial_tile_shape) # Clear the partial buffer to zero - self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - self.register_var_info(out, [new_vec_size, mlir_dtype]) + # 2 step reduction + new_vec_size = 2 + new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" + reduction_type = self.reduction_info[value][0] + out = ops.multi_reduction(out, init_vec2, partial_vec_size, new_vec_size, partial_vshape, reduction_type, mlir_dtype) + + out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") self.register_var_info(out2, [new_vec_size, mlir_dtype]) - out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) - self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - - if self.welford_reduce_out is not None: - # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.r_dim_size)} : f32") - if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") - else: - divider_vec = divider - if self.current_node.node.origin_node: # FIXME: This is a temporary solution - # mean = SUM(X) / N - self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) - out = self.reduction_mean[i] - else: - # m2 = (E(X^2) - E(X)^2) * N - sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") - mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") - variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") - out = m2 - - final_zero_var_list[-1] = f"%{body_index_var}" - final_compute_index_var = ",".join(final_zero_var_list) - operation = "affine.vector_store" - line = f"{operation} %{out}, %{sram_var}[{final_compute_index_var}] : {final_tile_shape}, {new_reduced_shape}" - self.reductions_suffix.writeline(DeferredLine(name, line)) + with self.override_buffer_cse(buffer=self.reductions_suffix): + out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) + + if self.welford_reduce_out is not None: + # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 + divider = ops.constant(float(self.r_dim_size), "f32") + if self.buffer_types[name][1] > 1: + divider_vec = ops.broadcast(divider, new_vec_size) + else: + divider_vec = divider + + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + # mean = SUM(X) / N + self.reduction_mean.append(ops.truediv(out, divider_vec)) + out = self.reduction_mean[i] + else: + # m2 = (E(X^2) - E(X)^2) * N + sqr_mean = ops.truediv(out, divider_vec) + mean_sqr = ops.mul(self.reduction_mean[i], self.reduction_mean[i]) + variance = ops.sub(sqr_mean, mean_sqr) + m2 = ops.mul(variance, divider_vec) + out = m2 + + final_zero_var_list[-1] = f"%{body_index_var}" + final_compute_index_var = ",".join(final_zero_var_list) + ops._store(out, sram_var, final_compute_index_var, final_tile_shape, buffer_name=name) # MVOUT Encoding # Generate DMA instruction diff --git a/README.md b/README.md index 103131c1..4d98baa4 100644 --- a/README.md +++ b/README.md @@ -220,7 +220,7 @@ Our load generator supports multi-tenancy experiments. You can run a simple exam python tests/test_scheduler.py ``` Below is an example code of multi-tenancy `resnet18` and `EncoderBlock`. -In this example, the `Scheduler` is initialized with a number of request queues, a scheduling policy, and a TOGSimulator config file(`.json`). The compiled PyTorch models are then registered with a unique model id. +In this example, the `Scheduler` is initialized with a number of request queues, a scheduling policy, and a TOGSimulator config file(`.yml`). The compiled PyTorch models are then registered with a unique model id. ```python3 import os @@ -228,7 +228,7 @@ import sys import torch from torchvision.models import resnet18 base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') -config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json' +config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml' sys.path.append(base_path) from tests.test_transformer import EncoderBlock @@ -244,7 +244,7 @@ SchedulerDNNModel.register_model("model0", opt_model0) SchedulerDNNModel.register_model("model1", opt_model1) ``` -The config file(`.json`) specifies two key items: +The config file(`.yml`) specifies two key items: - `num_partition`: The total number of independent request queues to create. - `partition`: Defines the hardware mapping, assigning each queue (identified by its index) to a specific physical core. For example, the configuration below creates two scheduling queues (`0` and `1`) and maps `core_0` to queue `0` and `core_1` to queue `1`: @@ -415,7 +415,7 @@ export TORCHSIM_USE_TIMING_POOLING=0 # use lightweight pooling for timing ``` You can set TOGSim config path as below. ```bash -export TORCHSIM_CONFIG=/workspace/PyTorchSim/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +export TORCHSIM_CONFIG=/workspace/PyTorchSim/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml ``` ## Future Works Currently, PyTorchSim supports PyTorch 2.2. Support for newer versions will be added soon. diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index ffe8e4fc..98ebb1d5 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -179,6 +179,7 @@ def setup_device(): ) torch.utils.rename_privateuse1_backend("npu") + torch._register_device_module("npu", module) from torch._inductor.codegen.common import ( get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -357,6 +358,7 @@ def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE, togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") self.tog_simulator = TOGSimulator(togsim_path, togsim_config) + os.environ['TOGSIM_CONFIG'] = togsim_config self.tog_simulator.interactive_simulation() if engine_select == Scheduler.FIFO_ENGINE: self.execution_engine = FIFORunner(self.tog_simulator, self.num_request_queue) diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 322d9b12..672ae6ec 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -4,11 +4,12 @@ import subprocess import re import sys -import json +import yaml import time import datetime import threading from pathlib import Path +import uuid import torch import numpy as np @@ -16,6 +17,8 @@ from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend import extension_config +print_lock = threading.Lock() + TORCH_TO_NUMPY = { torch.float32: np.float32, torch.float64: np.float64, @@ -53,7 +56,7 @@ def write_arg(self, arg, path, name): tensor = arg.cpu().detach() buffer_size = tensor.untyped_storage().size() buffer = (ctypes.c_char * buffer_size).from_address(tensor.data_ptr()) - t_arr = np.frombuffer(buffer, dtype=tensor.numpy().dtype, count=buffer_size // tensor.element_size()) + t_arr = np.frombuffer(buffer, dtype=TORCH_TO_NUMPY[tensor.dtype], count=buffer_size // tensor.element_size()) t_arr.tofile(data_path) else: assert(0) @@ -157,9 +160,12 @@ def show_progress(): while not finished: i = (i + 1) % 3 tail = "." * i + " " * (3-i) - sys.stdout.write("\r[Gem5] Gem5 is running." + tail) + with print_lock: + sys.stdout.write("\r[Gem5] Gem5 is running." + tail) + sys.stdout.flush() time.sleep(1) - print("") + with print_lock: + print("") dir_path = os.path.join(os.path.dirname(target_binary), "m5out") gem5_script_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "gem5_script/script_systolic.py") @@ -199,7 +205,7 @@ class TOGSimulator(): def __init__(self, togsim_path, config_path, vectorlane_size=-1) -> None: self.base_dir = togsim_path self.config_path = config_path - self.config_json = self.load_json(self.config_path) + self.config_yaml = self.load_yaml(self.config_path) self.process = None self.vectorlane_size = vectorlane_size @@ -209,7 +215,7 @@ def get_togsim_command(self): cmd = f"{bin} --config {config}" return cmd - def simulation(self, model_path, attribute_path="", silent_mode=False): + def simulation(self, model_path, attribute_path="", silent_mode=False, autotune_mode=False): def show_progress(): i = 0 while not finished: @@ -240,19 +246,35 @@ def show_progress(): if not silent_mode: finished = True progress_thread.join() - print("[TOGSim] Command failed with exit code", e.returncode) - print("[TOGSim] Error output:", e.output) + with print_lock: + print("[TOGSim] Command failed with exit code", e.returncode) + print("[TOGSim] Error output:", e.output) assert 0 - # Save result to result_path - result_path = extension_config.CONFIG_TORCHSIM_LOG_PATH - os.makedirs(result_path, exist_ok=True) - file_name = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+".log" - result_path = os.path.join(result_path, file_name) + + # Separate Autotune logs + if autotune_mode: + base_dir = Path(model_path).parent / "togsim_result" + base_dir.mkdir(parents=True, exist_ok=True) + file_name = f"{len(list(base_dir.iterdir()))}.log" + else: + base_dir = Path(extension_config.CONFIG_TORCHSIM_LOG_PATH) + unique_id = uuid.uuid4().hex[:8] + timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + file_name = f"{unique_id}_{timestamp}.log" + + base_dir.mkdir(parents=True, exist_ok=True) + result_path = base_dir / file_name + + # Prevent race condition with open(result_path, "w") as f: f.write(result.decode()) + f.flush() + os.fsync(f.fileno()) + if not silent_mode or extension_config.CONFIG_DEBUG_MODE: model_path_log = f' of "{model_path}" ' if extension_config.CONFIG_DEBUG_MODE else " " - print(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') + with print_lock: + print(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') return result_path def interactive_simulation(self): @@ -342,40 +364,41 @@ def sram_dealloc(cls, buf_name, addr_range): def create_attribute_file(self, attribute_path, inputs, **kwargs): address_info = {} sram_buffer = {} - json_content = {} + yaml_content = {} + os.makedirs(attribute_path, exist_ok=True) index = str(len(os.listdir(attribute_path))) attribute_path = os.path.join(attribute_path, index) for idx, tensor in enumerate(inputs): address_info[f"arg{idx}"] = tensor.data_ptr() - json_content["address_info"] = address_info + yaml_content["address_info"] = address_info for buf_name, range in self.ALLOC_POOL.items(): sram_buffer[buf_name] = range - json_content["sram_alloc"] = sram_buffer + yaml_content["sram_alloc"] = sram_buffer with open(attribute_path, "w") as f: - json.dump(json_content, f, indent=4) + yaml.dump(yaml_content, f, default_flow_style=False) f.flush() os.fsync(f.fileno()) # There could be a race condition. return attribute_path - def load_json(self, config_path): + def load_yaml(self, config_path): config_path = Path(config_path) if not config_path.is_file(): - raise FileNotFoundError(f"JSON file not found: {config_path}") + raise FileNotFoundError(f"YAML file not found: {config_path}") try: with open(config_path, "r") as file: - data = json.load(file) + data = yaml.safe_load(file) return data - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON format: {e}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML format: {e}") def get_core_freq(self): - if "core_freq_mhz" in self.config_json: - return self.config_json["core_freq_mhz"] * 1000 * 1000 # MHz + if "core_freq_mhz" in self.config_yaml: + return self.config_yaml["core_freq_mhz"] * 1000 * 1000 # MHz else: raise KeyError("Key 'core_freq' not found in JSON.") @@ -400,9 +423,9 @@ def find_zero_sub_tensors(self, tensor): def get_result_from_file(result_path): core_metrics = {} dram_channel_bw = {} - avg_dram_bw = None - simulation_time = None - total_cycle = None + avg_dram_bw = 0.0 + simulation_time = float("inf") + total_cycle = float("inf") # Read and find total stat position with open(result_path, "r") as f: @@ -417,7 +440,7 @@ def get_result_from_file(result_path): break if simulation_finished_idx == -1: - print("[TOGSim] Tried to parsing wrong formated output file!") + print(f"[TOGSim] Warning: Unable to parse the output file ({result_path}). The file may be improperly formatted.") return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time total_stat_lines = lines[simulation_finished_idx:] @@ -457,6 +480,6 @@ def get_result_from_file(result_path): return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time, total_cycle if __name__ == "__main__": - sim = TOGSimulator("/workspace/PyTorchSim/TOGSim", "/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json") + sim = TOGSimulator("/workspace/PyTorchSim/TOGSim", "/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml") sim.interactive_simulation() sim.until(4000) \ No newline at end of file diff --git a/TOGSim/conanfile.txt b/TOGSim/conanfile.txt index 7a57f52f..ce5268c7 100644 --- a/TOGSim/conanfile.txt +++ b/TOGSim/conanfile.txt @@ -2,6 +2,6 @@ boost/1.79.0 robin-hood-hashing/3.11.5 spdlog/1.11.0 -nlohmann_json/3.11.2 +yaml-cpp/0.8.0 [generators] cmake diff --git a/TOGSim/include/Common.h b/TOGSim/include/Common.h index 640cba0c..2fd62681 100644 --- a/TOGSim/include/Common.h +++ b/TOGSim/include/Common.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -14,7 +15,6 @@ #include "SimulationConfig.h" #include "Instruction.h" -#include "nlohmann/json.hpp" #define MIN(x, y) (((x) > (y)) ? (y) : (x)) #define MIN3(x, y, z) MIN(MIN(x, y), z) @@ -24,10 +24,8 @@ #define PAGE_SIZE 4096 -using json = nlohmann::json; - typedef uint64_t addr_type; typedef uint64_t cycle_type; -uint32_t generate_id(); -SimulationConfig initialize_config(json config); \ No newline at end of file +bool loadConfig(const std::string& config_path, YAML::Node& config_yaml); +SimulationConfig initialize_config(YAML::Node config); \ No newline at end of file diff --git a/TOGSim/include/SimulationConfig.h b/TOGSim/include/SimulationConfig.h index 64cfa223..090f5520 100644 --- a/TOGSim/include/SimulationConfig.h +++ b/TOGSim/include/SimulationConfig.h @@ -1,13 +1,11 @@ #pragma once -#include #include - -using json = nlohmann::json; +#include enum class CoreType { WS_MESH, STONNE }; -enum class DramType { SIMPLE, RAMULATOR1, RAMULATOR2 }; +enum class DramType { SIMPLE, RAMULATOR2 }; enum class IcntType { SIMPLE, BOOKSIM2 }; diff --git a/TOGSim/include/SparseCore.h b/TOGSim/include/SparseCore.h index 9188b21d..02781ab3 100644 --- a/TOGSim/include/SparseCore.h +++ b/TOGSim/include/SparseCore.h @@ -1,5 +1,6 @@ #include #include +#include #include "Core.h" #include "sstStonne.h" #include "SimpleMem.h" diff --git a/TOGSim/include/TileGraphParser.h b/TOGSim/include/TileGraphParser.h index 9cc61d4a..9c176966 100644 --- a/TOGSim/include/TileGraphParser.h +++ b/TOGSim/include/TileGraphParser.h @@ -2,19 +2,18 @@ #include #include #include -#include +#include #include #include #include "TileGraph.h" #include "Instruction.h" #include "sstStonne.h" #include "IntervalTree.h" +#include "Common.h" #include "onnx/defs/schema.h" #include "onnx/onnx-operators_pb.h" #include "onnx/onnx_pb.h" -using json = nlohmann::json; - enum class TileType{ LOOP_INDEX_NODE, LOOP_END_NODE, @@ -35,8 +34,6 @@ enum class LoopType { INNER_LOOP }; -bool loadConfig(const std::string& config_path, json& config_json); - class TileNode { public: TileNode(onnx::NodeProto& node); @@ -80,9 +77,9 @@ class TileGraphParser { LoopType get_loop_type(std::string key) { return std::get<2>(_loop_size_map[key]); } const std::map> & get_loop_map() { return _loop_size_map; } const std::vector &lookupNumaInfo(std::string key); - int getCoreIdFromJson(const json& attribute_json, int subgraph_id); + int getCoreIdFromConfig(const YAML::Node& attribute_config, int subgraph_id); std::string getMetaByName(std::string key) { return _tog_meta[key]; } - const json& get_attribute_file() { return _attribute_json; } + const YAML::Node& get_attribute_file() { return _attribute_config; } std::vector calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride); void register_memory_tag(std::string name, std::vector& tag_key); bool check_memory_tag(std::string name, std::vector& tag_key); @@ -135,8 +132,8 @@ class TileGraphParser { void _tile_index_generate() {} int _loop_stack_pointer = 0; - json _attribute_json; - json _config_json; + YAML::Node _attribute_config; + YAML::Node _config_yaml; std::string _tog_path; std::string _attribute_path; uint64_t indirect_counter = 0; diff --git a/TOGSim/src/Common.cc b/TOGSim/src/Common.cc index 9a6b7798..b15381a6 100644 --- a/TOGSim/src/Common.cc +++ b/TOGSim/src/Common.cc @@ -1,28 +1,41 @@ #include "Common.h" -uint32_t generate_id() { - static uint32_t id_counter{0}; - return id_counter++; +bool loadConfig(const std::string& config_path, YAML::Node& config_yaml) { + try { + config_yaml = YAML::LoadFile(config_path); + spdlog::info("[LoadConfig] Success to open \"{}\"", config_path); + return true; + } catch (const YAML::BadFile& e) { + spdlog::error("[LoadConfig] Failed to open \"{}\" (File not found or inaccessible)", config_path); + return false; + } catch (const YAML::ParserException& e) { + spdlog::error("[LoadConfig] Failed to parse YAML file \"{}\": {}", config_path, e.what()); + return false; + } catch (const std::exception& e) { + spdlog::error("[LoadConfig] Unknown error loading \"{}\": {}", config_path, e.what()); + return false; + } } template -T get_config_value(json config, std::string key) { - if (config.contains(key)) { - return config[key]; +T get_config_value(const YAML::Node& config, std::string key) { + if (config[key]) { + return config[key].as(); } else { throw std::runtime_error(fmt::format("Config key {} not found", key)); } } -SimulationConfig initialize_config(json config) { +SimulationConfig initialize_config(YAML::Node config) { SimulationConfig parsed_config; - // print json - spdlog::info("TOGSim Config: {}", config.dump(2)); + YAML::Emitter emitter; + emitter << config; + spdlog::info("PyTorchSim config:\n{}", emitter.c_str()); /* Core configs */ - parsed_config.num_cores = config["num_cores"]; - if (config.contains("core_type")) { - std::vector core_types = config["core_type"].get>(); + parsed_config.num_cores = get_config_value(config, "num_cores"); + if (config["core_type"]) { + std::vector core_types = config["core_type"].as>(); if (core_types.size() != parsed_config.num_cores) throw std::runtime_error("Mismatch between num_cores and core_type list size"); @@ -41,100 +54,105 @@ SimulationConfig initialize_config(json config) { for (int i=0; i(config, "core_freq_mhz"); + if (config["num_systolic_array_per_core"]) + parsed_config.num_systolic_array_per_core = config["num_systolic_array_per_core"].as(); + if (config["num_stonne_per_core"]) + parsed_config.num_stonne_per_core = config["num_stonne_per_core"].as(); + if (config["num_stonne_port"]) + parsed_config.num_stonne_port = config["num_stonne_port"].as(); parsed_config.core_print_interval = get_config_value(config, "core_stats_print_period_cycles"); - /* Stonne config */ - if (config.contains("stonne_config_path")) - parsed_config.stonne_config_path = config["stonne_config_path"]; + /* Stonne config */ + if (config["stonne_config_path"]) + parsed_config.stonne_config_path = config["stonne_config_path"].as(); /* DRAM config */ - if ((std::string)config["dram_type"] == "simple") + std::string dram_type_str = get_config_value(config, "dram_type"); + + if (dram_type_str == "simple") { parsed_config.dram_type = DramType::SIMPLE; - else if ((std::string)config["dram_type"] == "ramulator") - parsed_config.dram_type = DramType::RAMULATOR1; - else if ((std::string)config["dram_type"] == "ramulator2") + parsed_config.dram_latency = get_config_value(config, "dram_latency"); + } else if (dram_type_str == "ramulator2") { parsed_config.dram_type = DramType::RAMULATOR2; - else - throw std::runtime_error(fmt::format("Not implemented dram type {} ", - (std::string)config["dram_type"])); - parsed_config.dram_freq_mhz = config["dram_freq_mhz"]; - if (config.contains("dram_latency")) - parsed_config.dram_latency = config["dram_latency"]; - if (config.contains("ramulator_config_path")) - parsed_config.dram_config_path = config["ramulator_config_path"]; - parsed_config.dram_channels = config["dram_channels"]; - if (config.contains("dram_req_size_byte")) - parsed_config.dram_req_size = config["dram_req_size_byte"]; - if (config.contains("dram_stats_print_period_cycles")) - parsed_config.dram_print_interval = config["dram_stats_print_period_cycles"]; - if(config.contains("dram_num_burst_length")) - parsed_config.dram_nbl = config["dram_num_burst_length"]; - if (config.contains("dram_num_partitions")) { - parsed_config.dram_num_partitions = config["dram_num_partitions"]; + parsed_config.dram_config_path = get_config_value(config, "ramulator_config_path"); + } else { + throw std::runtime_error(fmt::format("Not implemented dram type {} ", dram_type_str)); + } + + parsed_config.dram_freq_mhz = get_config_value(config, "dram_freq_mhz"); + parsed_config.dram_channels = get_config_value(config, "dram_channels"); + parsed_config.dram_req_size = get_config_value(config, "dram_req_size_byte"); + parsed_config.dram_nbl = get_config_value(config, "dram_num_burst_length"); + + if (config["dram_stats_print_period_cycles"]) + parsed_config.dram_print_interval = config["dram_stats_print_period_cycles"].as(); + if (config["dram_num_partitions"]) { + parsed_config.dram_num_partitions = config["dram_num_partitions"].as(); if (parsed_config.dram_channels % parsed_config.dram_num_partitions != 0) { throw std::runtime_error("[Config] DRAM channels must be divisible by dram_num_partitions"); } } - parsed_config.dram_channels_per_partitions = - parsed_config.dram_channels / parsed_config.dram_num_partitions; + if (parsed_config.dram_num_partitions != 0) { + parsed_config.dram_channels_per_partitions = + parsed_config.dram_channels / parsed_config.dram_num_partitions; + } else { + parsed_config.dram_channels_per_partitions = parsed_config.dram_channels; + } /* L2D config */ - if (config.contains("l2d_type")) { - if ((std::string)config["l2d_type"] == "nocache") + if (config["l2d_type"]) { + std::string l2d_type_str = config["l2d_type"].as(); + if (l2d_type_str == "nocache") parsed_config.l2d_type = L2CacheType::NOCACHE; - else if ((std::string)config["l2d_type"] == "datacache") + else if (l2d_type_str == "datacache") { parsed_config.l2d_type = L2CacheType::DATACACHE; - else - throw std::runtime_error(fmt::format("Not implemented l2 cache type {} ", - (std::string)config["l2d_type"])); + parsed_config.l2d_config_str = get_config_value(config, "l2d_config"); + if (config["l2d_hit_latency"]) + parsed_config.l2d_hit_latency = config["l2d_hit_latency"].as(); + } else + throw std::runtime_error(fmt::format("Not implemented l2 cache type {} ", l2d_type_str)); } else { parsed_config.l2d_type = L2CacheType::NOCACHE; } - if (config.contains("l2d_config")) - parsed_config.l2d_config_str = config["l2d_config"]; - if (config.contains("l2d_hit_latency")) - parsed_config.l2d_config_str = config["l2d_hit_latency"]; - /* Icnt config */ - if ((std::string)config["icnt_type"] == "simple") + std::string icnt_type_str = config["icnt_type"].as(); + if (icnt_type_str == "simple") { parsed_config.icnt_type = IcntType::SIMPLE; - else if ((std::string)config["icnt_type"] == "booksim2") + if (config["icnt_latency_cycles"]) + parsed_config.icnt_latency = config["icnt_latency_cycles"].as(); + } else if (icnt_type_str == "booksim2") { parsed_config.icnt_type = IcntType::BOOKSIM2; - else - throw std::runtime_error(fmt::format("Not implemented icnt type {} ", - (std::string)config["icnt_type"])); - parsed_config.icnt_freq_mhz = config["icnt_freq_mhz"]; - if (config.contains("icnt_latency_cycles")) - parsed_config.icnt_latency = config["icnt_latency_cycles"]; - if (config.contains("booksim_config_path")) - parsed_config.icnt_config_path = config["booksim_config_path"]; - if (config.contains("icnt_stats_print_period_cycles")) - parsed_config.icnt_stats_print_period_cycles = config["icnt_stats_print_period_cycles"]; - if (config.contains("icnt_injection_ports_per_core")) - parsed_config.icnt_injection_ports_per_core = config["icnt_injection_ports_per_core"]; - - if (config.contains("scheduler")) - parsed_config.scheduler_type = config["scheduler"]; - if (config.contains("num_partition")) - parsed_config.num_partition = config["num_partition"]; - if (config.contains("partition")) { + parsed_config.icnt_config_path = get_config_value(config, "booksim_config_path"); + } else + throw std::runtime_error(fmt::format("Not implemented icnt type {} ", icnt_type_str)); + + parsed_config.icnt_freq_mhz = config["icnt_freq_mhz"].as(); + if (config["icnt_stats_print_period_cycles"]) + parsed_config.icnt_stats_print_period_cycles = config["icnt_stats_print_period_cycles"].as(); + if (config["icnt_injection_ports_per_core"]) + parsed_config.icnt_injection_ports_per_core = config["icnt_injection_ports_per_core"].as(); + + if (config["scheduler"]) + parsed_config.scheduler_type = config["scheduler"].as(); + if (config["num_partition"]) + parsed_config.num_partition = config["num_partition"].as(); + if (config["partition"]) { for (int i=0; i(); + parsed_config.partiton_map[i] = partition_id; + spdlog::info("[Config/Core] CPU {}: Partition {}", i, partition_id); + } else { + spdlog::warn("[Config/Core] CPU {}: Partition key not found, defaulting to 0", i); + parsed_config.partiton_map[i] = 0; + } } } else { - /* Default: all partition 0 */ for (int i=0; i> config_json; - config_file.close(); - spdlog::info("[LoadConfig] Success to open \"{}\"", config_path); - return true; - } else { - spdlog::error("[LoadConfig] Failed to open \"{}\"", config_path); - return false; - } -} - void printIndexMap(std::string prefix, const std::map& indexMap) { std::ostringstream oss; for (const auto& [key, value] : indexMap) { @@ -87,26 +74,33 @@ bool find_output_idx(TileGraphParser* tog_parser, std::vector& output_ m = output_idx.at(0); n = output_idx.at(1); k = output_idx.at(2); + auto attr_file = tog_parser->get_attribute_file(); - auto attr_json = tog_parser->get_attribute_file(); + if (!attr_file["zero_skip"]) { + return false; + } - // Check arg0: m -> k + YAML::Node zero_skip = attr_file["zero_skip"]; bool found_arg0 = false; - if (attr_json["zero_skip"].contains("arg0")) { - auto& arg0 = attr_json["zero_skip"]["arg0"]; - if (arg0.contains(std::to_string(m)) && arg0[std::to_string(m)].contains(std::to_string(k))) { + if (zero_skip["arg0"]) { + YAML::Node arg0 = zero_skip["arg0"]; + std::string m_str = std::to_string(m); + std::string k_str = std::to_string(k); + if (arg0[m_str] && arg0[m_str][k_str]) { found_arg0 = true; } } - // Check arg1: n -> k bool found_arg1 = false; - if (attr_json["zero_skip"].contains("arg1")) { - auto& arg1 = attr_json["zero_skip"]["arg1"]; - if (arg1.contains(std::to_string(k)) && arg1[std::to_string(k)].contains(std::to_string(n))) { + if (zero_skip["arg1"]) { + YAML::Node arg1 = zero_skip["arg1"]; + std::string k_str = std::to_string(k); + std::string n_str = std::to_string(n); + if (arg1[k_str] && arg1[k_str][n_str]) { found_arg1 = true; } } + return found_arg0 || found_arg1; } @@ -692,42 +686,58 @@ void TileLoopNode::print_node() { } TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_path, std::string config_path) { - loadConfig(attribute_path, _attribute_json); - loadConfig(config_path, _config_json); + loadConfig(attribute_path, _attribute_config); + loadConfig(config_path, _config_yaml); _attribute_path = attribute_path; + if (!std::filesystem::exists(onnx_path)) { + throw std::runtime_error("Error: ONNX file not found at path: " + onnx_path); + } /* Note: this parsing algorithm assume that all node are sorted in topological-order */ std::ifstream model_istream(onnx_path); google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); onnx::ModelProto model_proto; /* Attribute parsing */ - if (_attribute_json.contains("address_info")) { - auto address_info = _attribute_json["address_info"]; - for (auto it = address_info.begin(); it != address_info.end(); ++it) { - uint64_t value = it.value(); - _arg_to_address[it.key()] = value; - spdlog::info("[TOGParser/Attribute] Address Attribute key: {} address: 0x{:x}", it.key(), value); + if (_attribute_config["address_info"]) { + const auto& address_info = _attribute_config["address_info"]; + for (YAML::const_iterator it = address_info.begin(); it != address_info.end(); ++it) { + std::string key = it->first.as(); + uint64_t value = it->second.as(); + + _arg_to_address[key] = value; + spdlog::info("[TOGParser/Attribute] Address Attribute key: {} address: 0x{:x}", key, value); } } - if (_attribute_json.contains("address_numa_stride")) { - auto address_numa_stride = _attribute_json["address_numa_stride"]; - for (auto it = address_numa_stride.begin(); it != address_numa_stride.end(); ++it) { - auto value_list = it.value(); - for (auto value : value_list) { - _arg_numa_stride[it.key()].push_back(value); + + if (_attribute_config["address_numa_stride"]) { + const auto& address_numa_stride = _attribute_config["address_numa_stride"]; + for (YAML::const_iterator it = address_numa_stride.begin(); it != address_numa_stride.end(); ++it) { + std::string key = it->first.as(); + const auto& value_list = it->second; // YAML Sequence Node + + for (const auto& val : value_list) { + _arg_numa_stride[key].push_back(val.as()); } - spdlog::info("[TOGParser/Attribute] Address numa info key: {} numa stride : {}", it.key(), fmt::join(_arg_numa_stride[it.key()], ", ")); + spdlog::info("[TOGParser/Attribute] Address numa info key: {} numa stride : {}", key, fmt::join(_arg_numa_stride[key], ", ")); } } - if (_attribute_json.contains("sram_alloc") and _config_json.contains("l2d_type") and _config_json["l2d_type"] == "datacache") { - auto sram_alloc_list = _attribute_json["sram_alloc"]; + + if (_attribute_config["sram_alloc"] && + _config_yaml["l2d_type"] && + _config_yaml["l2d_type"].as() == "datacache") { + + auto sram_alloc_list = _attribute_config["sram_alloc"]; spdlog::info("[TOGParser/Attribute] ================= SRAM Alloc Plan ================"); - for (auto it = sram_alloc_list.begin(); it != sram_alloc_list.end(); ++it) { - auto value_list = it.value(); - unsigned long long start = value_list.at(0); - unsigned long long end = value_list.at(1); - spdlog::info("[TOGParser/Attribute] {:16s}: 0x{:016x} ~ 0x{:016x}", it.key(), start, end); + + for (YAML::const_iterator it = sram_alloc_list.begin(); it != sram_alloc_list.end(); ++it) { + std::string key = it->first.as(); + const auto& value_list = it->second; // List [start, end] + + unsigned long long start = value_list[0].as(); + unsigned long long end = value_list[1].as(); + + spdlog::info("[TOGParser/Attribute] {:16s}: 0x{:016x} ~ 0x{:016x}", key, start, end); Interval entry = {start, end, 0}; _cache_plan.push_back(entry); } @@ -835,7 +845,7 @@ TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_pa /* Iterate outer loop and initialize inner loop */ for (auto iter=_tile_graph->begin(); iter!=_tile_graph->end(); ++iter) { std::shared_ptr subgraph = std::make_shared(); - subgraph->set_core_id(getCoreIdFromJson(_attribute_json, subgraph->get_id())); + subgraph->set_core_id(getCoreIdFromConfig(_attribute_config, subgraph->get_id())); auto indices = iter.get_indices(); for (auto loop : _loop_nodes.at(last_outer_idx)) { std::shared_ptr outer_loop = std::static_pointer_cast(loop); @@ -938,11 +948,12 @@ const std::vector& TileGraphParser::lookupNumaInfo(std::string key) { return _arg_numa_stride.at(key); } -int TileGraphParser::getCoreIdFromJson(const json& attribute_json, int subgraph_id) { - if (attribute_json.contains("subgraph_map")) { - const auto& subgraph_map = attribute_json["subgraph_map"]; - if (subgraph_map.contains(std::to_string(subgraph_id)) && subgraph_map[std::to_string(subgraph_id)].is_number_integer()) { - return subgraph_map[std::to_string(subgraph_id)]; +int TileGraphParser::getCoreIdFromConfig(const YAML::Node& attribute_config, int subgraph_id) { + std::string key = std::to_string(subgraph_id); + if (attribute_config["subgraph_map"]) { + const auto& subgraph_map = attribute_config["subgraph_map"]; + if (subgraph_map[key]) { + return subgraph_map[key].as(); } } return -1; diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 77c1bae7..bee1b45f 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -22,11 +22,11 @@ void launchKernel(Simulator* simulator, std::string onnx_path, std::string attri } Simulator* create_simulator(std::string config_path) { - json config_json; - if(!loadConfig(config_path, config_json)) { + YAML::Node config_yaml; + if (!loadConfig(config_path, config_yaml)) exit(1); - } - SimulationConfig config = initialize_config(config_json); + SimulationConfig config = initialize_config(config_yaml); + auto simulator = new Simulator(config); return simulator; } diff --git a/configs/heterogeneous_c2_simple_noc.json b/configs/heterogeneous_c2_simple_noc.json deleted file mode 100644 index a68f38c2..00000000 --- a/configs/heterogeneous_c2_simple_noc.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "core_type" : ["stonne", "ws_mesh"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - - "num_stonne_per_core" : 8, - "num_stonne_port" : 64, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "num_partition" : 2, - "partition": { - "core_0":0, - "core_1":1 - }, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/heterogeneous_c2_simple_noc.yml b/configs/heterogeneous_c2_simple_noc.yml new file mode 100644 index 00000000..9c596d85 --- /dev/null +++ b/configs/heterogeneous_c2_simple_noc.yml @@ -0,0 +1,37 @@ +core_type: +- stonne +- ws_mesh +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 8 +num_stonne_port: 64 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +num_partition: 2 +partition: + core_0: 0 + core_1: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/stonne_big_c1_simple_noc.json b/configs/stonne_big_c1_simple_noc.json deleted file mode 100644 index 0a8ca3c2..00000000 --- a/configs/stonne_big_c1_simple_noc.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "core_type" : ["stonne"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_stonne_per_core" : 8, - "num_stonne_port" : 64, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 8, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycless": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16 -} \ No newline at end of file diff --git a/configs/stonne_big_c1_simple_noc.yml b/configs/stonne_big_c1_simple_noc.yml new file mode 100644 index 00000000..b14838c8 --- /dev/null +++ b/configs/stonne_big_c1_simple_noc.yml @@ -0,0 +1,21 @@ +core_type: +- stonne +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 8 +num_stonne_port: 64 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 8 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycless: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 diff --git a/configs/stonne_single_c1_simple_noc.json b/configs/stonne_single_c1_simple_noc.json deleted file mode 100644 index 3421d4f1..00000000 --- a/configs/stonne_single_c1_simple_noc.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "core_type" : ["stonne"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 1, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - "num_stonne_per_core" : 1, - "num_stonne_port" : 8, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 700, - "dram_channels": 8, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 8 -} \ No newline at end of file diff --git a/configs/stonne_single_c1_simple_noc.yml b/configs/stonne_single_c1_simple_noc.yml new file mode 100644 index 00000000..0ed7962c --- /dev/null +++ b/configs/stonne_single_c1_simple_noc.yml @@ -0,0 +1,21 @@ +core_type: +- stonne +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 1 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 1 +num_stonne_port: 8 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 8 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 8 diff --git a/configs/stonne_validation_c1_simple_noc.json b/configs/stonne_validation_c1_simple_noc.json deleted file mode 100644 index fb196dfb..00000000 --- a/configs/stonne_validation_c1_simple_noc.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "core_type" : ["stonne"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 1, - "core_freq_mhz" : 1000, - "core_stats_print_period_cycles" : 10000, - "num_stonne_per_core" : 1, - "num_stonne_port" : 32, - - "dram_type" : "simple", - "dram_freq_mhz" : 1000, - "dram_channels": 1, - "dram_req_size_byte": 32, - "dram_latency" : 100, - "dram_stats_print_period_cycles": 10000, - "l2d_type" : "datacache", - "l2d_config" : "S:128:128:64,32,L:T:m:W:L,A:192:4,32:0,32", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 1000, - "icnt_injection_ports_per_core" : 8 -} \ No newline at end of file diff --git a/configs/stonne_validation_c1_simple_noc.yml b/configs/stonne_validation_c1_simple_noc.yml new file mode 100644 index 00000000..f86dcce1 --- /dev/null +++ b/configs/stonne_validation_c1_simple_noc.yml @@ -0,0 +1,22 @@ +core_type: +- stonne +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 1 +core_freq_mhz: 1000 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 1 +num_stonne_port: 32 + +dram_type: simple +dram_freq_mhz: 1000 +dram_channels: 1 +dram_req_size_byte: 32 +dram_latency: 100 +dram_stats_print_period_cycles: 10000 +l2d_type: datacache +l2d_config: S:128:128:64,32,L:T:m:W:L,A:192:4,32:0,32 + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 1000 +icnt_injection_ports_per_core: 8 diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv2.json b/configs/systolic_ws_128x128_c1_booksim_tpuv2.json deleted file mode 100644 index 686827dc..00000000 --- a/configs/systolic_ws_128x128_c1_booksim_tpuv2.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :700, - "dram_channels": 16, - "dram_req_size_byte": 32, - - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/fly_c16_m16.icnt", - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv2.yml b/configs/systolic_ws_128x128_c1_booksim_tpuv2.yml new file mode 100644 index 00000000..08149005 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_booksim_tpuv2.yml @@ -0,0 +1,26 @@ +num_cores: 1 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/fly_c16_m16.icnt + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv3.json b/configs/systolic_ws_128x128_c1_booksim_tpuv3.json deleted file mode 100644 index 1109dc0f..00000000 --- a/configs/systolic_ws_128x128_c1_booksim_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/fly_c16_m16.icnt", - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv3.yml b/configs/systolic_ws_128x128_c1_booksim_tpuv3.yml new file mode 100644 index 00000000..12304ce2 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_booksim_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/fly_c16_m16.icnt + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json deleted file mode 100644 index 22aedcf8..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 700, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycless": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.yml new file mode 100644 index 00000000..aec29ff8 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.yml @@ -0,0 +1,29 @@ +num_cores: 1 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycless: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json deleted file mode 100644 index e8e489d9..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml new file mode 100644 index 00000000..72873f1c --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json deleted file mode 100644 index 980bfc73..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 8, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml new file mode 100644 index 00000000..c2e962e3 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 8 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json deleted file mode 100644 index 02bfd75c..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 1050, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 4, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :1200, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - "l2d_type" : "datacache", - "l2d_config" : "S:128:128:512,32,L:T:m:W:L,A:192:4,32:0,32", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 1050, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.yml new file mode 100644 index 00000000..0415876d --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.yml @@ -0,0 +1,32 @@ +num_cores: 1 +core_freq_mhz: 1050 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 4 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 1200 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml +l2d_type: datacache +l2d_config: S:128:128:512,32,L:T:m:W:L,A:192:4,32:0,32 + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 1050 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3.json b/configs/systolic_ws_128x128_c2_booksim_tpuv3.json deleted file mode 100644 index 66566324..00000000 --- a/configs/systolic_ws_128x128_c2_booksim_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/fly_c32_m32.icnt", - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml b/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml new file mode 100644 index 00000000..e411c0f3 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/fly_c32_m32.icnt + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.json b/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.json deleted file mode 100644 index 8ef47e87..00000000 --- a/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "sram_size" : 65536, - "core_print_interval" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq" : 940, - "dram_channels": 8, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_nbl" : 2, - "dram_print_interval": 10000, - "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_latency_cycles" : 10, - "icnt_freq" : 940, - "icnt_injection_ports_per_core" : 16, - "icnt_config_path" : "../configs/booksim2_configs/fly_c32_m8.icnt", - - "precision" : 4, - "scheduler" : "simple", - "num_partition" : 2, - "partition": { - "core_0":0, - "core_1":0 - }, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.yml b/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.yml new file mode 100644 index 00000000..f164b108 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.yml @@ -0,0 +1,39 @@ +num_cores: 2 +core_freq_mhz: 940 +sram_size: 65536 +core_print_interval: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq: 940 +dram_channels: 8 +dram_req_size: 32 +dram_latency: 10 +dram_nbl: 2 +dram_print_interval: 10000 +dram_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_latency_cycles: 10 +icnt_freq: 940 +icnt_injection_ports_per_core: 16 +icnt_config_path: ../configs/booksim2_configs/fly_c32_m8.icnt +precision: 4 +scheduler: simple +num_partition: 2 +partition: + core_0: 0 + core_1: 0 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json b/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json deleted file mode 100644 index ecd671bf..00000000 --- a/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "dram_num_partitions" : 2, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 1000, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/chiplet_32_32_2.icnt", - "icnt_stats_print_period_cycles" : 10000, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml b/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml new file mode 100644 index 00000000..e38f091f --- /dev/null +++ b/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml @@ -0,0 +1,32 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +dram_num_partitions: 2 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 1000 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/chiplet_32_32_2.icnt +icnt_stats_print_period_cycles: 10000 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json b/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json deleted file mode 100644 index 168fbe3a..00000000 --- a/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "dram_num_partitions" : 1, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 1000, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/chiplet_32_32_2.icnt", - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml b/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml new file mode 100644 index 00000000..57696243 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml @@ -0,0 +1,31 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +dram_num_partitions: 1 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 1000 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/chiplet_32_32_2.icnt + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json deleted file mode 100644 index 0a5f15b2..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :700, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml new file mode 100644 index 00000000..f0686055 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml @@ -0,0 +1,29 @@ +num_cores: 2 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json deleted file mode 100644 index f099b93d..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.yml new file mode 100644 index 00000000..511a5a09 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json deleted file mode 100644 index 681ef884..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "num_partition" : 2, - "partition": { - "core_0":0, - "core_1":1 - }, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml new file mode 100644 index 00000000..499ad823 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml @@ -0,0 +1,34 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +num_partition: 2 +partition: + core_0: 0 + core_1: 1 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json deleted file mode 100644 index d09228a1..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 1050, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 4, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :1200, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - "l2d_type" : "datacache", - "l2d_config" : "S:64:128:512,32,L:B:m:W:L,A:192:4,32:0,32", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 1050, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml new file mode 100644 index 00000000..da40f01e --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml @@ -0,0 +1,32 @@ +num_cores: 2 +core_freq_mhz: 1050 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 4 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 1200 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml +l2d_type: datacache +l2d_config: S:64:128:512,32,L:B:m:W:L,A:192:4,32:0,32 + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 1050 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_8x8_c1_booksim.json b/configs/systolic_ws_8x8_c1_booksim.json deleted file mode 100644 index 851664e6..00000000 --- a/configs/systolic_ws_8x8_c1_booksim.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 800, - "core_stats_print_period_cycles" : 100000, - - "vpu_num_lanes" : 8, - "vpu_spad_size_kb_per_lane" : 32, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :800, - "dram_channels": 1, - "dram_req_size_byte": 64, - "dram_num_burst_length" : 4, - "dram_stats_print_period_cycles": 100000, - "ramulator_config_path" : "../configs/ramulator2_configs/DDR4.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 800, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_8x8_c1_booksim.yml b/configs/systolic_ws_8x8_c1_booksim.yml new file mode 100644 index 00000000..6fd305f9 --- /dev/null +++ b/configs/systolic_ws_8x8_c1_booksim.yml @@ -0,0 +1,27 @@ +num_cores: 1 +core_freq_mhz: 800 +core_stats_print_period_cycles: 100000 + +vpu_num_lanes: 8 +vpu_spad_size_kb_per_lane: 32 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 800 +dram_channels: 1 +dram_req_size_byte: 64 +dram_num_burst_length: 4 +dram_stats_print_period_cycles: 100000 +ramulator_config_path: ../configs/ramulator2_configs/DDR4.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 800 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_8x8_c1_simple_noc.json b/configs/systolic_ws_8x8_c1_simple_noc.json deleted file mode 100644 index 2eb7e183..00000000 --- a/configs/systolic_ws_8x8_c1_simple_noc.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 800, - "core_stats_print_period_cycles" : 100000, - - "vpu_num_lanes" : 8, - "vpu_spad_size_kb_per_lane" : 32, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :800, - "dram_channels": 1, - "dram_req_size_byte": 64, - "dram_num_burst_length" : 4, - "dram_stats_print_period_cycles": 100000, - "ramulator_config_path" : "../configs/ramulator2_configs/DDR4.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 800, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_8x8_c1_simple_noc.yml b/configs/systolic_ws_8x8_c1_simple_noc.yml new file mode 100644 index 00000000..274f633c --- /dev/null +++ b/configs/systolic_ws_8x8_c1_simple_noc.yml @@ -0,0 +1,28 @@ +num_cores: 1 +core_freq_mhz: 800 +core_stats_print_period_cycles: 100000 + +vpu_num_lanes: 8 +vpu_spad_size_kb_per_lane: 32 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 800 +dram_channels: 1 +dram_req_size_byte: 64 +dram_num_burst_length: 4 +dram_stats_print_period_cycles: 100000 +ramulator_config_path: ../configs/ramulator2_configs/DDR4.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 800 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/experiments/BERT.py b/experiments/BERT.py index 3311682c..5ccd3084 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -36,7 +36,7 @@ def run_BERT(size, input_seq, config): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/artifact/cycle_validation/run_cycle.sh b/experiments/artifact/cycle_validation/run_cycle.sh index 99eed4ed..9cfd1e98 100755 --- a/experiments/artifact/cycle_validation/run_cycle.sh +++ b/experiments/artifact/cycle_validation/run_cycle.sh @@ -1,7 +1,7 @@ #!/bin/bash set -e -export TORCHSIM_CONFIG=$TORCHSIM_DIR/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +export TORCHSIM_CONFIG=$TORCHSIM_DIR/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml LOG_DIR=$TORCHSIM_DIR/experiments/artifact/logs mkdir -p $LOG_DIR diff --git a/experiments/artifact/speedup/run_speedup.sh b/experiments/artifact/speedup/run_speedup.sh index 9a19e9af..e84ab1a9 100755 --- a/experiments/artifact/speedup/run_speedup.sh +++ b/experiments/artifact/speedup/run_speedup.sh @@ -4,8 +4,8 @@ CONFIG_DIR="$TORCHSIM_DIR/configs" SIMULATOR_BIN="$TORCHSIM_DIR/TOGSim/build/bin/Simulator" configs=( - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - "systolic_ws_128x128_c2_booksim_tpuv3.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3.yml" + "systolic_ws_128x128_c2_booksim_tpuv3.yml" ) target_list=( diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh b/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh index fe872e02..467949af 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh @@ -2,10 +2,10 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + # "systolic_ws_8x8_c1_simple_noc.yml" + "systolic_ws_128x128_c2_simple_noc_tpuv3.yml" + #"systolic_ws_128x128_c2_booksim_tpuv3.yml" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SIZE_LIST=( diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh b/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh index 19613a34..fb681c74 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh @@ -2,10 +2,10 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + # "systolic_ws_8x8_c1_simple_noc.yml" + "systolic_ws_128x128_c2_simple_noc_tpuv3.yml" + #"systolic_ws_128x128_c2_booksim_tpuv3.yml" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SHAPE_LIST=( diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh b/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh index 6f3385f1..dc0fdd20 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh @@ -2,10 +2,10 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + # "systolic_ws_8x8_c1_simple_noc.yml" + "systolic_ws_128x128_c2_simple_noc_tpuv3.yml" + #"systolic_ws_128x128_c2_booksim_tpuv3.yml" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SHAPE_LIST=( diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh b/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh index ca4cfa39..2346ab3c 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh @@ -2,10 +2,10 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + # "systolic_ws_8x8_c1_simple_noc.yml" + "systolic_ws_128x128_c2_simple_noc_tpuv3.yml" + #"systolic_ws_128x128_c2_booksim_tpuv3.yml" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SIZE_LIST=( diff --git a/experiments/attention.py b/experiments/attention.py index bbd2734e..842f105a 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -36,7 +36,7 @@ def attention(query, key, value): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/conv.py b/experiments/conv.py index f439c5e3..25952fb0 100644 --- a/experiments/conv.py +++ b/experiments/conv.py @@ -37,7 +37,7 @@ def custom_conv2d(a, b, bias): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/gemm.py b/experiments/gemm.py index e92200d1..3090e331 100644 --- a/experiments/gemm.py +++ b/experiments/gemm.py @@ -31,7 +31,7 @@ def custom_matmul(a, b): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml) config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/layernorm.py b/experiments/layernorm.py index 74b6d286..9c9934a1 100644 --- a/experiments/layernorm.py +++ b/experiments/layernorm.py @@ -27,7 +27,7 @@ def run_layernorm(size, config): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/resnet18.py b/experiments/resnet18.py index 45311d59..5451e0f5 100644 --- a/experiments/resnet18.py +++ b/experiments/resnet18.py @@ -29,7 +29,7 @@ def run_resnet(batch, config): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/resnet50.py b/experiments/resnet50.py index 4f03ea15..83d82db4 100644 --- a/experiments/resnet50.py +++ b/experiments/resnet50.py @@ -29,7 +29,7 @@ def run_resnet(batch, config): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/softmax.py b/experiments/softmax.py index b47bd685..580d56ca 100644 --- a/experiments/softmax.py +++ b/experiments/softmax.py @@ -27,7 +27,7 @@ def run_softmax(size, config, dim=1): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/scripts/CompilerOpt_experiment/DMAopt.sh b/scripts/CompilerOpt_experiment/DMAopt.sh index 5c2dc65c..9e494d9b 100644 --- a/scripts/CompilerOpt_experiment/DMAopt.sh +++ b/scripts/CompilerOpt_experiment/DMAopt.sh @@ -1,5 +1,5 @@ #!/bin/bash -export TORCHSIM_CONFIG="/root/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json" +export TORCHSIM_CONFIG="/root/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml" # None FG DMA export TORCHSIM_SUBTILE=0 diff --git a/scripts/chiplet.sh b/scripts/chiplet.sh index 0d56ecae..e622874b 100755 --- a/scripts/chiplet.sh +++ b/scripts/chiplet.sh @@ -19,11 +19,11 @@ GEMM_DIR_NAME=$(basename "$GEMM_PATH") echo "GEMM Directory Name: $GEMM_DIR_NAME" CONFIG_LIST=( - "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json" + "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml" ) CONFIG_LIST2=( - "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_booksim_tpuv3.json" - "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json" + "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml" + "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml" ) shift shift @@ -39,7 +39,7 @@ MODELS_LIST="$GEMM_PATH/tile_graph.onnx" ATTRIBUTE_PATH="$GEMM_PATH/runtime_0000/attribute" for CONFIG in "${CONFIG_LIST[@]}"; do - CONFIG_NAME=$(basename "$CONFIG" .json) + CONFIG_NAME=$(basename "$CONFIG" .yml) for ATTRIBUTE_FILE in "${ATTRIBUTE_FILES[@]}"; do ATTRIBUTE_NAME=$(basename "$ATTRIBUTE_FILE") @@ -56,7 +56,7 @@ for CONFIG in "${CONFIG_LIST[@]}"; do done for CONFIG in "${CONFIG_LIST2[@]}"; do - CONFIG_NAME=$(basename "$CONFIG" .json) + CONFIG_NAME=$(basename "$CONFIG" .yml) ATTRIBUTE_NAME=0 RESULTS_DIR="./chiplet_results$INDEX_NAME/$GEMM_DIR_NAME/$ATTRIBUTE_NAME" mkdir -p "$RESULTS_DIR" diff --git a/scripts/chiplet_prep.py b/scripts/chiplet_prep.py index 32f7ad50..4f8b7f7c 100644 --- a/scripts/chiplet_prep.py +++ b/scripts/chiplet_prep.py @@ -1,5 +1,5 @@ import os -import json +import yaml import shutil import argparse import torch @@ -41,9 +41,11 @@ def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None): if not os.path.exists(file_path): print(f"File {file_path} does not exist.") return + with open(file_path, 'r') as f: - data = json.load(f) - # address_numa_stride와 subgraph_map 추가 + data = yaml.safe_load(f) + + # address_numa_stride, subgraph_map if address_numa_stride: data['address_numa_stride'] = address_numa_stride if subgraph_map: @@ -52,8 +54,9 @@ def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None): output_path = file_path = os.path.join(dump_path, 'runtime_0000', 'attribute') os.makedirs(output_path, exist_ok=True) output_file = os.path.join(output_path, name) + with open(output_file, 'w') as f: - json.dump(data, f, indent=4) + yaml.dump(data, f, default_flow_style=False, sort_keys=False) print(f"Modified file saved to {output_file}") if __name__ == "__main__": diff --git a/scripts/sparsity_experiment/run.sh b/scripts/sparsity_experiment/run.sh index 4f5dd3a6..84c818ac 100755 --- a/scripts/sparsity_experiment/run.sh +++ b/scripts/sparsity_experiment/run.sh @@ -5,7 +5,7 @@ export TORCHSIM_FORCE_TIME_M=8 export TORCHSIM_FORCE_TIME_N=8 OUTPUT_DIR="12GB" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_12G_simple_noc.json" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_12G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -13,7 +13,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="24GB" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_24G_simple_noc.json" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_24G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -21,7 +21,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="48GB" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_48G_simple_noc.json" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_48G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -29,7 +29,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="12GB_2core" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_12G_simple_noc.json" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_12G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -37,7 +37,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="24GB_2core" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_24G_simple_noc.json" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_24G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -45,7 +45,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="48GB_2core" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_48G_simple_noc.json" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_48G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 diff --git a/scripts/stonne_experiment/run.sh b/scripts/stonne_experiment/run.sh index 1825817f..2e386d9c 100755 --- a/scripts/stonne_experiment/run.sh +++ b/scripts/stonne_experiment/run.sh @@ -2,8 +2,8 @@ export TORCHSIM_FORCE_TIME_M=1024 export TORCHSIM_FORCE_TIME_K=1024 export TORCHSIM_FORCE_TIME_N=1024 -python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config stonne_big_c1_simple_noc.json --mode 0 > hetero/big_sparse.log -python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config systolic_ws_128x128_c1_simple_noc_tpuv3_half.json --mode 1 > hetero/big.log -python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config heterogeneous_c2_simple_noc.json --mode 2 > hetero/hetero.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config stonne_big_c1_simple_noc.yml --mode 0 > hetero/big_sparse.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml --mode 1 > hetero/big.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config heterogeneous_c2_simple_noc.yml --mode 2 > hetero/hetero.log echo "All processes completed!" diff --git a/scripts/stonne_experiment2/tog_gen.py b/scripts/stonne_experiment2/tog_gen.py index d4f93d4d..e8013da7 100644 --- a/scripts/stonne_experiment2/tog_gen.py +++ b/scripts/stonne_experiment2/tog_gen.py @@ -72,7 +72,7 @@ def extract_simulation_stats(result_path): continue tog_path = os.path.join(path, "tile_graph.onnx") togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") - stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_validation_c1_simple_noc.json' + stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_validation_c1_simple_noc.yml' backsim = TOGSimulator(togsim_path, stonne_config_path) result_path = backsim.simulation(tog_path) nr_multiplications, total_cycle, sim_time = extract_simulation_stats(result_path) diff --git a/tests/Fusion/test_matmul_vector.py b/tests/Fusion/test_matmul_vector.py new file mode 100644 index 00000000..bf1bd513 --- /dev/null +++ b/tests/Fusion/test_matmul_vector.py @@ -0,0 +1,52 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_matmul_vector(device, size=[56, 78, 239], dim=0): + def matmul_fused(a, b, c, d): + return torch.matmul(a, b) + c + d + torch.manual_seed(0) + input = torch.randn(size[:2]) + weight = torch.randn(size[1:]) + output_sz = [size[0], size[2]] + output_sz[dim]=1 + bias = torch.zeros(output_sz) + add = torch.zeros(output_sz) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + a1 = add.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + a2 = add.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, a1, b1) + y = matmul_fused(x2, w2, a2, b2) + test_result("Matmul Vector Fusion Forward", res, y) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_matmul_vector(device, size=[253, 123, 47], dim=0) + test_matmul_vector(device, size=[253, 123, 47], dim=1) \ No newline at end of file diff --git a/tests/Llama/test_llama.py b/tests/Llama/test_llama.py new file mode 100644 index 00000000..443f3fc2 --- /dev/null +++ b/tests/Llama/test_llama.py @@ -0,0 +1,394 @@ +import os +import sys +import argparse +import copy +import torch +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, LlamaModel + +def test_result(name, out, ref, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), ref.cpu(), rtol=rtol, atol=atol): + msg = f"|{name} Test Passed|" + print("-" * len(msg)); print(msg); print("-" * len(msg)) + else: + msg = f"|{name} Test Failed|" + print("-" * len(msg)); print(msg); print("-" * len(msg)) + diff = (out.cpu().int() - ref.cpu().int()).abs().max().item() + print("device out:", out.detach().cpu()) + print("cpu ref :", ref.detach().cpu()) + print(f"Max abs diff: {diff}") + sys.exit(1) + +@torch.no_grad() +def run_rmsnorm_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaRMSNorm Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + hidden_size = 4096 + eps = 1e-6 + + print(f"Building LlamaRMSNorm (hidden_size={hidden_size}, eps={eps})") + base_norm = LlamaRMSNorm(hidden_size=hidden_size, eps=eps).eval() + cpu_norm = copy.deepcopy(base_norm).eval() + + cpu_norm.to(dtype=torch_dtype, device="cpu") + model = base_norm.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + hidden_states = torch.randn(batch, seq_len, hidden_size, generator=g, dtype=torch_dtype) + hs_dev = hidden_states.to(device) + + print("Compiling LlamaRMSNorm with torch.compile(...)") + compiled_norm = torch.compile(model, dynamic=False) + + out_cpu = cpu_norm(hidden_states) + out_dev = compiled_norm(hs_dev) + + test_result("LlamaRMSNorm forward", out_dev, out_cpu, rtol=rtol, atol=atol) + print("Max diff >", (out_dev.detach().cpu() - out_cpu.detach().cpu()).abs().max().item()) + + +@torch.no_grad() +def run_rotary_embedding_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaRotaryEmbedding Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + hidden_size = 4096 + num_heads = 32 + head_dim = hidden_size // num_heads + + cfg = LlamaConfig( + _name_or_path="custom-llama", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + mlp_bias=False, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype=dtype, + transformers_version="4.43.4", + use_cache=True, + vocab_size=8192, + _attn_implementation = "sdpa" + ) + base_rope = LlamaRotaryEmbedding(cfg) + + cpu_rope = copy.deepcopy(base_rope) + + cpu_rope.to(device="cpu") + model = base_rope.to(device=device) + + g = torch.Generator().manual_seed(0) + value = torch.randn(batch, num_heads, seq_len, head_dim, generator=g, dtype=torch_dtype) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(batch, -1) + + val_dev = value.to(device) + pos_dev = position_ids.to(device) + + print("Compiling LlamaRotaryEmbedding with torch.compile(...)") + compiled_rope = torch.compile(model, dynamic=False) + + cos_cpu, sin_cpu = cpu_rope(value, position_ids) + cos_dev, sin_dev = compiled_rope(val_dev, pos_dev) + + print(f"Output dtype check - CPU: {cos_cpu.dtype}, Device: {cos_dev.dtype}") + + test_result("LlamaRotaryEmbedding (Cos)", cos_dev, cos_cpu, rtol=rtol, atol=atol) + test_result("LlamaRotaryEmbedding (Sin)", sin_dev, sin_cpu, rtol=rtol, atol=atol) + + diff_cos = (cos_dev.detach().cpu() - cos_cpu.detach().cpu()).abs().max().item() + diff_sin = (sin_dev.detach().cpu() - sin_cpu.detach().cpu()).abs().max().item() + print(f"Max diff (Cos) > {diff_cos}") + print(f"Max diff (Sin) > {diff_sin}") + +@torch.no_grad() +def run_decoder_layer_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaDecoderLayer Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + cfg = LlamaConfig( + _name_or_path="custom-llama", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + mlp_bias=False, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype=dtype, + transformers_version="4.43.4", + use_cache=True, + vocab_size=8192, + _attn_implementation = "sdpa" + ) + + print("Building LlamaDecoderLayer from custom config.") + base_layer = LlamaDecoderLayer(cfg, layer_idx=0).eval() + cpu_layer = copy.deepcopy(base_layer).eval() + + cpu_layer.to(dtype=torch_dtype, device="cpu") + model = base_layer.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + hidden_states = torch.randn(batch, seq_len, cfg.hidden_size, generator=g, dtype=torch_dtype) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(batch, -1) + + attention_mask = torch.zeros(batch, 1, seq_len, seq_len, dtype=torch_dtype) + mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) + attention_mask.masked_fill_(mask, torch.finfo(torch_dtype).min) + + # Shape: (1, seq_len, head_dim) or (batch, seq_len, head_dim) + head_dim = cfg.hidden_size // cfg.num_attention_heads + cos = torch.randn(1, seq_len, head_dim, generator=g, dtype=torch_dtype) + sin = torch.randn(1, seq_len, head_dim, generator=g, dtype=torch_dtype) + position_embeddings = (cos, sin) + + hs_dev = hidden_states.to(device) + pos_dev = position_ids.to(device) + att_dev = attention_mask.to(device) + pos_emb_dev = (cos.to(device), sin.to(device)) + + print("Compiling LlamaDecoderLayer with torch.compile(...)") + compiled_layer = torch.compile(model, dynamic=False) + + out_cpu = cpu_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings + ) + if isinstance(out_cpu, tuple): + out_cpu = out_cpu[0] + + out_dev = compiled_layer( + hidden_states=hs_dev, + attention_mask=att_dev, + position_ids=pos_dev, + position_embeddings=pos_emb_dev + ) + if isinstance(out_dev, tuple): + out_dev = out_dev[0] + + test_result("LlamaDecoderLayer forward", out_dev, out_cpu, rtol=rtol, atol=atol) + print("Max diff >", (out_dev.detach().cpu() - out_cpu.detach().cpu()).abs().max().item()) + +@torch.no_grad() +def run_custom_llama_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, + max_new_tokens=16, +): + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + cfg = LlamaConfig( + _name_or_path="custom-llama", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + mlp_bias=False, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype=dtype, + transformers_version="4.43.4", + use_cache=True, + vocab_size=8192, + ) + + print("Building LlamaForCausalLM from custom config (random init).") + base_model = LlamaForCausalLM(cfg).eval() + cpu_model = copy.deepcopy(base_model).eval() + + cpu_model.to(dtype=torch_dtype, device="cpu") + model = base_model.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + vocab = cfg.vocab_size + input_ids_cpu = torch.randint(low=0, high=vocab, size=(batch, seq_len), generator=g, dtype=torch.long) + + min_dtype = torch.finfo(torch_dtype).min + causal_mask = torch.zeros((seq_len, seq_len), dtype=torch_dtype, device="cpu") + + if seq_len > 1: + causal_mask = torch.triu(torch.full_like(causal_mask, min_dtype), diagonal=1) + + cache_position = torch.arange(seq_len, device="cpu") + mask_condition = torch.arange(seq_len, device="cpu") > cache_position.reshape(-1, 1) + causal_mask.masked_fill_(mask_condition, min_dtype) + attn_mask_cpu = causal_mask[None, None, :, :].expand(batch, 1, -1, -1) + + input_ids_dev = input_ids_cpu.to(device) + attn_mask_dev = attn_mask_cpu.to(device) + + # ---- forward comparison (compile vs CPU baseline) ---- + print("Compiling model with torch.compile(...)") + compiled = torch.compile(model, dynamic=False) + + logits_cpu = cpu_model(input_ids=input_ids_cpu, attention_mask=attn_mask_cpu)#.logits + logits_dev = compiled(input_ids=input_ids_dev, attention_mask=attn_mask_dev)#.logits + + test_result("Custom Llama forward(logits)", logits_dev, logits_cpu, rtol=rtol, atol=atol) + print("Max diff >", (logits_dev.detach().cpu() - logits_cpu.detach().cpu()).abs().max().item()) + +@torch.no_grad() +def run_llama_model_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaModel Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + cfg = LlamaConfig( + vocab_size=8192, + hidden_size=1024, + num_attention_heads=32, + num_key_value_heads=32, + intermediate_size=11008 // 4, + num_hidden_layers=1, + max_position_embeddings=4096, + hidden_act="silu", + use_cache=False, + torch_dtype=dtype, + ) + + print("Building LlamaModel from custom config (random init).") + base_model = LlamaModel(cfg).eval() + cpu_model = copy.deepcopy(base_model).eval() + + cpu_model.to(dtype=torch_dtype, device="cpu") + model = base_model.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + input_ids_cpu = torch.randint(low=0, high=cfg.vocab_size, size=(batch, seq_len), generator=g, dtype=torch.long) + + min_dtype = torch.finfo(torch_dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=min_dtype, dtype=torch_dtype, device="cpu") + if seq_len > 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + attn_mask_cpu = causal_mask[None, None, :, :].expand(batch, 1, -1, -1) + + input_ids_dev = input_ids_cpu.to(device) + attn_mask_dev = attn_mask_cpu.to(device) + + print("Compiling LlamaModel with torch.compile(...)") + compiled_model = torch.compile(model, dynamic=False) + + out_cpu = cpu_model(input_ids=input_ids_cpu, attention_mask=attn_mask_cpu) + out_dev = compiled_model(input_ids=input_ids_dev, attention_mask=attn_mask_dev) + + last_hidden_state_cpu = out_cpu.last_hidden_state + last_hidden_state_dev = out_dev.last_hidden_state + + test_result("LlamaModel (last_hidden_state)", last_hidden_state_dev, last_hidden_state_cpu, rtol=rtol, atol=atol) + diff = (last_hidden_state_dev.detach().cpu() - last_hidden_state_cpu.detach().cpu()).abs().max().item() + print(f"Max diff > {diff}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test Custom Llama (random weights, no tokenizer)") + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=32) + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--rtol", type=float, default=1e-3) + parser.add_argument("--atol", type=float, default=1e-3) + parser.add_argument("--max_new_tokens", type=int, default=16) + args = parser.parse_args() + + sys.path.append(os.environ.get("PYTORCHSIM_ROOT_PATH", "/workspace/PyTorchSim")) + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + #test_triu(device, size=(32, 128), diagonal=1) + torch.compiler.is_compiling = lambda: True # FIXME. How to fix this? + #run_rmsnorm_test(device) + #run_rotary_embedding_test(device) + #run_decoder_layer_test( + # device=device, + # batch=args.batch, + # seq_len=args.seq_len, + # dtype=args.dtype, + # rtol=args.rtol, + # atol=args.atol, + #) + run_llama_model_test(device) + #run_custom_llama_test( + # device=device, + # batch=args.batch, + # seq_len=args.seq_len, + # dtype=args.dtype, + # rtol=args.rtol, + # atol=args.atol, + #) diff --git a/tests/Mixtral_8x7B/test_attention.py b/tests/Mixtral_8x7B/test_attention.py index 6a7747f7..58955928 100644 --- a/tests/Mixtral_8x7B/test_attention.py +++ b/tests/Mixtral_8x7B/test_attention.py @@ -166,8 +166,8 @@ def test_rmsnorm(device, seq=32): from Scheduler.scheduler import PyTorchSimRunner module = PyTorchSimRunner.setup_device() device = module.custom_device() - test_rmsnorm(device, seq=1) - test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) + #test_rmsnorm(device, seq=1) + #test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) test_decode(device, 32, 3) #test_attention(device) #test_ffn(device) diff --git a/tests/test_compile_overhead.py b/tests/test_compile_overhead.py index 030f548e..449707a5 100644 --- a/tests/test_compile_overhead.py +++ b/tests/test_compile_overhead.py @@ -21,7 +21,7 @@ # shutil.rmtree("/tmp/torchinductor") #except FileNotFoundError: # print("no cache") - scheduler = Scheduler(num_request_queue=1, max_batch=4, engine_select=Scheduler.FIFO_ENGINE, togsim_config=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json") + scheduler = Scheduler(num_request_queue=1, max_batch=4, engine_select=Scheduler.FIFO_ENGINE, togsim_config=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml") # Register compiled model opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False) SchedulerDNNModel.register_model("resnet18", opt_model1) diff --git a/tests/test_hetro.py b/tests/test_hetro.py index a0716e2d..9fac8c65 100644 --- a/tests/test_hetro.py +++ b/tests/test_hetro.py @@ -17,7 +17,7 @@ def custom_matmul(a, b): parser.add_argument("--N", type=int, default=128, help="Input layer size") parser.add_argument("--K", type=int, default=128, help="Hidden layer size") parser.add_argument("--sparsity", type=float, default=0.9, help="Output layer size") - parser.add_argument("--config", type=str, default="stonne_big_c1_simple_noc.json", help="Output layer size") + parser.add_argument("--config", type=str, default="stonne_big_c1_simple_noc.yml", help="Output layer size") parser.add_argument("--mode", type=int, default=0, help="Output layer size") args = parser.parse_args() diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py index c6afaf86..6cfa7b58 100644 --- a/tests/test_indirect_access.py +++ b/tests/test_indirect_access.py @@ -43,6 +43,40 @@ def test_embedding(device, vocab_size, dim): cpu_res = cpu_emb(cpu_prompt) test_result("Embedding", res, cpu_res) +def test_scatter_add(device, num_tokens=256, hidden_size=256, num_assignments=3, dtype=torch.float32, seed=0): + torch.manual_seed(seed) + + def scatter_only(out, token_indices, weighted_output): + # token_indices: [N] (long), weighted_output: [N, H] + out.index_add_(0, token_indices, weighted_output) + return out + + out = torch.randn(num_tokens, hidden_size, dtype=dtype) + out_cp = out.clone() + token_indices = torch.randint(0, num_tokens, (num_assignments,)) + weighted_output = torch.randn(num_assignments, hidden_size, dtype=dtype) + + cpu_out = scatter_only(out, token_indices, weighted_output) + + out = out_cp.to(device=device) + token_indices = token_indices.to(device=device) + weighted_output = weighted_output.to(device=device) + opt_fn = torch.compile(dynamic=False)(scatter_only) + res = opt_fn(out, token_indices, weighted_output) + test_result("ScatterAdd(index_add_)", res, cpu_out) + +def test_scatter_full(device, size=(128, 128)): + def vectoradd(a, idx, b): + a[idx, :] = b + return a + x = torch.randn(size, dtype=torch.float32).to(device=device) + idx = torch.randint(0,128, [128]).to(device=device) + y = torch.randn(128, dtype=torch.float32).to(device=device) + opt_fn = torch.compile(dynamic=False)(vectoradd) + res = opt_fn(x, idx, y) + out = vectoradd(x.cpu(), idx.cpu(), y.cpu()) + test_result("Indirect VectorAdd", res, out) + if __name__ == "__main__": import os import sys @@ -51,5 +85,7 @@ def test_embedding(device, vocab_size, dim): from Scheduler.scheduler import PyTorchSimRunner module = PyTorchSimRunner.setup_device() device = module.custom_device() + test_scatter_full(device) + test_scatter_add(device) test_indirect_vectoradd(device) #test_embedding(device, 1024, 2048) \ No newline at end of file diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 4860de56..9c7ca255 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -7,7 +7,7 @@ base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') sys.path.append(base_path) from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request -config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json' +config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml' target_model1 = model1().eval() target_model2 = model2(768, 12).eval() diff --git a/tests/test_scheduler_batching.py b/tests/test_scheduler_batching.py index 53f9256d..65213ef0 100644 --- a/tests/test_scheduler_batching.py +++ b/tests/test_scheduler_batching.py @@ -17,7 +17,7 @@ target_model1 = model1().eval() # Init scheduler - scheduler = Scheduler(num_request_queue=1, max_batch=32, engine_select=Scheduler.FIFO_ENGINE, togsim_config=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json") + scheduler = Scheduler(num_request_queue=1, max_batch=32, engine_select=Scheduler.FIFO_ENGINE, togsim_config=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml") # Register compiled model opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False) SchedulerDNNModel.register_model("resnet18", opt_model1) diff --git a/tutorial/session1/CompilerOptimization.ipynb b/tutorial/session1/CompilerOptimization.ipynb index 178974c1..ead695c0 100644 --- a/tutorial/session1/CompilerOptimization.ipynb +++ b/tutorial/session1/CompilerOptimization.ipynb @@ -18,7 +18,7 @@ "import sys\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", "sys.path.append(base_dir)\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"" + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"" ] }, { @@ -71,7 +71,7 @@ "outputs": [], "source": [ "os.environ['TORCHSIM_DUMP_PATH']=os.path.join(os.getcwd(), \"non_fused\")\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", diff --git a/tutorial/session1/ExecutionMode.ipynb b/tutorial/session1/ExecutionMode.ipynb index 22e00bed..b6f0e048 100644 --- a/tutorial/session1/ExecutionMode.ipynb +++ b/tutorial/session1/ExecutionMode.ipynb @@ -56,7 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_functional_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_functional_only.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -101,7 +101,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"\n", "\n", "input = torch.randn(2048, 2048).to(device=device)\n", "weight = torch.randn(2048, 2048).to(device=device)\n", @@ -132,7 +132,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_2_cores.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_2_cores.yml\"\n", "\n", "input = torch.randn(2048, 2048).to(device=device)\n", "weight = torch.randn(2048, 2048).to(device=device)\n", diff --git a/tutorial/session1/LogAnalysis.ipynb b/tutorial/session1/LogAnalysis.ipynb index 4f1e17cb..d3207af1 100644 --- a/tutorial/session1/LogAnalysis.ipynb +++ b/tutorial/session1/LogAnalysis.ipynb @@ -18,7 +18,7 @@ "import sys\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", "sys.path.append(base_dir)\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"\n", "os.environ['TORCHSIM_DUMP_LOG_PATH']=os.path.join(os.getcwd(), \"togsim_results\")" ] }, diff --git a/tutorial/session1/Mapping.ipynb b/tutorial/session1/Mapping.ipynb index b02c98fe..684b69c0 100644 --- a/tutorial/session1/Mapping.ipynb +++ b/tutorial/session1/Mapping.ipynb @@ -68,7 +68,7 @@ "source": [ "torch._dynamo.reset()\n", "\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_external_mapping.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_external_mapping.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -101,7 +101,7 @@ "source": [ "torch._dynamo.reset()\n", "\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_autotune.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_autotune.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n",