From 8cc90c4f9b0fead5a058e8ad86d1751204af1c0f Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 9 Nov 2025 13:52:28 -0500 Subject: [PATCH 01/26] bring record analysis work over --- src/bloqade/analysis/record/__init__.py | 2 + src/bloqade/analysis/record/analysis.py | 65 +++++++++ src/bloqade/analysis/record/impls.py | 141 +++++++++++++++++++ src/bloqade/analysis/record/lattice.py | 88 ++++++++++++ src/bloqade/stim/passes/soft_flatten.py | 90 ++++++++++++ test/analysis/record/test_record_analysis.py | 63 +++++++++ 6 files changed, 449 insertions(+) create mode 100644 src/bloqade/analysis/record/__init__.py create mode 100644 src/bloqade/analysis/record/analysis.py create mode 100644 src/bloqade/analysis/record/impls.py create mode 100644 src/bloqade/analysis/record/lattice.py create mode 100644 src/bloqade/stim/passes/soft_flatten.py create mode 100644 test/analysis/record/test_record_analysis.py diff --git a/src/bloqade/analysis/record/__init__.py b/src/bloqade/analysis/record/__init__.py new file mode 100644 index 00000000..6741d40a --- /dev/null +++ b/src/bloqade/analysis/record/__init__.py @@ -0,0 +1,2 @@ +from . import impls as impls +from .analysis import RecordAnalysis as RecordAnalysis diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py new file mode 100644 index 00000000..3b0f42a0 --- /dev/null +++ b/src/bloqade/analysis/record/analysis.py @@ -0,0 +1,65 @@ +from typing import TypeVar +from dataclasses import field, dataclass + +from kirin import ir +from kirin.analysis import ForwardExtra, const +from kirin.analysis.forward import ForwardFrame + +from .lattice import Record, RecordIdx + + +@dataclass +class GlobalRecordState: + stack: list[RecordIdx] = field(default_factory=list) + + # assume that this RecordIdx will always be -1 + def increment_record_idx(self) -> RecordIdx: + # adjust all previous indices + for record_idx in self.stack: + record_idx.idx -= 1 + self.stack.append(RecordIdx(-1)) + # Return for usage + return self.stack[-1] + + def drop_record_idx(self, record_to_drop: RecordIdx): + # there is a chance now that the ordering is messed up but + # we can now update the indices to enforce consistency. + # We only have to update UP to the entry that was just removed + # everything else maintains ordering + dropped_idx = record_to_drop.idx + self.stack.remove(record_to_drop) + for record_idx in self.stack: + if record_idx.idx < dropped_idx: + record_idx.idx += 1 + + +@dataclass +class RecordFrame(ForwardFrame): + global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) + + +class RecordAnalysis(ForwardExtra[RecordFrame, Record]): + keys = ["record"] + lattice = Record + + def initialize_frame(self, code, *, has_parent_access: bool = False) -> RecordFrame: + return RecordFrame(code, has_parent_access=has_parent_access) + + def eval_stmt_fallback(self, frame: RecordFrame, stmt) -> tuple[Record, ...]: + return tuple(self.lattice.bottom() for _ in stmt.results) + + def run_method(self, method, args: tuple[Record, ...]): + # NOTE: we do not support dynamic calls here, thus no need to propagate method object + return self.run_callable(method.code, (self.lattice.bottom(),) + args) + + T = TypeVar("T") + + def get_const_value( + self, input_type: type[T], value: ir.SSAValue + ) -> type[T] | None: + if isinstance(hint := value.hints.get("const"), const.Value): + data = hint.data + if isinstance(data, input_type): + return hint.data + + return None diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py new file mode 100644 index 00000000..9a55f3eb --- /dev/null +++ b/src/bloqade/analysis/record/impls.py @@ -0,0 +1,141 @@ +from copy import deepcopy + +from kirin import types as kirin_types, interp +from kirin.dialects import py, scf, ilist + +from bloqade import qubit, annotate +from bloqade.annotate.stmts import SetDetector, SetObservable + +from .lattice import ( + AnyRecord, + NotRecord, + RecordIdx, + RecordTuple, + InvalidRecord, + ImmutableRecords, +) +from .analysis import RecordFrame, RecordAnalysis + + +@annotate.dialect.register(key="record") +class PhysicalAnnotations(interp.MethodTable): + # Both statements inherit from the base class "ConsumesMeasurementResults" + # both statements consume IList of MeasurementResults, so the input type should be + # expected to be a RecordTuple + @interp.impl(SetObservable) + @interp.impl(SetDetector) + def consumes_measurements( + self, interp: RecordAnalysis, frame: RecordFrame, stmt: SetDetector + ): + # Get the measurement results being consumed + record_tuple_at_stmt = frame.get(stmt.measurements) + + final_record_idxs = [ + deepcopy(record_idx) for record_idx in record_tuple_at_stmt.members + ] + + return (ImmutableRecords(members=tuple(final_record_idxs)),) + + +@qubit.dialect.register(key="record") +class SquinQubit(interp.MethodTable): + + @interp.impl(qubit.stmts.Measure) + def measure_qubit_list( + self, + interp: RecordAnalysis, + frame: RecordFrame, + stmt: qubit.stmts.Measure, + ): + + # try to get the length of the list + ## "...safely assume the type inference will give you what you need" + qubits_type = stmt.qubits.type + # vars[0] is just the type of the elements in the ilist, + # vars[1] can contain a literal with length information + num_qubits = qubits_type.vars[1] + if not isinstance(num_qubits, kirin_types.Literal): + return (AnyRecord(),) + + record_idxs = [] + for _ in range(num_qubits.data): + record_idx = frame.global_record_state.increment_record_idx() + record_idxs.append(record_idx) + + return (RecordTuple(members=tuple(record_idxs)),) + + +@py.indexing.dialect.register(key="record") +class PyIndexing(interp.MethodTable): + @interp.impl(py.GetItem) + def getitem(self, interp: RecordAnalysis, frame: RecordFrame, stmt: py.GetItem): + + idx_or_slice = interp.get_const_value((int, slice), stmt.index) + if idx_or_slice is None: + return (InvalidRecord(),) + + obj = frame.get(stmt.obj) + if isinstance(obj, RecordTuple): + if isinstance(idx_or_slice, slice): + return (RecordTuple(members=obj.members[idx_or_slice]),) + elif isinstance(idx_or_slice, int): + return (obj.members[idx_or_slice],) + else: + return (InvalidRecord(),) + # just propagate these down the line + elif isinstance(obj, (AnyRecord, NotRecord)): + return (obj,) + else: + return (InvalidRecord(),) + + +@ilist.dialect.register(key="record") +class IList(interp.MethodTable): + @interp.impl(ilist.New) + def new_ilist( + self, + interp: RecordAnalysis, + frame: interp.Frame, + stmt: ilist.New, + ): + return (RecordTuple(frame.get_values(stmt.values)),) + + +@py.assign.dialect.register(key="record") +class PyAlias(interp.MethodTable): + @interp.impl(py.Alias) + def alias( + self, + interp: RecordAnalysis, + frame: RecordFrame, + stmt: py.Alias, + ): + value = frame.get(stmt.value) + if isinstance(value, RecordIdx): + frame.global_record_state.drop_record_idx(value) + elif isinstance(value, RecordTuple): + for member in value.members: + frame.global_record_state.drop_record_idx(member) + + return (value,) + + +@scf.dialect.register(key="record") +class LoopHandling(scf.absint.Methods): + @interp.impl(scf.stmts.For) + def for_loop( + self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For + ): + + # this will contain the in-loop measure variable declared outside the loop + loop_vars = frame.get_values(stmt.initializers) + # NotRecord in the beginning just lets the sink have some value + loop_vars = interp_.run_ssacfg_region(frame, stmt.body, loop_vars) + + # need to update the information in the frame + if isinstance(loop_vars, interp.ReturnValue): + return loop_vars + elif loop_vars is None: + loop_vars = () + + return loop_vars diff --git a/src/bloqade/analysis/record/lattice.py b/src/bloqade/analysis/record/lattice.py new file mode 100644 index 00000000..b2c9cf80 --- /dev/null +++ b/src/bloqade/analysis/record/lattice.py @@ -0,0 +1,88 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + SimpleJoinMixin, + SimpleMeetMixin, +) + +# Taken directly from Kai-Hsin Wu's implementation +# with minor changes to names and addition of CanMeasureId type + + +@dataclass +class Record( + SimpleJoinMixin["Record"], + SimpleMeetMixin["Record"], + BoundedLattice["Record"], +): + + @classmethod + def bottom(cls) -> "Record": + return InvalidRecord() + + @classmethod + def top(cls) -> "Record": + return AnyRecord() + + +# Can pop up if user constructs some list containing a mixture +# of bools from measure results and other places, +# in which case the whole list is invalid +@final +@dataclass +class InvalidRecord(Record, metaclass=SingletonMeta): + + def is_subseteq(self, other: Record) -> bool: + return True + + +@final +@dataclass +class AnyRecord(Record, metaclass=SingletonMeta): + + def is_subseteq(self, other: Record) -> bool: + return isinstance(other, AnyRecord) + + +@final +@dataclass +class NotRecord(Record, metaclass=SingletonMeta): + + def is_subseteq(self, other: Record) -> bool: + return isinstance(other, NotRecord) + + +@final +@dataclass +class RecordIdx(Record): + idx: int + + def is_subseteq(self, other: Record) -> bool: + if isinstance(other, RecordIdx): + return self.idx == other.idx + return False + + +@final +@dataclass +class RecordTuple(Record): + members: tuple[RecordIdx, ...] + + def is_subseteq(self, other: Record) -> bool: + if isinstance(other, RecordTuple): + return all(a.is_subseteq(b) for a, b in zip(self.members, other.members)) + return False + + +@final +@dataclass +class ImmutableRecords(Record): + members: tuple[RecordIdx, ...] + + def is_subseteq(self, other: Record) -> bool: + if isinstance(other, ImmutableRecords): + return all(a.is_subseteq(b) for a, b in zip(self.members, other.members)) + return False diff --git a/src/bloqade/stim/passes/soft_flatten.py b/src/bloqade/stim/passes/soft_flatten.py new file mode 100644 index 00000000..cf4fbe6a --- /dev/null +++ b/src/bloqade/stim/passes/soft_flatten.py @@ -0,0 +1,90 @@ +# Taken from Phillip Weinberg's bloqade-shuttle implementation +from typing import Callable +from dataclasses import field, dataclass + +from kirin import ir +from kirin.passes import Fold, Pass, TypeInfer + +# from kirin.passes.aggressive import UnrollScf +from kirin.rewrite import ( + Walk, + Chain, + Inline, + Fixpoint, + CFGCompactify, + DeadCodeElimination, + CommonSubexpressionElimination, +) +from kirin.dialects import scf, ilist +from kirin.rewrite.abc import RewriteResult + +# from bloqade.qasm2.passes.fold import AggressiveUnroll +from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs + + +@dataclass +class AggressiveUnroll(Pass): + """A pass to unroll structured control flow""" + + additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True + + fold: Fold = field(init=False) + typeinfer: TypeInfer = field(init=False) + # scf_unroll: UnrollScf = field(init=False) + + def __post_init__(self): + self.fold = Fold(self.dialects, no_raise=self.no_raise) + self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) + # self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + result = RewriteResult() + # result = self.scf_unroll.unsafe_run(mt).join(result) + result = ( + Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) + .rewrite(mt.code) + .join(result) + ) + self.typeinfer.unsafe_run(mt) + result = self.fold.unsafe_run(mt).join(result) + result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result) + result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) + + rule = Chain( + CommonSubexpressionElimination(), + DeadCodeElimination(), + ) + result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) + + return result + + def inline_heuristic(self, node: ir.Statement) -> bool: + """The heuristic to decide whether to inline a function call or not. + inside loops and if-else, only inline simple functions, i.e. + functions with a single block + """ + return not isinstance( + node.parent_stmt, (scf.For, scf.IfElse) + ) and self.additional_inline_heuristic( + node + ) # always inline calls outside of loops and if-else + + +@dataclass +class SoftFlatten(Pass): + """ + like standard Flatten but without unrolling to let analysis go into loops + """ + + unroll: AggressiveUnroll = field(init=False) + simplify_if: StimSimplifyIfs = field(init=False) + + def __post_init__(self): + self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise) + self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + rewrite_result = RewriteResult() + rewrite_result = self.simplify_if(mt).join(rewrite_result) + rewrite_result = self.unroll(mt).join(rewrite_result) + return rewrite_result diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py new file mode 100644 index 00000000..c071f719 --- /dev/null +++ b/test/analysis/record/test_record_analysis.py @@ -0,0 +1,63 @@ +from bloqade import squin +from bloqade.analysis.record import RecordAnalysis +from bloqade.stim.passes.soft_flatten import SoftFlatten + + +@squin.kernel +def test(): + qs = squin.qalloc(5) + data_qs = [qs[0], qs[2], qs[4]] + and_qs = [qs[1], qs[3]] + + init_and_meas_res = squin.broadcast.measure(and_qs) + squin.set_detector([init_and_meas_res[0]], coordinates=[0, 0]) + squin.set_detector([init_and_meas_res[1]], coordinates=[0, 1]) + + and_meas_res = None + for _ in range(10): + and_meas_res = squin.broadcast.measure(and_qs) + + squin.set_detector([and_meas_res[0], init_and_meas_res[0]], coordinates=[0, 0]) + squin.set_detector([and_meas_res[1], init_and_meas_res[1]], coordinates=[1, 1]) + + init_and_meas_res = and_meas_res + + data_meas_res = squin.broadcast.measure(data_qs) + squin.set_detector( + [data_meas_res[0], data_meas_res[1], and_meas_res[0]], coordinates=[2, 0] + ) + squin.set_detector( + [data_meas_res[2], data_meas_res[1], and_meas_res[1]], coordinates=[2, 1] + ) + squin.set_observable([data_meas_res[0]]) + + # return and_meas_res + + +test.print() +SoftFlatten(dialects=test.dialects).fixpoint(test) +test.print() +frame, _ = RecordAnalysis(dialects=test.dialects).run_analysis(test) +test.print(analysis=frame.entries) + + +@squin.kernel +def analysis_demo(): + qs = squin.qalloc(3) + ms0 = squin.broadcast.measure(qs) + ms1 = squin.broadcast.measure(qs) + squin.set_detector(ms0, coordinates=(0, 0)) # -6 -5 -4 + squin.set_detector(ms1, coordinates=(0, 1)) # -3 -2 -1 + # squin.broadcast.measure(qs) # -3 -2 -1 + # physical.set_detector(ms1, coordinates=(0,2)) # -6 -5 -4 + + # get aliasing to work + # ms1 = ms0 + # physical.set_detector(ms1, coordinates=(1,0)) # -9 -8 -7 + # return ms1 + + +# SoftFlatten(dialects=analysis_demo.dialects).fixpoint(analysis_demo) +# analysis_demo.print() +# frame, _ = RecordAnalysis(dialects=analysis_demo.dialects).run_analysis(analysis_demo) +# analysis_demo.print(analysis=frame.entries) From 4e562937c479f4e072a9658acbc3451951230013 Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 10 Nov 2025 09:02:14 -0500 Subject: [PATCH 02/26] save wip test_record_analysis --- test/analysis/record/test_record_analysis.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index c071f719..1691608b 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -1,7 +1,9 @@ from bloqade import squin -from bloqade.analysis.record import RecordAnalysis -from bloqade.stim.passes.soft_flatten import SoftFlatten +# from bloqade.analysis.record import RecordAnalysis +# from bloqade.stim.passes.soft_flatten import SoftFlatten + +""" @squin.kernel def test(): @@ -39,6 +41,7 @@ def test(): test.print() frame, _ = RecordAnalysis(dialects=test.dialects).run_analysis(test) test.print(analysis=frame.entries) +""" @squin.kernel @@ -46,14 +49,14 @@ def analysis_demo(): qs = squin.qalloc(3) ms0 = squin.broadcast.measure(qs) ms1 = squin.broadcast.measure(qs) - squin.set_detector(ms0, coordinates=(0, 0)) # -6 -5 -4 - squin.set_detector(ms1, coordinates=(0, 1)) # -3 -2 -1 - # squin.broadcast.measure(qs) # -3 -2 -1 - # physical.set_detector(ms1, coordinates=(0,2)) # -6 -5 -4 + squin.set_detector(ms0, coordinates=(0, 0)) + squin.set_detector(ms1, coordinates=(0, 1)) + squin.broadcast.measure(qs) + squin.set_detector(ms1, coordinates=(0, 2)) # get aliasing to work - # ms1 = ms0 - # physical.set_detector(ms1, coordinates=(1,0)) # -9 -8 -7 + ms1 = ms0 + squin.set_detector(ms1, coordinates=(1, 0)) # return ms1 From eee712d81ffcd0a5403f4cd4d3fb5167184b1adf Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 11 Nov 2025 15:03:30 -0500 Subject: [PATCH 03/26] update to new kirin --- src/bloqade/analysis/record/analysis.py | 17 ++++++++++----- test/analysis/record/test_record_analysis.py | 22 +++++++++----------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py index 3b0f42a0..97f70692 100644 --- a/src/bloqade/analysis/record/analysis.py +++ b/src/bloqade/analysis/record/analysis.py @@ -42,15 +42,19 @@ class RecordAnalysis(ForwardExtra[RecordFrame, Record]): keys = ["record"] lattice = Record - def initialize_frame(self, code, *, has_parent_access: bool = False) -> RecordFrame: - return RecordFrame(code, has_parent_access=has_parent_access) + def initialize_frame( + self, node: ir.Statement, *, has_parent_access: bool = False + ) -> RecordFrame: + return RecordFrame(node, has_parent_access=has_parent_access) - def eval_stmt_fallback(self, frame: RecordFrame, stmt) -> tuple[Record, ...]: - return tuple(self.lattice.bottom() for _ in stmt.results) + def eval_stmt_fallback( + self, frame: RecordFrame, node: ir.Statement + ) -> tuple[Record, ...]: + return tuple(self.lattice.bottom() for _ in node.results) def run_method(self, method, args: tuple[Record, ...]): # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + return self.run_method(method.code, (self.lattice.bottom(),) + args) T = TypeVar("T") @@ -63,3 +67,6 @@ def get_const_value( return hint.data return None + + def method_self(self, method: ir.Method) -> Record: + return self.lattice.bottom() diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index 1691608b..effb6eeb 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -1,10 +1,8 @@ from bloqade import squin - -# from bloqade.analysis.record import RecordAnalysis -# from bloqade.stim.passes.soft_flatten import SoftFlatten +from bloqade.analysis.record import RecordAnalysis +from bloqade.stim.passes.soft_flatten import SoftFlatten """ - @squin.kernel def test(): qs = squin.qalloc(5) @@ -49,18 +47,18 @@ def analysis_demo(): qs = squin.qalloc(3) ms0 = squin.broadcast.measure(qs) ms1 = squin.broadcast.measure(qs) - squin.set_detector(ms0, coordinates=(0, 0)) - squin.set_detector(ms1, coordinates=(0, 1)) + squin.set_detector(ms0, coordinates=[0, 0]) + squin.set_detector(ms1, coordinates=[0, 1]) squin.broadcast.measure(qs) - squin.set_detector(ms1, coordinates=(0, 2)) + squin.set_detector(ms1, coordinates=[0, 2]) # get aliasing to work ms1 = ms0 - squin.set_detector(ms1, coordinates=(1, 0)) + squin.set_detector(ms1, coordinates=[1, 0]) # return ms1 -# SoftFlatten(dialects=analysis_demo.dialects).fixpoint(analysis_demo) -# analysis_demo.print() -# frame, _ = RecordAnalysis(dialects=analysis_demo.dialects).run_analysis(analysis_demo) -# analysis_demo.print(analysis=frame.entries) +SoftFlatten(dialects=analysis_demo.dialects).fixpoint(analysis_demo) +analysis_demo.print() +frame, _ = RecordAnalysis(dialects=analysis_demo.dialects).run(analysis_demo) +analysis_demo.print(analysis=frame.entries) From 4e1da7b0ad48e88616fd443b398e8a51ad724e3b Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 17 Nov 2025 14:43:04 -0500 Subject: [PATCH 04/26] almost there, still a problem with invariance checking --- src/bloqade/analysis/record/analysis.py | 51 +++---- src/bloqade/analysis/record/impls.py | 93 ++++++++---- src/bloqade/analysis/record/lattice.py | 13 ++ src/bloqade/stim/dialects/cf/__init__.py | 0 src/bloqade/stim/dialects/cf/_dialect.py | 3 + src/bloqade/stim/dialects/cf/stmts.py | 0 src/bloqade/stim/passes/simplify_ifs.py | 2 +- src/bloqade/stim/passes/soft_flatten.py | 6 +- test/analysis/record/test_record_analysis.py | 150 ++++++++++++++++++- 9 files changed, 251 insertions(+), 67 deletions(-) create mode 100644 src/bloqade/stim/dialects/cf/__init__.py create mode 100644 src/bloqade/stim/dialects/cf/_dialect.py create mode 100644 src/bloqade/stim/dialects/cf/stmts.py diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py index 97f70692..2068ec42 100644 --- a/src/bloqade/analysis/record/analysis.py +++ b/src/bloqade/analysis/record/analysis.py @@ -1,8 +1,7 @@ -from typing import TypeVar from dataclasses import field, dataclass from kirin import ir -from kirin.analysis import ForwardExtra, const +from kirin.analysis import ForwardExtra from kirin.analysis.forward import ForwardFrame from .lattice import Record, RecordIdx @@ -10,27 +9,27 @@ @dataclass class GlobalRecordState: - stack: list[RecordIdx] = field(default_factory=list) + buffer: list[RecordIdx] = field(default_factory=list) # assume that this RecordIdx will always be -1 - def increment_record_idx(self) -> RecordIdx: + def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]: # adjust all previous indices - for record_idx in self.stack: - record_idx.idx -= 1 - self.stack.append(RecordIdx(-1)) - # Return for usage - return self.stack[-1] + for record_idx in self.buffer: + record_idx.idx -= num_new_records - def drop_record_idx(self, record_to_drop: RecordIdx): - # there is a chance now that the ordering is messed up but - # we can now update the indices to enforce consistency. - # We only have to update UP to the entry that was just removed - # everything else maintains ordering - dropped_idx = record_to_drop.idx - self.stack.remove(record_to_drop) - for record_idx in self.stack: - if record_idx.idx < dropped_idx: - record_idx.idx += 1 + # generate new indices and add them to the buffer + new_record_idxs = [RecordIdx(-i) for i in range(num_new_records, 0, -1)] + self.buffer += new_record_idxs + # Return for usage, idxs linked to the global state + return new_record_idxs + + """ + Might need a free after use! You can keep the size of the list small + but could be a premature optimization... + """ + # def drop_record_idxs(self, record_tuple: RecordTuple): + # for record_idx in record_tuple.members: + # self.buffer.remove(record_idx) @dataclass @@ -47,7 +46,7 @@ def initialize_frame( ) -> RecordFrame: return RecordFrame(node, has_parent_access=has_parent_access) - def eval_stmt_fallback( + def eval_fallback( self, frame: RecordFrame, node: ir.Statement ) -> tuple[Record, ...]: return tuple(self.lattice.bottom() for _ in node.results) @@ -56,17 +55,5 @@ def run_method(self, method, args: tuple[Record, ...]): # NOTE: we do not support dynamic calls here, thus no need to propagate method object return self.run_method(method.code, (self.lattice.bottom(),) + args) - T = TypeVar("T") - - def get_const_value( - self, input_type: type[T], value: ir.SSAValue - ) -> type[T] | None: - if isinstance(hint := value.hints.get("const"), const.Value): - data = hint.data - if isinstance(data, input_type): - return hint.data - - return None - def method_self(self, method: ir.Method) -> Record: return self.lattice.bottom() diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index 9a55f3eb..70284c25 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -1,6 +1,7 @@ from copy import deepcopy from kirin import types as kirin_types, interp +from kirin.ir import PyAttr from kirin.dialects import py, scf, ilist from bloqade import qubit, annotate @@ -9,9 +10,9 @@ from .lattice import ( AnyRecord, NotRecord, - RecordIdx, RecordTuple, InvalidRecord, + ConstantCarrier, ImmutableRecords, ) from .analysis import RecordFrame, RecordAnalysis @@ -57,10 +58,7 @@ def measure_qubit_list( if not isinstance(num_qubits, kirin_types.Literal): return (AnyRecord(),) - record_idxs = [] - for _ in range(num_qubits.data): - record_idx = frame.global_record_state.increment_record_idx() - record_idxs.append(record_idx) + record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data) return (RecordTuple(members=tuple(record_idxs)),) @@ -70,9 +68,22 @@ class PyIndexing(interp.MethodTable): @interp.impl(py.GetItem) def getitem(self, interp: RecordAnalysis, frame: RecordFrame, stmt: py.GetItem): - idx_or_slice = interp.get_const_value((int, slice), stmt.index) - if idx_or_slice is None: - return (InvalidRecord(),) + # maybe_const will work fine outside of any loops because + # constprop will put the expected data into a hint. + + # if maybeconst fails, we fall back to getting the value from the frame + # (note that even outside loops, the constant impl will happily + # capture integer/slice constants so if THAT fails, then something + # has truly gone wrong). + possible_idx_or_slice = interp.maybe_const(stmt.index, (int, slice)) + if possible_idx_or_slice is not None: + idx_or_slice = possible_idx_or_slice + else: + idx_or_slice = frame.get(stmt.index) + if not isinstance(idx_or_slice, ConstantCarrier): + return (InvalidRecord(),) + else: + idx_or_slice = idx_or_slice.value obj = frame.get(stmt.obj) if isinstance(obj, RecordTuple): @@ -106,36 +117,68 @@ class PyAlias(interp.MethodTable): @interp.impl(py.Alias) def alias( self, - interp: RecordAnalysis, + interp_: RecordAnalysis, frame: RecordFrame, stmt: py.Alias, ): - value = frame.get(stmt.value) - if isinstance(value, RecordIdx): - frame.global_record_state.drop_record_idx(value) - elif isinstance(value, RecordTuple): - for member in value.members: - frame.global_record_state.drop_record_idx(member) + input = frame.get(stmt.value) # expect this to be a RecordTuple - return (value,) + # two variables share the same references in the global state + return (input,) @scf.dialect.register(key="record") -class LoopHandling(scf.absint.Methods): +class LoopHandling(interp.MethodTable): @interp.impl(scf.stmts.For) def for_loop( self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For ): - # this will contain the in-loop measure variable declared outside the loop loop_vars = frame.get_values(stmt.initializers) - # NotRecord in the beginning just lets the sink have some value - loop_vars = interp_.run_ssacfg_region(frame, stmt.body, loop_vars) - # need to update the information in the frame - if isinstance(loop_vars, interp.ReturnValue): - return loop_vars - elif loop_vars is None: - loop_vars = () + for _ in range(2): + loop_vars = interp_.frame_call_region( + frame, stmt, stmt.body, InvalidRecord(), *loop_vars + ) + + if loop_vars is None: + loop_vars = () + + elif isinstance(loop_vars, interp.ReturnValue): + return loop_vars return loop_vars + + @interp.impl(scf.stmts.Yield) + def for_yield( + self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.Yield + ): + return interp.YieldValue(frame.get_values(stmt.values)) + + +# Only carry about carrying integers for now because +# the current issue is that +@py.dialect.register(key="record") +class ConstantForwarding(interp.MethodTable): + @interp.impl(py.Constant) + def constant( + self, + interp_: RecordAnalysis, + frame: RecordFrame, + stmt: py.Constant, + ): + # can't use interp_.maybe_const/expect_const because it assumes the data is already + # there to begin with... + if not isinstance(stmt.value, PyAttr): + return (InvalidRecord(),) + + expected_int_or_slice = stmt.value.data + + if not isinstance(expected_int_or_slice, (int, slice)): + return (InvalidRecord(),) + + return (ConstantCarrier(value=expected_int_or_slice),) + + +# outside_frame -> create new frame with context manager COPIED from outside frame +# the frame and the stack are separate diff --git a/src/bloqade/analysis/record/lattice.py b/src/bloqade/analysis/record/lattice.py index b2c9cf80..0b4fac9f 100644 --- a/src/bloqade/analysis/record/lattice.py +++ b/src/bloqade/analysis/record/lattice.py @@ -55,6 +55,19 @@ def is_subseteq(self, other: Record) -> bool: return isinstance(other, NotRecord) +# For now I only care about propagating constant integers or slices, +# things that can be used as indices to list of measurements +@final +@dataclass +class ConstantCarrier(Record): + value: int | slice + + def is_subseteq(self, other: Record) -> bool: + if isinstance(other, ConstantCarrier): + return self.value == other.value + return False + + @final @dataclass class RecordIdx(Record): diff --git a/src/bloqade/stim/dialects/cf/__init__.py b/src/bloqade/stim/dialects/cf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/stim/dialects/cf/_dialect.py b/src/bloqade/stim/dialects/cf/_dialect.py new file mode 100644 index 00000000..75b011c9 --- /dev/null +++ b/src/bloqade/stim/dialects/cf/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("stim.cf") diff --git a/src/bloqade/stim/dialects/cf/stmts.py b/src/bloqade/stim/dialects/cf/stmts.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/stim/passes/simplify_ifs.py b/src/bloqade/stim/passes/simplify_ifs.py index 4db85d23..e2bb47c1 100644 --- a/src/bloqade/stim/passes/simplify_ifs.py +++ b/src/bloqade/stim/passes/simplify_ifs.py @@ -22,7 +22,7 @@ class StimSimplifyIfs(Pass): def unsafe_run(self, mt: ir.Method): result = Chain( - Walk(UnusedYield()), + Walk(UnusedYield()), # this is being too aggressive, need to file an issue Walk(StimLiftThenBody()), # remove yields (if possible), then lift out as much stuff as possible Walk(DeadCodeElimination()), diff --git a/src/bloqade/stim/passes/soft_flatten.py b/src/bloqade/stim/passes/soft_flatten.py index cf4fbe6a..11626788 100644 --- a/src/bloqade/stim/passes/soft_flatten.py +++ b/src/bloqade/stim/passes/soft_flatten.py @@ -81,10 +81,12 @@ class SoftFlatten(Pass): def __post_init__(self): self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise) - self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) + + # DO NOT USE FOR NOW, TrimUnusedYield call messes up loop structure + # self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) def unsafe_run(self, mt: ir.Method) -> RewriteResult: rewrite_result = RewriteResult() - rewrite_result = self.simplify_if(mt).join(rewrite_result) + # rewrite_result = self.simplify_if(mt).join(rewrite_result) rewrite_result = self.unroll(mt).join(rewrite_result) return rewrite_result diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index effb6eeb..5fef779f 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -1,3 +1,5 @@ +# from kirin.passes.fold import Fold + from bloqade import squin from bloqade.analysis.record import RecordAnalysis from bloqade.stim.passes.soft_flatten import SoftFlatten @@ -37,28 +39,162 @@ def test(): test.print() SoftFlatten(dialects=test.dialects).fixpoint(test) test.print() -frame, _ = RecordAnalysis(dialects=test.dialects).run_analysis(test) +frame, _ = RecordAnalysis(dialects=test.dialects).run(test) test.print(analysis=frame.entries) """ +""" +def hint_const_failure(): + + @squin.kernel + def test(): + qs = squin.qalloc(3) + ms0 = squin.broadcast.measure(qs) + i = 0 + for _ in range(5): + ms1 = squin.broadcast.measure(qs) + squin.set_detector([ms0[i], ms1[i]], coordinates=[i, i]) + + # SoftFlatten(dialects=test.dialects).fixpoint(test) + Fold(dialects=test.dialects, no_raise=False).fixpoint(test) + test.print(hint="const") + # frame, _ = RecordAnalysis(dialects=test.dialects).run(test) + # test.print(analysis=frame.entries, hint="const") + + +# Problematic having the variable substitution happen at the end +""" + + +def test_custom_const_carrier(): + + @squin.kernel(fold=False) + def test(x: int): + y = None + z = None + for _ in range(5): + f = [1, 2, 3, 4, 5, 5, 6, 7, 8] + z = slice(0, 2) + y = f[z] + y[0] += x + return y, z + + SoftFlatten(dialects=test.dialects).fixpoint(test) + test.print() + frame, _ = RecordAnalysis(dialects=test.dialects).run(test) + test.print(analysis=frame.entries, hint="const") + + +""" +def assignment_last_rep_code(): + @squin.kernel(fold=True) + def test(): + + qs = squin.qalloc(5) + data_qs = [qs[0], qs[2], qs[4]] + and_qs = [qs[1], qs[3]] + + init_and_ms = squin.broadcast.measure(and_qs) + + squin.set_detector([init_and_ms[0]], coordinates=[0, 0]) + squin.set_detector([init_and_ms[1]], coordinates=[0, 1]) + + # loop_and_ms = None + for _ in range(5): + loop_and_ms = squin.broadcast.measure(and_qs) + squin.annotate.set_detector([loop_and_ms[0], init_and_ms[0]], coordinates=[0,0]) + squin.annotate.set_detector([loop_and_ms[1], init_and_ms[1]], coordinates=[1,1]) + + #for i in range(len(curr_ms)): + # squin.annotate.set_detector([curr_ms[i], prev_ms[i]], coordinates=[1,1]) + + ##init_and_ms = loop_and_ms + + #data_ms = squin.broadcast.measure(data_qs) + #squin.set_detector( + # [data_ms[0], data_ms[1], loop_and_ms[0]], coordinates=[2, 0] + #) + #squin.set_detector( + # [data_ms[2], data_ms[1], loop_and_ms[1]], coordinates=[2, 1] + #) + + + SoftFlatten(dialects=test.dialects).fixpoint(test) + test.print() + frame, _ = RecordAnalysis(dialects=test.dialects).run(test) + test.print(analysis=frame.entries, hint="const") + +""" + +""" +from kirin.prelude import structural_no_opt + +@structural_no_opt +def demo(): + + a = 0 + b= 1 + for _ in range(10): + c = b + b = a + a = c + +demo.print() +""" + +def assignment_first_rep_code(): + @squin.kernel + def test(): + + qs = squin.qalloc(5) + data_qs = [qs[0], qs[2], qs[4]] + and_qs = [qs[1], qs[3]] + + curr_ms = squin.broadcast.measure(and_qs) + squin.set_detector([curr_ms[0]], coordinates=[0, 0]) + squin.set_detector([curr_ms[1]], coordinates=[0, 1]) + + for _ in range(5): + # prev lives entirely in the loop + prev_ms = curr_ms + curr_ms = squin.broadcast.measure(and_qs) + squin.annotate.set_detector([prev_ms[0], curr_ms[0]], coordinates=[0, 0]) + squin.annotate.set_detector([prev_ms[1], curr_ms[1]], coordinates=[0, 1]) + + squin.annotate.set_detector(curr_ms, coordinates=[0, 0]) + data_ms = squin.broadcast.measure(data_qs) + + squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0]) + squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) + + test.print() + SoftFlatten(dialects=test.dialects).fixpoint(test) + test.print() + frame, _ = RecordAnalysis(dialects=test.dialects).run(test) + test.print(analysis=frame.entries) + + +assignment_first_rep_code() + +""" @squin.kernel def analysis_demo(): qs = squin.qalloc(3) ms0 = squin.broadcast.measure(qs) ms1 = squin.broadcast.measure(qs) - squin.set_detector(ms0, coordinates=[0, 0]) - squin.set_detector(ms1, coordinates=[0, 1]) - squin.broadcast.measure(qs) - squin.set_detector(ms1, coordinates=[0, 2]) + squin.set_detector(ms0, coordinates=[0, 0]) # -4 -5 -6 + squin.set_detector(ms1, coordinates=[0, 1]) # -1 -2 -3 + # squin.broadcast.measure(qs) + squin.set_detector(ms1, coordinates=[0, 2]) # -4 -5 -6 # get aliasing to work ms1 = ms0 - squin.set_detector(ms1, coordinates=[1, 0]) - # return ms1 + squin.set_detector(ms1, coordinates=[1, 0]) # should also be -4 -5 -6 SoftFlatten(dialects=analysis_demo.dialects).fixpoint(analysis_demo) analysis_demo.print() frame, _ = RecordAnalysis(dialects=analysis_demo.dialects).run(analysis_demo) analysis_demo.print(analysis=frame.entries) +""" From 09130c645996d280585fb5620c5b64987149e0be Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 18 Nov 2025 09:15:45 -0500 Subject: [PATCH 05/26] loop invariance support complete --- src/bloqade/analysis/record/analysis.py | 13 ++++ src/bloqade/analysis/record/impls.py | 81 ++++++++++++++++---- test/analysis/record/test_record_analysis.py | 2 - 3 files changed, 81 insertions(+), 15 deletions(-) diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py index 2068ec42..a9c544d4 100644 --- a/src/bloqade/analysis/record/analysis.py +++ b/src/bloqade/analysis/record/analysis.py @@ -23,6 +23,19 @@ def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]: # Return for usage, idxs linked to the global state return new_record_idxs + """ + def clone_record_idxs(self, record_tuple: RecordTuple) -> RecordTuple: + cloned_members = [] + for record_idx in record_tuple.members: + cloned_record_idx = RecordIdx(record_idx.idx) + # put into the global buffer but also + # return an analysis-facing copy + self.buffer.append(cloned_record_idx) + cloned_members.append(cloned_record_idx) + + return RecordTuple(members=tuple(cloned_members)) + """ + """ Might need a free after use! You can keep the size of the list small but could be a premature optimization... diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index 70284c25..2ff6526c 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -10,6 +10,7 @@ from .lattice import ( AnyRecord, NotRecord, + RecordIdx, RecordTuple, InvalidRecord, ConstantCarrier, @@ -31,6 +32,12 @@ def consumes_measurements( # Get the measurement results being consumed record_tuple_at_stmt = frame.get(stmt.measurements) + if not ( + isinstance(record_tuple_at_stmt, RecordTuple) + and kirin_types.is_tuple_of(record_tuple_at_stmt.members, RecordIdx) + ): + return (InvalidRecord(),) + final_record_idxs = [ deepcopy(record_idx) for record_idx in record_tuple_at_stmt.members ] @@ -122,7 +129,7 @@ def alias( stmt: py.Alias, ): input = frame.get(stmt.value) # expect this to be a RecordTuple - + # frame.global_record_state.clone_record_idxs(input) # two variables share the same references in the global state return (input,) @@ -134,20 +141,70 @@ def for_loop( self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For ): - loop_vars = frame.get_values(stmt.initializers) - - for _ in range(2): - loop_vars = interp_.frame_call_region( - frame, stmt, stmt.body, InvalidRecord(), *loop_vars + init_loop_vars = frame.get_values(stmt.initializers) + + # You go through the loops twice to verify the loop invariant. + # we need to freeze the frame entries right after exiting the loop + + first_loop_frame = RecordFrame( + stmt, + global_record_state=frame.global_record_state, + parent=frame, + has_parent_access=True, + ) + first_loop_vars = interp_.frame_call_region( + first_loop_frame, stmt, stmt.body, InvalidRecord(), *init_loop_vars + ) + + if first_loop_vars is None: + first_loop_vars = () + elif isinstance(first_loop_vars, interp.ReturnValue): + return first_loop_vars + + captured_first_loop_entries = {} + captured_first_loop_vars = deepcopy(first_loop_vars) + + for ssa_val, lattice_element in first_loop_frame.entries.items(): + captured_first_loop_entries[ssa_val] = deepcopy(lattice_element) + + second_loop_frame = RecordFrame( + stmt, + global_record_state=frame.global_record_state, + parent=frame, + has_parent_access=True, + ) + second_loop_vars = interp_.frame_call_region( + second_loop_frame, stmt, stmt.body, InvalidRecord(), *first_loop_vars + ) + + if second_loop_vars is None: + second_loop_vars = () + elif isinstance(second_loop_vars, interp.ReturnValue): + return second_loop_vars + + # take the entries in the first and second loops + # update the parent frame + + unified_frame_buffer = {} + for ssa_val, lattice_element in captured_first_loop_entries.items(): + verified_latticed_element = second_loop_frame.entries[ssa_val].join( + lattice_element ) + # print(f"Joining {lattice_element} and {second_loop_frame.entries[ssa_val]} to get {verified_latticed_element}") + unified_frame_buffer[ssa_val] = verified_latticed_element + + frame.entries.update(unified_frame_buffer) - if loop_vars is None: - loop_vars = () + if captured_first_loop_vars is None or second_loop_vars is None: + return () - elif isinstance(loop_vars, interp.ReturnValue): - return loop_vars + joined_loop_vars = [] + for first_loop_var, second_loop_var in zip( + captured_first_loop_vars, second_loop_vars + ): + joined_loop_vars.append(first_loop_var.join(second_loop_var)) - return loop_vars + return tuple(joined_loop_vars) @interp.impl(scf.stmts.Yield) def for_yield( @@ -156,8 +213,6 @@ def for_yield( return interp.YieldValue(frame.get_values(stmt.values)) -# Only carry about carrying integers for now because -# the current issue is that @py.dialect.register(key="record") class ConstantForwarding(interp.MethodTable): @interp.impl(py.Constant) diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index 5fef779f..b5fc7f5e 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -162,13 +162,11 @@ def test(): squin.annotate.set_detector([prev_ms[0], curr_ms[0]], coordinates=[0, 0]) squin.annotate.set_detector([prev_ms[1], curr_ms[1]], coordinates=[0, 1]) - squin.annotate.set_detector(curr_ms, coordinates=[0, 0]) data_ms = squin.broadcast.measure(data_qs) squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0]) squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) - test.print() SoftFlatten(dialects=test.dialects).fixpoint(test) test.print() frame, _ = RecordAnalysis(dialects=test.dialects).run(test) From cf5cb1fb08ceb416486a59118e6a4b9debc4a22c Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 18 Nov 2025 09:44:34 -0500 Subject: [PATCH 06/26] add a set observable just to really make sure things work properly --- test/analysis/record/test_record_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index b5fc7f5e..951b6d01 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -166,6 +166,7 @@ def test(): squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0]) squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) + squin.set_observable([data_ms[2]]) SoftFlatten(dialects=test.dialects).fixpoint(test) test.print() From 43675fd771fe5cfde8d02017525188258a5230bb Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 20 Nov 2025 09:28:58 -0500 Subject: [PATCH 07/26] fix variable names in test --- _typos.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/_typos.toml b/_typos.toml index 98f4e2e1..beee7372 100644 --- a/_typos.toml +++ b/_typos.toml @@ -14,3 +14,4 @@ mch = "mch" IY = "IY" ket = "ket" typ = "typ" +anc_qs = "anc_qs" # "ancilla qubits" variable abbreivation - used to be autocorrected to AND_QS! :sigh: From d19f3ef76dd9052d29d9195eb2f3b3f9dfb2ac97 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 20 Nov 2025 23:14:27 -0500 Subject: [PATCH 08/26] move record analysis prototype into measure id analysis --- src/bloqade/analysis/measure_id/analysis.py | 44 ++--- src/bloqade/analysis/measure_id/impls.py | 179 +++++++++++++----- src/bloqade/analysis/measure_id/lattice.py | 36 +++- src/bloqade/analysis/record/impls.py | 33 +++- src/bloqade/stim/rewrite/get_record_util.py | 13 +- src/bloqade/stim/rewrite/ifs_to_stim.py | 10 +- .../stim/rewrite/set_detector_to_stim.py | 9 +- .../stim/rewrite/set_observable_to_stim.py | 8 +- .../measure_id/test_new_measure_id.py | 32 ++++ test/analysis/record/test_record_analysis.py | 13 +- 10 files changed, 267 insertions(+), 110 deletions(-) create mode 100644 test/analysis/measure_id/test_new_measure_id.py diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 8b65b2f3..8151909c 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -1,25 +1,38 @@ -from typing import TypeVar from dataclasses import field, dataclass from kirin import ir -from kirin.analysis import ForwardExtra, const +from kirin.analysis import ForwardExtra from kirin.analysis.forward import ForwardFrame -from .lattice import MeasureId, NotMeasureId +from .lattice import MeasureId, NotMeasureId, KnownMeasureId + + +@dataclass +class GlobalRecordState: + buffer: list[KnownMeasureId] = field(default_factory=list) + + # assume that this KnownMeasureId will always be -1 + def add_record_idxs(self, num_new_records: int) -> list[KnownMeasureId]: + # adjust all previous indices + for record_idx in self.buffer: + record_idx.idx -= num_new_records + + # generate new indices and add them to the buffer + new_record_idxs = [KnownMeasureId(-i) for i in range(num_new_records, 0, -1)] + self.buffer += new_record_idxs + # Return for usage, idxs linked to the global state + return new_record_idxs @dataclass class MeasureIDFrame(ForwardFrame[MeasureId]): - num_measures_at_stmt: dict[ir.Statement, int] = field(default_factory=dict) + global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]): keys = ["measure_id"] lattice = MeasureId - # for every kind of measurement encountered, increment this - # then use this to generate the negative values for target rec indices - measure_count = 0 def initialize_frame( self, node: ir.Statement, *, has_parent_access: bool = False @@ -33,22 +46,5 @@ def eval_fallback( ) -> tuple[MeasureId, ...]: return tuple(NotMeasureId() for _ in node.results) - # Xiu-zhe (Roger) Luo came up with this in the address analysis, - # reused here for convenience (now modified to be a bit more graceful) - # TODO: Remove this function once upgrade to kirin 0.18 happens, - # method is built-in to interpreter then - - T = TypeVar("T") - - def get_const_value( - self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue - ) -> type[T] | None: - if isinstance(hint := value.hints.get("const"), const.Value): - data = hint.data - if isinstance(data, input_type): - return hint.data - - return None - def method_self(self, method: ir.Method) -> MeasureId: return self.lattice.bottom() diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index eb5abc22..f0052982 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -1,22 +1,22 @@ +from copy import deepcopy + from kirin import types as kirin_types, interp -from kirin.analysis import const from kirin.dialects import py, scf, func, ilist +from kirin.ir.attrs.py import PyAttr from bloqade import qubit, annotate from .lattice import ( AnyMeasureId, NotMeasureId, - MeasureIdBool, + KnownMeasureId, MeasureIdTuple, + ConstantCarrier, InvalidMeasureId, + ImmutableMeasureIds, ) from .analysis import MeasureIDFrame, MeasurementIDAnalysis -## Can't do wire right now because of -## unresolved RFC on return type -# from bloqade.squin import wire - @qubit.dialect.register(key="measure_id") class SquinQubit(interp.MethodTable): @@ -25,7 +25,7 @@ class SquinQubit(interp.MethodTable): def measure_qubit_list( self, interp: MeasurementIDAnalysis, - frame: interp.Frame, + frame: MeasureIDFrame, stmt: qubit.stmts.Measure, ): @@ -38,26 +38,34 @@ def measure_qubit_list( if not isinstance(num_qubits, kirin_types.Literal): return (AnyMeasureId(),) - measure_id_bools = [] - for _ in range(num_qubits.data): - interp.measure_count += 1 - measure_id_bools.append(MeasureIdBool(interp.measure_count)) + record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data) - return (MeasureIdTuple(data=tuple(measure_id_bools)),) + return (MeasureIdTuple(data=tuple(record_idxs)),) @annotate.dialect.register(key="measure_id") class Annotate(interp.MethodTable): @interp.impl(annotate.stmts.SetObservable) @interp.impl(annotate.stmts.SetDetector) - def consumes_measurement_results( + def consumes_measurements( self, interp: MeasurementIDAnalysis, frame: MeasureIDFrame, stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector, ): - frame.num_measures_at_stmt[stmt] = interp.measure_count - return (NotMeasureId(),) + measure_id_tuple_at_stmt = frame.get(stmt.measurements) + + if not ( + isinstance(measure_id_tuple_at_stmt, MeasureIdTuple) + and kirin_types.is_tuple_of(measure_id_tuple_at_stmt.data, KnownMeasureId) + ): + return (InvalidMeasureId(),) + + final_record_idxs = [ + deepcopy(record_idx) for record_idx in measure_id_tuple_at_stmt.data + ] + + return (ImmutableMeasureIds(data=tuple(final_record_idxs)),) @ilist.dialect.register(key="measure_id") @@ -73,8 +81,7 @@ def new_ilist( stmt: ilist.New, ): - measure_ids_in_ilist = frame.get_values(stmt.values) - return (MeasureIdTuple(data=tuple(measure_ids_in_ilist)),) + return (MeasureIdTuple(frame.get_values(stmt.values)),) @py.tuple.dialect.register(key="measure_id") @@ -94,13 +101,15 @@ def getitem( self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem ): - idx_or_slice = interp.get_const_value((int, slice), stmt.index) - if idx_or_slice is None: - return (InvalidMeasureId(),) - - # hint = stmt.index.hints.get("const") - # if hint is None or not isinstance(hint, const.Value): - # return (InvalidMeasureId(),) + possible_idx_or_slice = interp.maybe_const(stmt.index, (int, slice)) + if possible_idx_or_slice is not None: + idx_or_slice = possible_idx_or_slice + else: + idx_or_slice = frame.get(stmt.index) + if not isinstance(idx_or_slice, ConstantCarrier): + return (InvalidMeasureId(),) + else: + idx_or_slice = idx_or_slice.value obj = frame.get(stmt.obj) if isinstance(obj, MeasureIdTuple): @@ -123,7 +132,9 @@ class PyAssign(interp.MethodTable): def alias( self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias ): - return (frame.get(stmt.value),) + + input = frame.get(stmt.value) + return (input,) @py.binop.dialect.register(key="measure_id") @@ -160,37 +171,105 @@ def invoke( return (ret,) -# Just let analysis propagate through -# scf, particularly IfElse @scf.dialect.register(key="measure_id") -class Scf(scf.absint.Methods): +class LoopHandling(interp.MethodTable): + @interp.impl(scf.stmts.For) + def for_loop( + self, interp_: MeasurementIDAnalysis, frame: MeasureIDFrame, stmt: scf.stmts.For + ): + + init_loop_vars = frame.get_values(stmt.initializers) + + # You go through the loops twice to verify the loop invariant. + # we need to freeze the frame entries right after exiting the loop + + first_loop_frame = MeasureIDFrame( + stmt, + global_record_state=frame.global_record_state, + parent=frame, + has_parent_access=True, + ) + first_loop_vars = interp_.frame_call_region( + first_loop_frame, stmt, stmt.body, InvalidMeasureId(), *init_loop_vars + ) + + if first_loop_vars is None: + first_loop_vars = () + elif isinstance(first_loop_vars, interp.ReturnValue): + return first_loop_vars - @interp.impl(scf.IfElse) - def if_else( + captured_first_loop_entries = {} + captured_first_loop_vars = deepcopy(first_loop_vars) + + for ssa_val, lattice_element in first_loop_frame.entries.items(): + captured_first_loop_entries[ssa_val] = deepcopy(lattice_element) + + second_loop_frame = MeasureIDFrame( + stmt, + global_record_state=frame.global_record_state, + parent=frame, + has_parent_access=True, + ) + second_loop_vars = interp_.frame_call_region( + second_loop_frame, stmt, stmt.body, InvalidMeasureId(), *first_loop_vars + ) + + if second_loop_vars is None: + second_loop_vars = () + elif isinstance(second_loop_vars, interp.ReturnValue): + return second_loop_vars + + # take the entries in the first and second loops + # update the parent frame + + unified_frame_buffer = {} + for ssa_val, lattice_element in captured_first_loop_entries.items(): + verified_latticed_element = second_loop_frame.entries[ssa_val].join( + lattice_element + ) + # print(f"Joining {lattice_element} and {second_loop_frame.entries[ssa_val]} to get {verified_latticed_element}") + unified_frame_buffer[ssa_val] = verified_latticed_element + + frame.entries.update(unified_frame_buffer) + + if captured_first_loop_vars is None or second_loop_vars is None: + return () + + joined_loop_vars = [] + for first_loop_var, second_loop_var in zip( + captured_first_loop_vars, second_loop_vars + ): + joined_loop_vars.append(first_loop_var.join(second_loop_var)) + + return tuple(joined_loop_vars) + + @interp.impl(scf.stmts.Yield) + def for_yield( + self, + interp_: MeasurementIDAnalysis, + frame: MeasureIDFrame, + stmt: scf.stmts.Yield, + ): + return interp.YieldValue(frame.get_values(stmt.values)) + + +@py.dialect.register(key="measure_id") +class ConstantForwarding(interp.MethodTable): + @interp.impl(py.Constant) + def constant( self, interp_: MeasurementIDAnalysis, frame: MeasureIDFrame, - stmt: scf.IfElse, + stmt: py.Constant, ): + # can't use interp_.maybe_const/expect_const because it assumes the data is already + # there to begin with... + if not isinstance(stmt.value, PyAttr): + return (InvalidMeasureId(),) - frame.num_measures_at_stmt[stmt] = interp_.measure_count + expected_int_or_slice = stmt.value.data - # rest of the code taken directly from scf.absint.Methods base implementation + if not isinstance(expected_int_or_slice, (int, slice)): + return (InvalidMeasureId(),) - if isinstance(hint := stmt.cond.hints.get("const"), const.Value): - if hint.data: - return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body) - else: - return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body) - then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body) - else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body) - - match (then_results, else_results): - case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)): - return interp.ReturnValue(then_value.join(else_value)) - case (interp.ReturnValue(then_value), _): - return then_results - case (_, interp.ReturnValue(else_value)): - return else_results - case _: - return interp_.join_results(then_results, else_results) + return (ConstantCarrier(value=expected_int_or_slice),) diff --git a/src/bloqade/analysis/measure_id/lattice.py b/src/bloqade/analysis/measure_id/lattice.py index 34d78b3c..a29f8cb7 100644 --- a/src/bloqade/analysis/measure_id/lattice.py +++ b/src/bloqade/analysis/measure_id/lattice.py @@ -28,9 +28,6 @@ def top(cls) -> "MeasureId": return AnyMeasureId() -# Can pop up if user constructs some list containing a mixture -# of bools from measure results and other places, -# in which case the whole list is invalid @final @dataclass class InvalidMeasureId(MeasureId, metaclass=SingletonMeta): @@ -57,20 +54,15 @@ def is_subseteq(self, other: MeasureId) -> bool: @final @dataclass -class MeasureIdBool(MeasureId): +class KnownMeasureId(MeasureId): idx: int def is_subseteq(self, other: MeasureId) -> bool: - if isinstance(other, MeasureIdBool): + if isinstance(other, KnownMeasureId): return self.idx == other.idx return False -# Might be nice to have some print override -# here so all the CanMeasureId's/other types are consolidated for -# readability - - @final @dataclass class MeasureIdTuple(MeasureId): @@ -80,3 +72,27 @@ def is_subseteq(self, other: MeasureId) -> bool: if isinstance(other, MeasureIdTuple): return all(a.is_subseteq(b) for a, b in zip(self.data, other.data)) return False + + +@final +@dataclass +class ImmutableMeasureIds(MeasureId): + data: tuple[KnownMeasureId, ...] + + def is_subseteq(self, other: MeasureId) -> bool: + if isinstance(other, ImmutableMeasureIds): + return all(a.is_subseteq(b) for a, b in zip(self.data, other.data)) + return False + + +# For now I only care about propagating constant integers or slices, +# things that can be used as indices to list of measurements +@final +@dataclass +class ConstantCarrier(MeasureId): + value: int | slice + + def is_subseteq(self, other: MeasureId) -> bool: + if isinstance(other, ConstantCarrier): + return self.value == other.value + return False diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index 2ff6526c..e6e450f2 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -136,8 +136,39 @@ def alias( @scf.dialect.register(key="record") class LoopHandling(interp.MethodTable): + """ @interp.impl(scf.stmts.For) - def for_loop( + def for_loop_single_pass( + self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For + ): + + init_loop_vars = frame.get_values(stmt.initializers) + + loop_frame = RecordFrame( + stmt, + global_record_state=frame.global_record_state, + parent=frame, + has_parent_access=True, + ) + loop_vars = interp_.frame_call_region( + loop_frame, stmt, stmt.body, InvalidRecord(), *init_loop_vars + ) + + print(frame.global_record_state) + + if loop_vars is None: + return () + elif isinstance(loop_vars, interp.ReturnValue): + return loop_vars + + # update the parent frame with the loop frame entries + frame.entries.update(loop_frame.entries) + + return loop_vars + """ + + @interp.impl(scf.stmts.For) + def for_loop_double_pass( self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For ): diff --git a/src/bloqade/stim/rewrite/get_record_util.py b/src/bloqade/stim/rewrite/get_record_util.py index aaa28261..c06015ac 100644 --- a/src/bloqade/stim/rewrite/get_record_util.py +++ b/src/bloqade/stim/rewrite/get_record_util.py @@ -2,20 +2,17 @@ from kirin.dialects import py from bloqade.stim.dialects import auxiliary -from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple +from bloqade.analysis.measure_id.lattice import KnownMeasureId, MeasureIdTuple -def insert_get_records( - node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int -): +def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple): """ Insert GetRecord statements before the given node """ get_record_ssas = [] - for measure_id_bool in measure_id_tuple.data: - assert isinstance(measure_id_bool, MeasureIdBool) - target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt - idx_stmt = py.constant.Constant(target_rec_idx) + for known_measure_id in measure_id_tuple.data: + assert isinstance(known_measure_id, KnownMeasureId) + idx_stmt = py.constant.Constant(known_measure_id.idx) idx_stmt.insert_before(node) get_record_stmt = auxiliary.GetRecord(idx_stmt.result) get_record_stmt.insert_before(node) diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index 6a15253b..9a9087df 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -14,7 +14,7 @@ from bloqade.analysis.measure_id import MeasureIDFrame from bloqade.stim.dialects.auxiliary import GetRecord from bloqade.analysis.measure_id.lattice import ( - MeasureIdBool, + KnownMeasureId, ) @@ -140,7 +140,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: # Check the condition is a singular MeasurementIdBool - if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool): + if not isinstance(self.measure_frame.entries[stmt.cond], KnownMeasureId): return RewriteResult() # Reusing code from SplitIf, @@ -160,12 +160,10 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: # get necessary measurement ID type from analysis measure_id_bool = self.measure_frame.entries[stmt.cond] - assert isinstance(measure_id_bool, MeasureIdBool) + assert isinstance(measure_id_bool, KnownMeasureId) # generate get record statement - measure_id_idx_stmt = py.Constant( - (measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt] - ) + measure_id_idx_stmt = py.Constant(value=measure_id_bool.idx) get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841 address_attr = stmts[0].qubits.hints.get("address") diff --git a/src/bloqade/stim/rewrite/set_detector_to_stim.py b/src/bloqade/stim/rewrite/set_detector_to_stim.py index 229067a2..ed5dfa3b 100644 --- a/src/bloqade/stim/rewrite/set_detector_to_stim.py +++ b/src/bloqade/stim/rewrite/set_detector_to_stim.py @@ -52,12 +52,13 @@ def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: coord_ssas.append(coord_stmt.result) coord_stmt.insert_before(node) - measure_ids = self.measure_id_frame.entries[node.measurements] + measure_ids = self.measure_id_frame.entries.get(node.measurements, None) + if measure_ids is None: + return RewriteResult() + assert isinstance(measure_ids, MeasureIdTuple) - get_record_list = insert_get_records( - node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] - ) + get_record_list = insert_get_records(node, measure_ids) detector_stmt = Detector( coord=tuple(coord_ssas), targets=tuple(get_record_list) diff --git a/src/bloqade/stim/rewrite/set_observable_to_stim.py b/src/bloqade/stim/rewrite/set_observable_to_stim.py index 39ac14fe..cf9d16e0 100644 --- a/src/bloqade/stim/rewrite/set_observable_to_stim.py +++ b/src/bloqade/stim/rewrite/set_observable_to_stim.py @@ -36,12 +36,12 @@ def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult: idx_stmt = auxiliary.ConstInt(value=0) idx_stmt.insert_before(node) - measure_ids = self.measure_id_frame.entries[node.measurements] + measure_ids = self.measure_id_frame.entries.get(node.measurements, None) + if measure_ids is None: + return RewriteResult() assert isinstance(measure_ids, MeasureIdTuple) - get_record_list = insert_get_records( - node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] - ) + get_record_list = insert_get_records(node, measure_ids) observable_include_stmt = ObservableInclude( idx=idx_stmt.result, targets=tuple(get_record_list) diff --git a/test/analysis/measure_id/test_new_measure_id.py b/test/analysis/measure_id/test_new_measure_id.py new file mode 100644 index 00000000..107ed1cf --- /dev/null +++ b/test/analysis/measure_id/test_new_measure_id.py @@ -0,0 +1,32 @@ +import io + +from kirin import ir + +from bloqade import stim, squin +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass + + +def codegen(mt: ir.Method): + # method should not have any arguments! + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) + emit.initialize() + emit.run(mt) + return buf.getvalue().strip() + + +@squin.kernel +def test_simple_linear(): + + qs = squin.qalloc(4) + m0 = squin.broadcast.measure(qs) + squin.set_detector([m0[0], m0[1]], coordinates=[0, 0]) + m1 = squin.broadcast.measure(qs) + squin.set_detector([m1[0], m1[1]], coordinates=[1, 1]) + + +test_simple_linear.print() +SquinToStimPass(dialects=test_simple_linear.dialects)(test_simple_linear) +test_simple_linear.print() +print(codegen(test_simple_linear)) diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index 951b6d01..07cfb8a7 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -1,7 +1,11 @@ # from kirin.passes.fold import Fold from bloqade import squin + +# from bloqade.stim.passes import SquinToStimPass from bloqade.analysis.record import RecordAnalysis + +# from bloqade.analysis.measure_id import MeasurementIDAnalysis from bloqade.stim.passes.soft_flatten import SoftFlatten """ @@ -151,18 +155,18 @@ def test(): data_qs = [qs[0], qs[2], qs[4]] and_qs = [qs[1], qs[3]] - curr_ms = squin.broadcast.measure(and_qs) + curr_ms = squin.broadcast.measure(and_qs) # 2 meas squin.set_detector([curr_ms[0]], coordinates=[0, 0]) squin.set_detector([curr_ms[1]], coordinates=[0, 1]) for _ in range(5): # prev lives entirely in the loop prev_ms = curr_ms - curr_ms = squin.broadcast.measure(and_qs) + curr_ms = squin.broadcast.measure(and_qs) # another 2 meas squin.annotate.set_detector([prev_ms[0], curr_ms[0]], coordinates=[0, 0]) squin.annotate.set_detector([prev_ms[1], curr_ms[1]], coordinates=[0, 1]) - data_ms = squin.broadcast.measure(data_qs) + data_ms = squin.broadcast.measure(data_qs) # 3 meas squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0]) squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) @@ -173,6 +177,9 @@ def test(): frame, _ = RecordAnalysis(dialects=test.dialects).run(test) test.print(analysis=frame.entries) + # frame, _ = MeasurementIDAnalysis(dialects=test.dialects).run(test) + # test.print(analysis=frame.entries) + assignment_first_rep_code() From 6bbf11f92d61e2b4fdeae399f036e3ccd556cd10 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 21 Nov 2025 10:25:23 -0500 Subject: [PATCH 09/26] add REPEAT statement in stim --- src/bloqade/stim/dialects/__init__.py | 8 ++- src/bloqade/stim/dialects/cf/__init__.py | 3 + src/bloqade/stim/dialects/cf/emit.py | 32 +++++++++++ src/bloqade/stim/dialects/cf/stmts.py | 36 ++++++++++++ src/bloqade/stim/groups.py | 3 +- .../dialects/stim/emit/test_stim_repeat.py | 56 +++++++++++++++++++ 6 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 src/bloqade/stim/dialects/cf/emit.py create mode 100644 test/stim/dialects/stim/emit/test_stim_repeat.py diff --git a/src/bloqade/stim/dialects/__init__.py b/src/bloqade/stim/dialects/__init__.py index deeadc2a..31a7434e 100644 --- a/src/bloqade/stim/dialects/__init__.py +++ b/src/bloqade/stim/dialects/__init__.py @@ -1 +1,7 @@ -from . import gate as gate, noise as noise, collapse as collapse, auxiliary as auxiliary +from . import ( + cf as cf, + gate as gate, + noise as noise, + collapse as collapse, + auxiliary as auxiliary, +) diff --git a/src/bloqade/stim/dialects/cf/__init__.py b/src/bloqade/stim/dialects/cf/__init__.py index e69de29b..873dc7b3 100644 --- a/src/bloqade/stim/dialects/cf/__init__.py +++ b/src/bloqade/stim/dialects/cf/__init__.py @@ -0,0 +1,3 @@ +from .emit import EmitStimCfMethods as EmitStimCfMethods +from .stmts import * # noqa F403 +from ._dialect import dialect as dialect diff --git a/src/bloqade/stim/dialects/cf/emit.py b/src/bloqade/stim/dialects/cf/emit.py new file mode 100644 index 00000000..992f06bd --- /dev/null +++ b/src/bloqade/stim/dialects/cf/emit.py @@ -0,0 +1,32 @@ +from kirin.interp import MethodTable, impl + +from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame + +from . import stmts +from ._dialect import dialect + + +@dialect.register(key="emit.stim") +class EmitStimCfMethods(MethodTable): + + @impl(stmts.REPEAT) + def repeat(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.REPEAT): + + print(stmt.count) + count = frame.get(stmt.count) + frame.write_line(f"REPEAT {count} {{") + with frame.indent(): + + # Assume single block in REPEAT + for inner_stmt in stmt.body.blocks[0].stmts: + inner_stmt_results = emit.frame_eval(frame, inner_stmt) + + match inner_stmt_results: + case tuple(): + frame.set_values(inner_stmt._results, inner_stmt_results) + case _: + continue + + frame.write_line("}") + + return () diff --git a/src/bloqade/stim/dialects/cf/stmts.py b/src/bloqade/stim/dialects/cf/stmts.py index e69de29b..e3d2e64f 100644 --- a/src/bloqade/stim/dialects/cf/stmts.py +++ b/src/bloqade/stim/dialects/cf/stmts.py @@ -0,0 +1,36 @@ +from typing import cast + +from kirin import ir, types +from kirin.decl import info, statement + +from ._dialect import dialect + + +@statement(dialect=dialect, init=False) +class REPEAT(ir.Statement): + """Repeat statement for looping a fixed number of times. + + This statement has a loop count and a body. + """ + + name = "REPEAT" + traits = frozenset({ir.MaybePure(), ir.HasCFG(), ir.SSACFG()}) + count: ir.SSAValue = info.argument(types.Int) + body: ir.Region = info.region(multi=False) + + def __init__( + self, + count: ir.SSAValue, + body: ir.Region | ir.Block, + ): + if body.IS_REGION: + body_region = cast(ir.Region, body) + if body_region.blocks: + body_block = body_region.blocks[0] + else: + body_block = None + else: + body_block = cast(ir.Block, body) + body_region = ir.Region(body_block) + + super().__init__(args=(count,), regions=(body_region,), args_slice={"count": 0}) diff --git a/src/bloqade/stim/groups.py b/src/bloqade/stim/groups.py index fdf6cde8..66d0afec 100644 --- a/src/bloqade/stim/groups.py +++ b/src/bloqade/stim/groups.py @@ -2,11 +2,12 @@ from kirin.passes import Fold, TypeInfer from kirin.dialects import func, debug, ssacfg, lowering -from .dialects import gate, noise, collapse, auxiliary +from .dialects import cf, gate, noise, collapse, auxiliary @ir.dialect_group( [ + cf, noise, gate, auxiliary, diff --git a/test/stim/dialects/stim/emit/test_stim_repeat.py b/test/stim/dialects/stim/emit/test_stim_repeat.py new file mode 100644 index 00000000..e26dee8e --- /dev/null +++ b/test/stim/dialects/stim/emit/test_stim_repeat.py @@ -0,0 +1,56 @@ +import io + +from kirin import ir, types +from kirin.dialects import func + +from bloqade import stim +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.dialects.cf.stmts import REPEAT +from bloqade.stim.dialects.auxiliary.stmts import ConstInt +from bloqade.stim.dialects.gate.stmts.clifford_1q import H, Z + + +def codegen(mt): + # method should not have any arguments! + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) + # emit.initialize() + emit.run(mt) + return buf.getvalue().strip() + + +def test_repeat_emit(): + + num_iter = ConstInt(value=5) + body = ir.Region(ir.Block([])) + q0 = ConstInt(value=0) + q1 = ConstInt(value=1) + body.blocks[0].stmts.append(q0) + body.blocks[0].stmts.append(q1) + targets = (q0.result, q1.result) + body.blocks[0].stmts.append(H(targets=targets)) + body.blocks[0].stmts.append(Z(targets=targets)) + repeat_stmt = REPEAT(count=num_iter.result, body=body) + + repeat_stmt.print() + + block = ir.Block() + block.stmts.append(num_iter) + block.stmts.append(repeat_stmt) + + block.args.append_from(types.MethodType, "self") + gen_func = func.Function( + sym_name="main", + signature=func.Signature( + inputs=(), + output=types.NoneType, + ), + body=ir.Region(block), + ) + + gen_func.print() + + print(codegen(gen_func)) + + +test_repeat_emit() From ca354affcf92be0464213993bb4343a22075ab10 Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 23 Nov 2025 14:45:13 -0500 Subject: [PATCH 10/26] save work before trying to hammer down obnoxious ownership issues --- src/bloqade/analysis/record/analysis.py | 17 ++++++-- src/bloqade/analysis/record/impls.py | 45 ++++++++++++++++++-- test/analysis/address/test_qubit_analysis.py | 20 +++++++++ test/analysis/record/test_record_analysis.py | 2 +- 4 files changed, 76 insertions(+), 8 deletions(-) diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py index a9c544d4..5b5efeba 100644 --- a/src/bloqade/analysis/record/analysis.py +++ b/src/bloqade/analysis/record/analysis.py @@ -4,7 +4,7 @@ from kirin.analysis import ForwardExtra from kirin.analysis.forward import ForwardFrame -from .lattice import Record, RecordIdx +from .lattice import Record, RecordIdx, RecordTuple @dataclass @@ -23,7 +23,11 @@ def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]: # Return for usage, idxs linked to the global state return new_record_idxs - """ + # Need for loop invariance, especially when you + # run the loop twice "behind the scenes". Then + # it isn't sufficient to just have two + # copies of a lattice element point to one entry on the + # buffer def clone_record_idxs(self, record_tuple: RecordTuple) -> RecordTuple: cloned_members = [] for record_idx in record_tuple.members: @@ -34,7 +38,13 @@ def clone_record_idxs(self, record_tuple: RecordTuple) -> RecordTuple: cloned_members.append(cloned_record_idx) return RecordTuple(members=tuple(cloned_members)) - """ + + def offset_existing_records(self, offset: int): + for record_idx in self.buffer: + record_idx.idx -= offset + print("offset is now:", offset) + print("The record idx is now:", record_idx.idx) + # print the record_idx after offsetting """ Might need a free after use! You can keep the size of the list small @@ -48,6 +58,7 @@ def clone_record_idxs(self, record_tuple: RecordTuple) -> RecordTuple: @dataclass class RecordFrame(ForwardFrame): global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) + measure_count_offset: int = 0 class RecordAnalysis(ForwardExtra[RecordFrame, Record]): diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index e6e450f2..05b7de97 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -65,6 +65,12 @@ def measure_qubit_list( if not isinstance(num_qubits, kirin_types.Literal): return (AnyRecord(),) + # increment the parent frame measure count offset. + # Loop analysis relies on local state tracking + # so we use this data after exiting a loop to + # readjust the previous global measure count. + frame.measure_count_offset += num_qubits.data + record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data) return (RecordTuple(members=tuple(record_idxs)),) @@ -129,9 +135,9 @@ def alias( stmt: py.Alias, ): input = frame.get(stmt.value) # expect this to be a RecordTuple - # frame.global_record_state.clone_record_idxs(input) + new_input = frame.global_record_state.clone_record_idxs(input) # two variables share the same references in the global state - return (input,) + return (new_input,) @scf.dialect.register(key="record") @@ -174,15 +180,23 @@ def for_loop_double_pass( init_loop_vars = frame.get_values(stmt.initializers) + # for ssa_val, lattice_element in frame.entries.items(): + # print(f"Before loop: {ssa_val} -> {lattice_element}") + # You go through the loops twice to verify the loop invariant. # we need to freeze the frame entries right after exiting the loop + local_state = deepcopy(frame.global_record_state) + # local_state = GlobalRecordState() + first_loop_frame = RecordFrame( stmt, - global_record_state=frame.global_record_state, + # frame_id = frame.frame_id + 1, + global_record_state=local_state, parent=frame, has_parent_access=True, ) + first_loop_vars = interp_.frame_call_region( first_loop_frame, stmt, stmt.body, InvalidRecord(), *init_loop_vars ) @@ -200,7 +214,8 @@ def for_loop_double_pass( second_loop_frame = RecordFrame( stmt, - global_record_state=frame.global_record_state, + # frame_id = frame.frame_id + 2, + global_record_state=local_state, parent=frame, has_parent_access=True, ) @@ -216,6 +231,18 @@ def for_loop_double_pass( # take the entries in the first and second loops # update the parent frame + # + # debug prints + # print("First loop entries (captured + preserved):") + # stmt.body.print(analysis=captured_first_loop_entries) + # print("First loop entries (based off existing frame values):") + # stmt.body.print(analysis=first_loop_frame.entries) + # print("Second loop entries (via local state)") + # stmt.body.print(analysis=second_loop_frame.entries) + # print("local state after being passed through two loops") + # print(local_state) + # print(frame.global_record_state) + # unified_frame_buffer = {} for ssa_val, lattice_element in captured_first_loop_entries.items(): verified_latticed_element = second_loop_frame.entries[ssa_val].join( @@ -225,6 +252,16 @@ def for_loop_double_pass( unified_frame_buffer[ssa_val] = verified_latticed_element frame.entries.update(unified_frame_buffer) + print( + "number measurements in first loop:", first_loop_frame.measure_count_offset + ) + print("local state after two loops:", local_state) + print("global (parent frame) state", frame.global_record_state) + frame.global_record_state.offset_existing_records( + first_loop_frame.measure_count_offset + ) + print("parent frame after update: should be -4 -3 resp.") + print(frame.global_record_state) if captured_first_loop_vars is None or second_loop_vars is None: return () diff --git a/test/analysis/address/test_qubit_analysis.py b/test/analysis/address/test_qubit_analysis.py index dddf825a..578e676e 100644 --- a/test/analysis/address/test_qubit_analysis.py +++ b/test/analysis/address/test_qubit_analysis.py @@ -5,6 +5,7 @@ from bloqade import qubit, squin from bloqade.analysis import address +from bloqade.stim.passes.soft_flatten import SoftFlatten # test tuple and indexing @@ -265,3 +266,22 @@ def main(): assert ret == address.AddressReg(data=tuple(range(20))) assert analysis.qubit_count == 20 + + +def test_loop_propagation(): + + @squin.kernel + def main(n: int): + qs = squin.qalloc(n) + for _ in range(10): + sub_qs = [qs[0], qs[5]] + squin.cx(sub_qs[0], sub_qs[1]) + + # qalloc needs to be flattened for anything to go through + SoftFlatten(dialects=main.dialects).fixpoint(main) + address_analysis = address.AddressAnalysis(main.dialects) + frame, _ = address_analysis.run(main) + main.print(analysis=frame.entries) + + +test_loop_propagation() diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index 07cfb8a7..065213bf 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -173,7 +173,7 @@ def test(): squin.set_observable([data_ms[2]]) SoftFlatten(dialects=test.dialects).fixpoint(test) - test.print() + # test.print() frame, _ = RecordAnalysis(dialects=test.dialects).run(test) test.print(analysis=frame.entries) From 3f7ad41db2926ae466585b3fa5d58112bb3b479d Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 23 Nov 2025 19:37:44 -0500 Subject: [PATCH 11/26] corrected logic in record analysis prototype with Kai's advice, can start porting things to MeasureIDAnalysis --- src/bloqade/analysis/record/analysis.py | 9 ++++--- src/bloqade/analysis/record/impls.py | 36 +++++++++++++++++-------- src/bloqade/analysis/record/lattice.py | 1 + 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py index 5b5efeba..8e66330a 100644 --- a/src/bloqade/analysis/record/analysis.py +++ b/src/bloqade/analysis/record/analysis.py @@ -12,13 +12,13 @@ class GlobalRecordState: buffer: list[RecordIdx] = field(default_factory=list) # assume that this RecordIdx will always be -1 - def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]: + def add_record_idxs(self, num_new_records: int, id: int) -> list[RecordIdx]: # adjust all previous indices for record_idx in self.buffer: record_idx.idx -= num_new_records # generate new indices and add them to the buffer - new_record_idxs = [RecordIdx(-i) for i in range(num_new_records, 0, -1)] + new_record_idxs = [RecordIdx(-i, id) for i in range(num_new_records, 0, -1)] self.buffer += new_record_idxs # Return for usage, idxs linked to the global state return new_record_idxs @@ -28,10 +28,10 @@ def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]: # it isn't sufficient to just have two # copies of a lattice element point to one entry on the # buffer - def clone_record_idxs(self, record_tuple: RecordTuple) -> RecordTuple: + def clone_record_idxs(self, record_tuple: RecordTuple, id: int) -> RecordTuple: cloned_members = [] for record_idx in record_tuple.members: - cloned_record_idx = RecordIdx(record_idx.idx) + cloned_record_idx = RecordIdx(record_idx.idx, id) # put into the global buffer but also # return an analysis-facing copy self.buffer.append(cloned_record_idx) @@ -59,6 +59,7 @@ def offset_existing_records(self, offset: int): class RecordFrame(ForwardFrame): global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) measure_count_offset: int = 0 + frame_id: int = 0 class RecordAnalysis(ForwardExtra[RecordFrame, Record]): diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index 05b7de97..abb80e19 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -71,7 +71,9 @@ def measure_qubit_list( # readjust the previous global measure count. frame.measure_count_offset += num_qubits.data - record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data) + record_idxs = frame.global_record_state.add_record_idxs( + num_qubits.data, id=frame.frame_id + ) return (RecordTuple(members=tuple(record_idxs)),) @@ -135,7 +137,11 @@ def alias( stmt: py.Alias, ): input = frame.get(stmt.value) # expect this to be a RecordTuple - new_input = frame.global_record_state.clone_record_idxs(input) + # input could belong to another frame and get repossessed with an + # independent copy in this frame. Might need to set a new frame_id here + new_input = frame.global_record_state.clone_record_idxs( + input, id=frame.frame_id + ) # two variables share the same references in the global state return (new_input,) @@ -191,7 +197,7 @@ def for_loop_double_pass( first_loop_frame = RecordFrame( stmt, - # frame_id = frame.frame_id + 1, + frame_id=frame.frame_id + 1, global_record_state=local_state, parent=frame, has_parent_access=True, @@ -214,7 +220,7 @@ def for_loop_double_pass( second_loop_frame = RecordFrame( stmt, - # frame_id = frame.frame_id + 2, + frame_id=frame.frame_id + 2, global_record_state=local_state, parent=frame, has_parent_access=True, @@ -252,16 +258,9 @@ def for_loop_double_pass( unified_frame_buffer[ssa_val] = verified_latticed_element frame.entries.update(unified_frame_buffer) - print( - "number measurements in first loop:", first_loop_frame.measure_count_offset - ) - print("local state after two loops:", local_state) - print("global (parent frame) state", frame.global_record_state) frame.global_record_state.offset_existing_records( first_loop_frame.measure_count_offset ) - print("parent frame after update: should be -4 -3 resp.") - print(frame.global_record_state) if captured_first_loop_vars is None or second_loop_vars is None: return () @@ -272,12 +271,27 @@ def for_loop_double_pass( ): joined_loop_vars.append(first_loop_var.join(second_loop_var)) + # TrimYield is currently disabled meaning that the same RecordIdx + # can get copied into the parent frame twice! As a result + # we need to be careful to only add unique RecordIdx entries + witnessed_record_idxs = set() + for var in joined_loop_vars: + if isinstance(var, RecordTuple): + for member in var.members: + if ( + isinstance(member, RecordIdx) + and member.idx not in witnessed_record_idxs + ): + witnessed_record_idxs.add(member.idx) + frame.global_record_state.buffer.append(member) + return tuple(joined_loop_vars) @interp.impl(scf.stmts.Yield) def for_yield( self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.Yield ): + print("yield encountered, yielding values:", frame.get_values(stmt.values)) return interp.YieldValue(frame.get_values(stmt.values)) diff --git a/src/bloqade/analysis/record/lattice.py b/src/bloqade/analysis/record/lattice.py index 0b4fac9f..2fd065a8 100644 --- a/src/bloqade/analysis/record/lattice.py +++ b/src/bloqade/analysis/record/lattice.py @@ -72,6 +72,7 @@ def is_subseteq(self, other: Record) -> bool: @dataclass class RecordIdx(Record): idx: int + id: int def is_subseteq(self, other: Record) -> bool: if isinstance(other, RecordIdx): From d4113e601e50211ef4b49cfef1d4216f8750c57d Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 23 Nov 2025 19:39:47 -0500 Subject: [PATCH 12/26] get rid of a bunch of debug print statements --- src/bloqade/analysis/record/impls.py | 46 ---------------------------- 1 file changed, 46 deletions(-) diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index abb80e19..b6223be3 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -148,36 +148,6 @@ def alias( @scf.dialect.register(key="record") class LoopHandling(interp.MethodTable): - """ - @interp.impl(scf.stmts.For) - def for_loop_single_pass( - self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For - ): - - init_loop_vars = frame.get_values(stmt.initializers) - - loop_frame = RecordFrame( - stmt, - global_record_state=frame.global_record_state, - parent=frame, - has_parent_access=True, - ) - loop_vars = interp_.frame_call_region( - loop_frame, stmt, stmt.body, InvalidRecord(), *init_loop_vars - ) - - print(frame.global_record_state) - - if loop_vars is None: - return () - elif isinstance(loop_vars, interp.ReturnValue): - return loop_vars - - # update the parent frame with the loop frame entries - frame.entries.update(loop_frame.entries) - - return loop_vars - """ @interp.impl(scf.stmts.For) def for_loop_double_pass( @@ -186,9 +156,6 @@ def for_loop_double_pass( init_loop_vars = frame.get_values(stmt.initializers) - # for ssa_val, lattice_element in frame.entries.items(): - # print(f"Before loop: {ssa_val} -> {lattice_element}") - # You go through the loops twice to verify the loop invariant. # we need to freeze the frame entries right after exiting the loop @@ -237,24 +204,11 @@ def for_loop_double_pass( # take the entries in the first and second loops # update the parent frame - # - # debug prints - # print("First loop entries (captured + preserved):") - # stmt.body.print(analysis=captured_first_loop_entries) - # print("First loop entries (based off existing frame values):") - # stmt.body.print(analysis=first_loop_frame.entries) - # print("Second loop entries (via local state)") - # stmt.body.print(analysis=second_loop_frame.entries) - # print("local state after being passed through two loops") - # print(local_state) - # print(frame.global_record_state) - # unified_frame_buffer = {} for ssa_val, lattice_element in captured_first_loop_entries.items(): verified_latticed_element = second_loop_frame.entries[ssa_val].join( lattice_element ) - # print(f"Joining {lattice_element} and {second_loop_frame.entries[ssa_val]} to get {verified_latticed_element}") unified_frame_buffer[ssa_val] = verified_latticed_element frame.entries.update(unified_frame_buffer) From 49202731faa29baef41d1fb672a4f6c6997d21ed Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 23 Nov 2025 20:01:01 -0500 Subject: [PATCH 13/26] remove more record analysis debug prints and move everything into measurement_id --- src/bloqade/analysis/measure_id/analysis.py | 27 +++++++++++-- src/bloqade/analysis/measure_id/impls.py | 42 ++++++++++++++++++--- src/bloqade/analysis/record/analysis.py | 3 -- src/bloqade/analysis/record/impls.py | 1 - 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 8151909c..36fc3fc9 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -4,7 +4,7 @@ from kirin.analysis import ForwardExtra from kirin.analysis.forward import ForwardFrame -from .lattice import MeasureId, NotMeasureId, KnownMeasureId +from .lattice import MeasureId, NotMeasureId, KnownMeasureId, MeasureIdTuple @dataclass @@ -12,7 +12,7 @@ class GlobalRecordState: buffer: list[KnownMeasureId] = field(default_factory=list) # assume that this KnownMeasureId will always be -1 - def add_record_idxs(self, num_new_records: int) -> list[KnownMeasureId]: + def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: # adjust all previous indices for record_idx in self.buffer: record_idx.idx -= num_new_records @@ -21,12 +21,33 @@ def add_record_idxs(self, num_new_records: int) -> list[KnownMeasureId]: new_record_idxs = [KnownMeasureId(-i) for i in range(num_new_records, 0, -1)] self.buffer += new_record_idxs # Return for usage, idxs linked to the global state - return new_record_idxs + return MeasureIdTuple(data=tuple(new_record_idxs)) + + # Need for loop invariance, especially when you + # run the loop twice "behind the scenes". Then + # it isn't sufficient to just have two + # copies of a lattice element point to one entry on the + # buffer + def clone_record_idxs(self, measure_id_tuple: MeasureIdTuple) -> MeasureIdTuple: + cloned_members = [] + for known_measure_id in measure_id_tuple.data: + assert isinstance(known_measure_id, KnownMeasureId) + cloned_known_measure_id = KnownMeasureId(known_measure_id.idx) + # put into the global buffer but also + # return an analysis-facing copy + self.buffer.append(cloned_known_measure_id) + cloned_members.append(cloned_known_measure_id) + return MeasureIdTuple(data=tuple(cloned_members)) + + def offset_existing_records(self, offset: int): + for record_idx in self.buffer: + record_idx.idx -= offset @dataclass class MeasureIDFrame(ForwardFrame[MeasureId]): global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) + measure_count_offset: int = 0 class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]): diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index f0052982..83666901 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -38,9 +38,15 @@ def measure_qubit_list( if not isinstance(num_qubits, kirin_types.Literal): return (AnyMeasureId(),) - record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data) + # increment the parent frame measure count offset. + # Loop analysis relies on local state tracking + # so we use this data after exiting a loop to + # readjust the previous global measure count. + frame.measure_count_offset += num_qubits.data - return (MeasureIdTuple(data=tuple(record_idxs)),) + measure_id_tuple = frame.global_record_state.add_record_idxs(num_qubits.data) + + return (measure_id_tuple,) @annotate.dialect.register(key="measure_id") @@ -130,11 +136,16 @@ def getitem( class PyAssign(interp.MethodTable): @interp.impl(py.Alias) def alias( - self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias + self, + interp: MeasurementIDAnalysis, + frame: MeasureIDFrame, + stmt: py.assign.Alias, ): input = frame.get(stmt.value) - return (input,) + + new_input = frame.global_record_state.clone_record_idxs(input) + return (new_input,) @py.binop.dialect.register(key="measure_id") @@ -183,9 +194,11 @@ def for_loop( # You go through the loops twice to verify the loop invariant. # we need to freeze the frame entries right after exiting the loop + local_state = deepcopy(frame.global_record_state) + first_loop_frame = MeasureIDFrame( stmt, - global_record_state=frame.global_record_state, + global_record_state=local_state, parent=frame, has_parent_access=True, ) @@ -206,7 +219,7 @@ def for_loop( second_loop_frame = MeasureIDFrame( stmt, - global_record_state=frame.global_record_state, + global_record_state=local_state, parent=frame, has_parent_access=True, ) @@ -231,6 +244,9 @@ def for_loop( unified_frame_buffer[ssa_val] = verified_latticed_element frame.entries.update(unified_frame_buffer) + frame.global_record_state.offset_existing_records( + first_loop_frame.measure_count_offset + ) if captured_first_loop_vars is None or second_loop_vars is None: return () @@ -241,6 +257,20 @@ def for_loop( ): joined_loop_vars.append(first_loop_var.join(second_loop_var)) + # TrimYield is currently disabled meaning that the same RecordIdx + # can get copied into the parent frame twice! As a result + # we need to be careful to only add unique RecordIdx entries + witnessed_record_idxs = set() + for var in joined_loop_vars: + if isinstance(var, MeasureIdTuple): + for member in var.data: + if ( + isinstance(member, KnownMeasureId) + and member.idx not in witnessed_record_idxs + ): + witnessed_record_idxs.add(member.idx) + frame.global_record_state.buffer.append(member) + return tuple(joined_loop_vars) @interp.impl(scf.stmts.Yield) diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py index 8e66330a..4ef20a0e 100644 --- a/src/bloqade/analysis/record/analysis.py +++ b/src/bloqade/analysis/record/analysis.py @@ -42,9 +42,6 @@ def clone_record_idxs(self, record_tuple: RecordTuple, id: int) -> RecordTuple: def offset_existing_records(self, offset: int): for record_idx in self.buffer: record_idx.idx -= offset - print("offset is now:", offset) - print("The record idx is now:", record_idx.idx) - # print the record_idx after offsetting """ Might need a free after use! You can keep the size of the list small diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py index b6223be3..f4a7deaa 100644 --- a/src/bloqade/analysis/record/impls.py +++ b/src/bloqade/analysis/record/impls.py @@ -245,7 +245,6 @@ def for_loop_double_pass( def for_yield( self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.Yield ): - print("yield encountered, yielding values:", frame.get_values(stmt.values)) return interp.YieldValue(frame.get_values(stmt.values)) From a39a126900650922d0f118cb8fe907f1b0d3da35 Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 8 Dec 2025 10:20:05 -0500 Subject: [PATCH 14/26] latest attempt to try to reconcile type lattices --- src/bloqade/analysis/measure_id/analysis.py | 21 +++-- src/bloqade/analysis/measure_id/impls.py | 27 +++--- src/bloqade/analysis/measure_id/lattice.py | 20 ++++- src/bloqade/stim/passes/squin_to_stim.py | 2 + src/bloqade/stim/rewrite/get_record_util.py | 4 +- src/bloqade/stim/rewrite/ifs_to_stim.py | 12 +-- test/analysis/measure_id/test_measure_id.py | 88 +++++++++---------- .../measure_id/test_new_measure_id.py | 32 ------- test/stim/passes/test_squin_qubit_to_stim.py | 1 + 9 files changed, 96 insertions(+), 111 deletions(-) delete mode 100644 test/analysis/measure_id/test_new_measure_id.py diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 36fc3fc9..fcae8fed 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -4,12 +4,17 @@ from kirin.analysis import ForwardExtra from kirin.analysis.forward import ForwardFrame -from .lattice import MeasureId, NotMeasureId, KnownMeasureId, MeasureIdTuple +from .lattice import ( + MeasureId, + NotMeasureId, + RawMeasureId, + MeasureIdTuple, +) @dataclass class GlobalRecordState: - buffer: list[KnownMeasureId] = field(default_factory=list) + buffer: list[RawMeasureId] = field(default_factory=list) # assume that this KnownMeasureId will always be -1 def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: @@ -18,7 +23,7 @@ def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: record_idx.idx -= num_new_records # generate new indices and add them to the buffer - new_record_idxs = [KnownMeasureId(-i) for i in range(num_new_records, 0, -1)] + new_record_idxs = [RawMeasureId(-i) for i in range(num_new_records, 0, -1)] self.buffer += new_record_idxs # Return for usage, idxs linked to the global state return MeasureIdTuple(data=tuple(new_record_idxs)) @@ -30,13 +35,13 @@ def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: # buffer def clone_record_idxs(self, measure_id_tuple: MeasureIdTuple) -> MeasureIdTuple: cloned_members = [] - for known_measure_id in measure_id_tuple.data: - assert isinstance(known_measure_id, KnownMeasureId) - cloned_known_measure_id = KnownMeasureId(known_measure_id.idx) + for raw_measure_id in measure_id_tuple.data: + assert isinstance(raw_measure_id, RawMeasureId) + cloned_raw_measure_id = RawMeasureId(raw_measure_id.idx) # put into the global buffer but also # return an analysis-facing copy - self.buffer.append(cloned_known_measure_id) - cloned_members.append(cloned_known_measure_id) + self.buffer.append(cloned_raw_measure_id) + cloned_members.append(cloned_raw_measure_id) return MeasureIdTuple(data=tuple(cloned_members)) def offset_existing_records(self, offset: int): diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 61b30ac9..c4283507 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -10,11 +10,12 @@ Predicate, AnyMeasureId, NotMeasureId, - KnownMeasureId, + RawMeasureId, MeasureIdTuple, ConstantCarrier, InvalidMeasureId, ImmutableMeasureIds, + PredicatedMeasureId, ) from .analysis import MeasureIDFrame, MeasurementIDAnalysis @@ -61,7 +62,7 @@ def measurement_predicate( ): original_measure_id_tuple = frame.get(stmt.measurements) if not all( - isinstance(measure_id, KnownMeasureId) + isinstance(measure_id, RawMeasureId) for measure_id in original_measure_id_tuple.data ): return (InvalidMeasureId(),) @@ -76,10 +77,10 @@ def measurement_predicate( return (InvalidMeasureId(),) predicate_measure_ids = [ - KnownMeasureId(measure_id.idx, predicate) + PredicatedMeasureId(measure_id.idx, predicate) for measure_id in original_measure_id_tuple.data ] - return (MeasureIdTuple(data=tuple(predicate_measure_ids)),) + return (ImmutableMeasureIds(data=tuple(predicate_measure_ids)),) @gemini.logical.dialect.register(key="measure_id") @@ -100,11 +101,13 @@ def terminal_measurement( return (AnyMeasureId(),) measure_id_bools = [] - for _ in range(num_qubits.data): - interp.measure_count += 1 - measure_id_bools.append(RawMeasureId(interp.measure_count)) + for i in range(num_qubits.data): + measure_id_bools.append(RawMeasureId(idx=-(i + 1))) - return (MeasureIdTuple(data=tuple(measure_id_bools)),) + # Immutable usually desired for stim generation + # but we can reuse it here to indicate + # the measurement ids should not change anymore. + return (ImmutableMeasureIds(data=tuple(measure_id_bools)),) @annotate.dialect.register(key="measure_id") @@ -121,7 +124,9 @@ def consumes_measurements( if not ( isinstance(measure_id_tuple_at_stmt, MeasureIdTuple) - and kirin_types.is_tuple_of(measure_id_tuple_at_stmt.data, KnownMeasureId) + and kirin_types.is_tuple_of( + measure_id_tuple_at_stmt.data, PredicatedMeasureId + ) ): return (InvalidMeasureId(),) @@ -241,7 +246,7 @@ def invoke( @scf.dialect.register(key="measure_id") -class LoopHandling(interp.MethodTable): +class ScfHandling(interp.MethodTable): @interp.impl(scf.stmts.For) def for_loop( self, interp_: MeasurementIDAnalysis, frame: MeasureIDFrame, stmt: scf.stmts.For @@ -323,7 +328,7 @@ def for_loop( if isinstance(var, MeasureIdTuple): for member in var.data: if ( - isinstance(member, KnownMeasureId) + isinstance(member, RawMeasureId) and member.idx not in witnessed_record_idxs ): witnessed_record_idxs.add(member.idx) diff --git a/src/bloqade/analysis/measure_id/lattice.py b/src/bloqade/analysis/measure_id/lattice.py index 5b7c3e6d..415dab2c 100644 --- a/src/bloqade/analysis/measure_id/lattice.py +++ b/src/bloqade/analysis/measure_id/lattice.py @@ -68,12 +68,23 @@ def is_subseteq(self, other: MeasureId) -> bool: @final @dataclass -class KnownMeasureId(MeasureId): +class RawMeasureId(MeasureId): + idx: int + + def is_subseteq(self, other: MeasureId) -> bool: + if isinstance(other, RawMeasureId): + return self.idx == other.idx + return False + + +@final +@dataclass +class PredicatedMeasureId(MeasureId): idx: int predicate: Predicate def is_subseteq(self, other: MeasureId) -> bool: - if isinstance(other, KnownMeasureId): + if isinstance(other, PredicatedMeasureId): return self.idx == other.idx and self.predicate == other.predicate return False @@ -92,7 +103,10 @@ def is_subseteq(self, other: MeasureId) -> bool: @final @dataclass class ImmutableMeasureIds(MeasureId): - data: tuple[KnownMeasureId, ...] + # SetDetector happily consumes RawMeasureIds, but + # for scf.IfElse rewrite with predicates I need to allow + # PredicatedMeasureIds as well. + data: tuple[PredicatedMeasureId | RawMeasureId, ...] def is_subseteq(self, other: MeasureId) -> bool: if isinstance(other, ImmutableMeasureIds): diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index de40986c..2ec96bd0 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -29,6 +29,8 @@ from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim +# from bloqade.stim.passes.soft_flatten import SoftFlatten + @dataclass class SquinToStimPass(Pass): diff --git a/src/bloqade/stim/rewrite/get_record_util.py b/src/bloqade/stim/rewrite/get_record_util.py index c06015ac..f522c0f2 100644 --- a/src/bloqade/stim/rewrite/get_record_util.py +++ b/src/bloqade/stim/rewrite/get_record_util.py @@ -2,7 +2,7 @@ from kirin.dialects import py from bloqade.stim.dialects import auxiliary -from bloqade.analysis.measure_id.lattice import KnownMeasureId, MeasureIdTuple +from bloqade.analysis.measure_id.lattice import MeasureIdTuple, PredicatedMeasureId def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple): @@ -11,7 +11,7 @@ def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple): """ get_record_ssas = [] for known_measure_id in measure_id_tuple.data: - assert isinstance(known_measure_id, KnownMeasureId) + assert isinstance(known_measure_id, PredicatedMeasureId) idx_stmt = py.constant.Constant(known_measure_id.idx) idx_stmt.insert_before(node) get_record_stmt = auxiliary.GetRecord(idx_stmt.result) diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index cf9b45d6..52170a3e 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -13,7 +13,7 @@ from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ from bloqade.analysis.measure_id import MeasureIDFrame from bloqade.stim.dialects.auxiliary import GetRecord -from bloqade.analysis.measure_id.lattice import Predicate, KnownMeasureId +from bloqade.analysis.measure_id.lattice import Predicate, PredicatedMeasureId @dataclass @@ -140,7 +140,7 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: condition_type = self.measure_frame.entries.get(stmt.cond) # Check the condition is a singular MeasurementIdBool and that # it was generated by querying if the measurement is equivalent to the one state - if not isinstance(condition_type, KnownMeasureId): + if not isinstance(condition_type, PredicatedMeasureId): return RewriteResult() if condition_type.predicate != Predicate.IS_ONE: @@ -162,12 +162,8 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: return RewriteResult() # generate get record statement - num_measures = self.measure_frame.num_measures_at_stmt.get(stmt) - if num_measures is None: - return RewriteResult() - - measure_id_idx_stmt = py.Constant((condition_type.idx - 1) - num_measures) - get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841 + measure_id_idx_stmt = py.Constant(condition_type.idx) + get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) address_attr = stmts[0].qubits.hints.get("address") diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index 9745e1d7..1b6f8b32 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -1,4 +1,4 @@ -from kirin.dialects import scf +import pytest from kirin.passes.inline import InlinePass from bloqade import squin, gemini @@ -8,9 +8,10 @@ Predicate, NotMeasureId, RawMeasureId, - MeasureIdBool, MeasureIdTuple, InvalidMeasureId, + ImmutableMeasureIds, + PredicatedMeasureId, ) @@ -28,21 +29,21 @@ def results_of_variables(kernel, variable_names): return results -def test_subset_eq_MeasureIdBool(): +def test_subset_eq_PredicatedMeasureId(): - m0 = MeasureIdBool(idx=1, predicate=Predicate.IS_ONE) - m1 = MeasureIdBool(idx=1, predicate=Predicate.IS_ONE) + m0 = PredicatedMeasureId(idx=1, predicate=Predicate.IS_ONE) + m1 = PredicatedMeasureId(idx=1, predicate=Predicate.IS_ONE) assert m0.is_subseteq(m1) # not equivalent if predicate is different - m2 = MeasureIdBool(idx=1, predicate=Predicate.IS_ZERO) + m2 = PredicatedMeasureId(idx=1, predicate=Predicate.IS_ZERO) assert not m0.is_subseteq(m2) # not equivalent if index is different either, # they are only equivalent if both index and predicate match - m3 = MeasureIdBool(idx=2, predicate=Predicate.IS_ONE) + m3 = PredicatedMeasureId(idx=2, predicate=Predicate.IS_ONE) assert not m0.is_subseteq(m3) @@ -69,8 +70,9 @@ def test(): # construct expected MeasureIdTuple expected_measure_id_tuple = MeasureIdTuple( - data=tuple([RawMeasureId(idx=i) for i in range(1, 11)]) + data=tuple([RawMeasureId(idx=i) for i in range(-10, 0)]) ) + assert measure_id_tuples[-1] == expected_measure_id_tuple @@ -94,7 +96,7 @@ def test(): # construct expected MeasureIdTuples measure_id_tuple_with_id_bools = MeasureIdTuple( - data=tuple([RawMeasureId(idx=i) for i in range(1, 6)]) + data=tuple([RawMeasureId(idx=i) for i in range(-5, 0)]) ) measure_id_tuple_with_not_measures = MeasureIdTuple( data=tuple([NotMeasureId() for _ in range(5)]) @@ -111,30 +113,7 @@ def test(): ) -def test_measure_count_at_if_else(): - - @squin.kernel - def test(): - q = squin.qalloc(5) - squin.x(q[2]) - ms = squin.broadcast.measure(q) - - if ms[1]: - squin.x(q[0]) - - if ms[3]: - squin.y(q[1]) - - Flatten(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run(test) - - assert all( - isinstance(stmt, scf.IfElse) and measures_accumulated == 5 - for stmt, measures_accumulated in frame.num_measures_at_stmt.items() - ) - - -def test_scf_cond_true(): +def scf_cond_true(): @squin.kernel def test(): q = squin.qalloc(3) @@ -143,7 +122,7 @@ def test(): ms = None cond = True if cond: - ms = squin.measure(q[1]) + ms = squin.measure(q[1]) # need to enter the if-else else: ms = squin.measure(q[0]) @@ -151,6 +130,7 @@ def test(): InlinePass(test.dialects).fixpoint(test) frame, _ = MeasurementIDAnalysis(test.dialects).run(test) + test.print(analysis=frame.entries) # MeasureIdBool(idx=1) should occur twice: # First from the measurement in the true branch, then @@ -161,6 +141,7 @@ def test(): assert len(analysis_results) == 2 +@pytest.mark.xfail def test_scf_cond_false(): @squin.kernel @@ -191,6 +172,7 @@ def test(): assert len(analysis_results) == 2 +@pytest.mark.xfail def test_scf_cond_unknown(): @squin.kernel @@ -242,15 +224,15 @@ def test(): # This is an assertion against `msi` NOT the initial list of measurements assert frame.get(results["msi"]) == MeasureIdTuple( - data=tuple(list(RawMeasureId(idx=i) for i in range(2, 7))) + data=tuple(list(RawMeasureId(idx=i) for i in range(-5, 0))) ) # msi2 assert frame.get(results["msi2"]) == MeasureIdTuple( - data=tuple(list(RawMeasureId(idx=i) for i in range(3, 7))) + data=tuple(list(RawMeasureId(idx=i) for i in range(-4, 0))) ) # ms_final assert frame.get(results["ms_final"]) == MeasureIdTuple( - data=(RawMeasureId(idx=3), RawMeasureId(idx=5)) + data=(RawMeasureId(idx=-4), RawMeasureId(idx=-2)) ) @@ -320,19 +302,28 @@ def test(): test, ("is_zero_bools", "is_one_bools", "is_lost_bools") ) - expected_is_zero_bools = MeasureIdTuple( + expected_is_zero_bools = ImmutableMeasureIds( data=tuple( - [MeasureIdBool(idx=i, predicate=Predicate.IS_ZERO) for i in range(1, 4)] + [ + PredicatedMeasureId(idx=i, predicate=Predicate.IS_ZERO) + for i in range(-3, 0) + ] ) ) - expected_is_one_bools = MeasureIdTuple( + expected_is_one_bools = ImmutableMeasureIds( data=tuple( - [MeasureIdBool(idx=i, predicate=Predicate.IS_ONE) for i in range(1, 4)] + [ + PredicatedMeasureId(idx=i, predicate=Predicate.IS_ONE) + for i in range(-3, 0) + ] ) ) - expected_is_lost_bools = MeasureIdTuple( + expected_is_lost_bools = ImmutableMeasureIds( data=tuple( - [MeasureIdBool(idx=i, predicate=Predicate.IS_LOST) for i in range(1, 4)] + [ + PredicatedMeasureId(idx=i, predicate=Predicate.IS_LOST) + for i in range(-3, 0) + ] ) ) @@ -343,7 +334,9 @@ def test(): def test_terminal_logical_measurement(): - @gemini.logical.kernel(no_raise=False, typeinfer=True, aggressive_unroll=True) + @gemini.logical.kernel( + no_raise=False, typeinfer=True, aggressive_unroll=True, verify=False + ) def tm_logical_kernel(): q = squin.qalloc(3) tm = gemini.logical.terminal_measure(q) @@ -352,10 +345,11 @@ def tm_logical_kernel(): frame, _ = MeasurementIDAnalysis(tm_logical_kernel.dialects).run(tm_logical_kernel) # will have a MeasureIdTuple that's not from the terminal measurement, # basically a container of InvalidMeasureIds from the qubits that get allocated + tm_logical_kernel.print(analysis=frame.entries) analysis_results = [ - val for val in frame.entries.values() if isinstance(val, MeasureIdTuple) + val for val in frame.entries.values() if isinstance(val, ImmutableMeasureIds) ] - expected_result = MeasureIdTuple( - data=tuple([RawMeasureId(idx=i) for i in range(1, 4)]) + expected_result = ImmutableMeasureIds( + data=tuple([RawMeasureId(idx=-i) for i in range(1, 4)]) ) assert expected_result in analysis_results diff --git a/test/analysis/measure_id/test_new_measure_id.py b/test/analysis/measure_id/test_new_measure_id.py deleted file mode 100644 index 107ed1cf..00000000 --- a/test/analysis/measure_id/test_new_measure_id.py +++ /dev/null @@ -1,32 +0,0 @@ -import io - -from kirin import ir - -from bloqade import stim, squin -from bloqade.stim.emit import EmitStimMain -from bloqade.stim.passes import SquinToStimPass - - -def codegen(mt: ir.Method): - # method should not have any arguments! - buf = io.StringIO() - emit = EmitStimMain(dialects=stim.main, io=buf) - emit.initialize() - emit.run(mt) - return buf.getvalue().strip() - - -@squin.kernel -def test_simple_linear(): - - qs = squin.qalloc(4) - m0 = squin.broadcast.measure(qs) - squin.set_detector([m0[0], m0[1]], coordinates=[0, 0]) - m1 = squin.broadcast.measure(qs) - squin.set_detector([m1[0], m1[1]], coordinates=[1, 1]) - - -test_simple_linear.print() -SquinToStimPass(dialects=test_simple_linear.dialects)(test_simple_linear) -test_simple_linear.print() -print(codegen(test_simple_linear)) diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index 10e14883..f4768f6f 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -285,6 +285,7 @@ def test(): sq.z(q[2]) SquinToStimPass(test.dialects)(test) + test.print() base_stim_prog = load_reference_program("valid_if_measure_predicate.stim") assert codegen(test) == base_stim_prog.rstrip() From b4c455968b29122b53f0b6140b88db71716a447d Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 9 Dec 2025 08:44:09 -0500 Subject: [PATCH 15/26] still need to find a solution to proper IfElse handling --- src/bloqade/analysis/measure_id/impls.py | 28 +++++++---- src/bloqade/analysis/measure_id/lattice.py | 15 +----- src/bloqade/stim/passes/soft_flatten.py | 6 +-- src/bloqade/stim/rewrite/get_record_util.py | 4 +- test/analysis/measure_id/test_measure_id.py | 51 +++++++++++++++++---- test/stim/passes/test_annotation_to_stim.py | 5 ++ 6 files changed, 71 insertions(+), 38 deletions(-) diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index c4283507..9a22c41f 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -14,7 +14,6 @@ MeasureIdTuple, ConstantCarrier, InvalidMeasureId, - ImmutableMeasureIds, PredicatedMeasureId, ) from .analysis import MeasureIDFrame, MeasurementIDAnalysis @@ -80,7 +79,7 @@ def measurement_predicate( PredicatedMeasureId(measure_id.idx, predicate) for measure_id in original_measure_id_tuple.data ] - return (ImmutableMeasureIds(data=tuple(predicate_measure_ids)),) + return (MeasureIdTuple(data=tuple(predicate_measure_ids)),) @gemini.logical.dialect.register(key="measure_id") @@ -107,7 +106,7 @@ def terminal_measurement( # Immutable usually desired for stim generation # but we can reuse it here to indicate # the measurement ids should not change anymore. - return (ImmutableMeasureIds(data=tuple(measure_id_bools)),) + return (MeasureIdTuple(data=tuple(measure_id_bools), immutable=True),) @annotate.dialect.register(key="measure_id") @@ -124,9 +123,7 @@ def consumes_measurements( if not ( isinstance(measure_id_tuple_at_stmt, MeasureIdTuple) - and kirin_types.is_tuple_of( - measure_id_tuple_at_stmt.data, PredicatedMeasureId - ) + and kirin_types.is_tuple_of(measure_id_tuple_at_stmt.data, RawMeasureId) ): return (InvalidMeasureId(),) @@ -134,7 +131,7 @@ def consumes_measurements( deepcopy(record_idx) for record_idx in measure_id_tuple_at_stmt.data ] - return (ImmutableMeasureIds(data=tuple(final_record_idxs)),) + return (MeasureIdTuple(data=tuple(final_record_idxs), immutable=True),) @ilist.dialect.register(key="measure_id") @@ -150,7 +147,7 @@ def new_ilist( stmt: ilist.New, ): - return (MeasureIdTuple(frame.get_values(stmt.values)),) + return (MeasureIdTuple(data=frame.get_values(stmt.values)),) @py.tuple.dialect.register(key="measure_id") @@ -345,6 +342,21 @@ def for_yield( ): return interp.YieldValue(frame.get_values(stmt.values)) + @interp.impl(scf.stmts.IfElse) + def if_else( + self, + interp_: MeasurementIDAnalysis, + frame: MeasureIDFrame, + stmt: scf.stmts.IfElse, + ): + cond_measure_id = frame.get(stmt.cond) + assert type(cond_measure_id) is PredicatedMeasureId + detached_cond_measure_id = PredicatedMeasureId( + idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate + ) + # remove underlying reference to the frame + frame.set(stmt.cond, detached_cond_measure_id) + @py.dialect.register(key="measure_id") class ConstantForwarding(interp.MethodTable): diff --git a/src/bloqade/analysis/measure_id/lattice.py b/src/bloqade/analysis/measure_id/lattice.py index 415dab2c..5bf71ccf 100644 --- a/src/bloqade/analysis/measure_id/lattice.py +++ b/src/bloqade/analysis/measure_id/lattice.py @@ -93,6 +93,7 @@ def is_subseteq(self, other: MeasureId) -> bool: @dataclass class MeasureIdTuple(MeasureId): data: tuple[MeasureId, ...] + immutable: bool = False def is_subseteq(self, other: MeasureId) -> bool: if isinstance(other, MeasureIdTuple): @@ -100,20 +101,6 @@ def is_subseteq(self, other: MeasureId) -> bool: return False -@final -@dataclass -class ImmutableMeasureIds(MeasureId): - # SetDetector happily consumes RawMeasureIds, but - # for scf.IfElse rewrite with predicates I need to allow - # PredicatedMeasureIds as well. - data: tuple[PredicatedMeasureId | RawMeasureId, ...] - - def is_subseteq(self, other: MeasureId) -> bool: - if isinstance(other, ImmutableMeasureIds): - return all(a.is_subseteq(b) for a, b in zip(self.data, other.data)) - return False - - # For now I only care about propagating constant integers or slices, # things that can be used as indices to list of measurements @final diff --git a/src/bloqade/stim/passes/soft_flatten.py b/src/bloqade/stim/passes/soft_flatten.py index 11626788..cf4fbe6a 100644 --- a/src/bloqade/stim/passes/soft_flatten.py +++ b/src/bloqade/stim/passes/soft_flatten.py @@ -81,12 +81,10 @@ class SoftFlatten(Pass): def __post_init__(self): self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise) - - # DO NOT USE FOR NOW, TrimUnusedYield call messes up loop structure - # self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) + self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) def unsafe_run(self, mt: ir.Method) -> RewriteResult: rewrite_result = RewriteResult() - # rewrite_result = self.simplify_if(mt).join(rewrite_result) + rewrite_result = self.simplify_if(mt).join(rewrite_result) rewrite_result = self.unroll(mt).join(rewrite_result) return rewrite_result diff --git a/src/bloqade/stim/rewrite/get_record_util.py b/src/bloqade/stim/rewrite/get_record_util.py index f522c0f2..1db02bd6 100644 --- a/src/bloqade/stim/rewrite/get_record_util.py +++ b/src/bloqade/stim/rewrite/get_record_util.py @@ -2,7 +2,7 @@ from kirin.dialects import py from bloqade.stim.dialects import auxiliary -from bloqade.analysis.measure_id.lattice import MeasureIdTuple, PredicatedMeasureId +from bloqade.analysis.measure_id.lattice import RawMeasureId, MeasureIdTuple def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple): @@ -11,7 +11,7 @@ def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple): """ get_record_ssas = [] for known_measure_id in measure_id_tuple.data: - assert isinstance(known_measure_id, PredicatedMeasureId) + assert isinstance(known_measure_id, RawMeasureId) idx_stmt = py.constant.Constant(known_measure_id.idx) idx_stmt.insert_before(node) get_record_stmt = auxiliary.GetRecord(idx_stmt.result) diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index 1b6f8b32..d9b1b164 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -4,13 +4,15 @@ from bloqade import squin, gemini from bloqade.analysis.measure_id import MeasurementIDAnalysis from bloqade.stim.passes.flatten import Flatten + +# from bloqade.stim.passes.soft_flatten import SoftFlatten +from bloqade.stim.passes.squin_to_stim import SquinToStimPass from bloqade.analysis.measure_id.lattice import ( Predicate, NotMeasureId, RawMeasureId, MeasureIdTuple, InvalidMeasureId, - ImmutableMeasureIds, PredicatedMeasureId, ) @@ -302,29 +304,29 @@ def test(): test, ("is_zero_bools", "is_one_bools", "is_lost_bools") ) - expected_is_zero_bools = ImmutableMeasureIds( + expected_is_zero_bools = MeasureIdTuple( data=tuple( [ PredicatedMeasureId(idx=i, predicate=Predicate.IS_ZERO) for i in range(-3, 0) ] - ) + ), ) - expected_is_one_bools = ImmutableMeasureIds( + expected_is_one_bools = MeasureIdTuple( data=tuple( [ PredicatedMeasureId(idx=i, predicate=Predicate.IS_ONE) for i in range(-3, 0) ] - ) + ), ) - expected_is_lost_bools = ImmutableMeasureIds( + expected_is_lost_bools = MeasureIdTuple( data=tuple( [ PredicatedMeasureId(idx=i, predicate=Predicate.IS_LOST) for i in range(-3, 0) ] - ) + ), ) assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools @@ -347,9 +349,38 @@ def tm_logical_kernel(): # basically a container of InvalidMeasureIds from the qubits that get allocated tm_logical_kernel.print(analysis=frame.entries) analysis_results = [ - val for val in frame.entries.values() if isinstance(val, ImmutableMeasureIds) + val for val in frame.entries.values() if isinstance(val, MeasureIdTuple) ] - expected_result = ImmutableMeasureIds( - data=tuple([RawMeasureId(idx=-i) for i in range(1, 4)]) + expected_result = MeasureIdTuple( + data=tuple([RawMeasureId(idx=-i) for i in range(1, 4)]), + immutable=True, ) assert expected_result in analysis_results + + +def test_if_else_happy_path(): + + @squin.kernel + def test(): + qs = squin.qalloc(3) + ms = squin.broadcast.measure(qs) + # predicate + pred_ms = squin.broadcast.is_one(ms) + squin.broadcast.measure(qs) + squin.broadcast.measure(qs) + if pred_ms[0]: + squin.x(qs[1]) + + return + + # Flatten(test.dialects).fixpoint(test) + # SoftFlatten(test.dialects).fixpoint(test) + test.print() + SquinToStimPass(test.dialects)(test) + test.print() + # test.print() + # frame, _ = MeasurementIDAnalysis(test.dialects).run(test) + # test.print(analysis=frame.entries) + + +test_if_else_happy_path() diff --git a/test/stim/passes/test_annotation_to_stim.py b/test/stim/passes/test_annotation_to_stim.py index bf79cca5..ba32fa9b 100644 --- a/test/stim/passes/test_annotation_to_stim.py +++ b/test/stim/passes/test_annotation_to_stim.py @@ -148,10 +148,15 @@ def main(): return + main.print() SquinToStimPass(main.dialects, no_raise=True)(main) + main.print() assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) +test_missing_predicate() + + def test_incorrect_predicate(): # You can only rewrite squin.is_one(...) predicates to From 927fcd8e936ffa18f019376dc5e2defa6371991e Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 9 Dec 2025 09:57:05 -0500 Subject: [PATCH 16/26] figured out way to handle scf.IfElse with new lattice --- src/bloqade/analysis/measure_id/analysis.py | 9 ++++++++- src/bloqade/analysis/measure_id/impls.py | 17 +++++++++++------ src/bloqade/stim/rewrite/ifs_to_stim.py | 10 ++++++++-- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index fcae8fed..580afc2c 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -14,9 +14,12 @@ @dataclass class GlobalRecordState: + # every time a cond value is encountered inside scf + # detach and save it here because I need to let it update + # if it gets used again somewhere else + type_for_scf_conds: dict[ir.Statement, MeasureId] = field(default_factory=dict) buffer: list[RawMeasureId] = field(default_factory=list) - # assume that this KnownMeasureId will always be -1 def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: # adjust all previous indices for record_idx in self.buffer: @@ -52,6 +55,10 @@ def offset_existing_records(self, offset: int): @dataclass class MeasureIDFrame(ForwardFrame[MeasureId]): global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) + # every time a cond value is encountered inside scf + # detach and save it here because I need to let it update + # if it gets used again somewhere else + type_for_scf_conds: dict[ir.Statement, MeasureId] = field(default_factory=dict) measure_count_offset: int = 0 diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 9a22c41f..7462db26 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -350,12 +350,17 @@ def if_else( stmt: scf.stmts.IfElse, ): cond_measure_id = frame.get(stmt.cond) - assert type(cond_measure_id) is PredicatedMeasureId - detached_cond_measure_id = PredicatedMeasureId( - idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate - ) - # remove underlying reference to the frame - frame.set(stmt.cond, detached_cond_measure_id) + if isinstance(cond_measure_id, PredicatedMeasureId): + detached_cond_measure_id = PredicatedMeasureId( + idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate + ) + frame.type_for_scf_conds[stmt] = detached_cond_measure_id + return + + # If you don't get a PredicatedMeasureId, don't bother + # converting anything + frame.type_for_scf_conds[stmt] = InvalidMeasureId() + # nothing to return, this thing already lives on the @py.dialect.register(key="measure_id") diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index 52170a3e..a3f9063c 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -13,7 +13,11 @@ from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ from bloqade.analysis.measure_id import MeasureIDFrame from bloqade.stim.dialects.auxiliary import GetRecord -from bloqade.analysis.measure_id.lattice import Predicate, PredicatedMeasureId +from bloqade.analysis.measure_id.lattice import ( + Predicate, + InvalidMeasureId, + PredicatedMeasureId, +) @dataclass @@ -137,7 +141,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: - condition_type = self.measure_frame.entries.get(stmt.cond) + condition_type = self.measure_frame.type_for_scf_conds.get(stmt) + if condition_type is None or condition_type is InvalidMeasureId(): + return RewriteResult() # Check the condition is a singular MeasurementIdBool and that # it was generated by querying if the measurement is equivalent to the one state if not isinstance(condition_type, PredicatedMeasureId): From 4d62a473c3718216532f5102336f94ee4b58bbb5 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 10 Dec 2025 10:03:19 -0500 Subject: [PATCH 17/26] get decent portion of qubit and annotate to stim tests fully working, keep some parity with the Bloqade version of Fold et al. --- src/bloqade/analysis/measure_id/analysis.py | 46 ++++-- src/bloqade/analysis/measure_id/impls.py | 13 +- .../stim/passes/flatten_except_loops.py | 150 ++++++++++++++++++ src/bloqade/stim/passes/soft_flatten.py | 90 ----------- src/bloqade/stim/passes/squin_to_stim.py | 10 +- .../stim/rewrite/set_detector_to_stim.py | 3 +- test/analysis/record/test_record_analysis.py | 6 +- test/stim/passes/test_annotation_to_stim.py | 5 +- test/stim/passes/test_squin_qubit_to_stim.py | 5 +- 9 files changed, 208 insertions(+), 120 deletions(-) create mode 100644 src/bloqade/stim/passes/flatten_except_loops.py delete mode 100644 src/bloqade/stim/passes/soft_flatten.py diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 580afc2c..423fd589 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -9,6 +9,7 @@ NotMeasureId, RawMeasureId, MeasureIdTuple, + PredicatedMeasureId, ) @@ -18,7 +19,7 @@ class GlobalRecordState: # detach and save it here because I need to let it update # if it gets used again somewhere else type_for_scf_conds: dict[ir.Statement, MeasureId] = field(default_factory=dict) - buffer: list[RawMeasureId] = field(default_factory=list) + buffer: list[RawMeasureId | PredicatedMeasureId] = field(default_factory=list) def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: # adjust all previous indices @@ -36,17 +37,44 @@ def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple: # it isn't sufficient to just have two # copies of a lattice element point to one entry on the # buffer - def clone_record_idxs(self, measure_id_tuple: MeasureIdTuple) -> MeasureIdTuple: + + def clone_measure_id_tuple( + self, measure_id_tuple: MeasureIdTuple + ) -> MeasureIdTuple: cloned_members = [] - for raw_measure_id in measure_id_tuple.data: - assert isinstance(raw_measure_id, RawMeasureId) - cloned_raw_measure_id = RawMeasureId(raw_measure_id.idx) - # put into the global buffer but also - # return an analysis-facing copy - self.buffer.append(cloned_raw_measure_id) - cloned_members.append(cloned_raw_measure_id) + for measure_id in measure_id_tuple.data: + cloned_measure_id = self.clone_measure_ids(measure_id) + cloned_members.append(cloned_measure_id) return MeasureIdTuple(data=tuple(cloned_members)) + def clone_raw_measure_id(self, raw_measure_id: RawMeasureId) -> RawMeasureId: + cloned_raw_measure_id = RawMeasureId(raw_measure_id.idx) + self.buffer.append(cloned_raw_measure_id) + return cloned_raw_measure_id + + def clone_predicated_measure_id( + self, predicated_measure_id: PredicatedMeasureId + ) -> PredicatedMeasureId: + cloned_predicated_measure_id = PredicatedMeasureId( + idx=predicated_measure_id.idx, + predicate=predicated_measure_id.predicate, + ) + self.buffer.append(cloned_predicated_measure_id) + return cloned_predicated_measure_id + + def clone_measure_ids(self, measure_id_type: MeasureId) -> MeasureId: + + if isinstance(measure_id_type, RawMeasureId): + return self.clone_raw_measure_id(measure_id_type) + elif isinstance(measure_id_type, PredicatedMeasureId): + return self.clone_predicated_measure_id(measure_id_type) + elif isinstance(measure_id_type, MeasureIdTuple): + cloned_members = [] + for member in measure_id_type.data: + cloned_member = self.clone_measure_ids(member) + cloned_members.append(cloned_member) + return MeasureIdTuple(data=tuple(cloned_members)) + def offset_existing_records(self, offset: int): for record_idx in self.buffer: record_idx.idx -= offset diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 7462db26..2e165df2 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -127,11 +127,12 @@ def consumes_measurements( ): return (InvalidMeasureId(),) - final_record_idxs = [ - deepcopy(record_idx) for record_idx in measure_id_tuple_at_stmt.data + final_measure_ids = [ + deepcopy(measure_id_element) + for measure_id_element in measure_id_tuple_at_stmt.data ] - return (MeasureIdTuple(data=tuple(final_record_idxs), immutable=True),) + return (MeasureIdTuple(data=tuple(final_measure_ids), immutable=True),) @ilist.dialect.register(key="measure_id") @@ -203,9 +204,11 @@ def alias( ): input = frame.get(stmt.value) + attempted_cloned_input = frame.global_record_state.clone_measure_ids(input) + if attempted_cloned_input is None: + return (input,) - new_input = frame.global_record_state.clone_record_idxs(input) - return (new_input,) + return (attempted_cloned_input,) @py.binop.dialect.register(key="measure_id") diff --git a/src/bloqade/stim/passes/flatten_except_loops.py b/src/bloqade/stim/passes/flatten_except_loops.py new file mode 100644 index 00000000..e7b3a630 --- /dev/null +++ b/src/bloqade/stim/passes/flatten_except_loops.py @@ -0,0 +1,150 @@ +# Taken from Phillip Weinberg's bloqade-shuttle implementation +from typing import Callable +from dataclasses import field, dataclass + +from kirin import ir +from kirin.passes import Pass, TypeInfer + +# from kirin.passes.aggressive import UnrollScf +from kirin.rewrite import ( + Walk, + Chain, + Inline, + Fixpoint, + CFGCompactify, + DeadCodeElimination, + CommonSubexpressionElimination, +) +from kirin.analysis import const +from kirin.dialects import py, scf, ilist +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.dialects.scf.unroll import PickIfElse + +# from bloqade.qasm2.passes.fold import AggressiveUnroll +from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs + +# this fold is different from the one in Kirin +from bloqade.rewrite.passes.aggressive_unroll import Fold as BloqadeFold +from bloqade.rewrite.passes.canonicalize_ilist import CanonicalizeIList + + +class ForLoopNoIterDependance(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if not isinstance(node, scf.For): + return RewriteResult() + + # If the iterator is not being depended on at all, + # take that as a sign that REPEAT should be generated. + # If the iterator IS being dependent on, we can to fall back + # to unrolling. + if not bool(node.body.blocks[0].args[0].uses): + return RewriteResult() + + # TODO: support for PartialTuple and IList with known length + if not isinstance(hint := node.iterable.hints.get("const"), const.Value): + return RewriteResult() + + loop_vars = node.initializers + for item in hint.data: + body = node.body.clone() + block = body.blocks[0] + item_stmt = py.Constant(item) + item_stmt.insert_before(node) + block.args[0].replace_by(item_stmt.result) + for var, input in zip(block.args[1:], loop_vars): + var.replace_by(input) + + block_stmt = block.first_stmt + while block_stmt and not block_stmt.has_trait(ir.IsTerminator): + block_stmt.detach() + block_stmt.insert_before(node) + block_stmt = block.first_stmt + + terminator = block.last_stmt + # we assume Yield has the same # of values as initializers + # TODO: check this in validation + if isinstance(terminator, scf.Yield): + loop_vars = terminator.values + terminator.delete() + + for result, output in zip(node.results, loop_vars): + result.replace_by(output) + node.delete() + return RewriteResult(has_done_something=True) + + +@dataclass +class UnrollNoLoops(Pass): + """A pass to unroll structured control flow""" + + additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True + + fold: BloqadeFold = field(init=False) + typeinfer: TypeInfer = field(init=False) + # scf_unroll: UnrollScf = field(init=False) + canonicalize_ilist: CanonicalizeIList = field(init=False) + + def __post_init__(self): + self.fold = BloqadeFold(self.dialects, no_raise=self.no_raise) + self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) + self.canonicalize_ilist = CanonicalizeIList( + self.dialects, no_raise=self.no_raise + ) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + + result = RewriteResult() + result = self.fold.unsafe_run(mt).join(result) + + # equivalent of ScfUnroll but now customized + result = Walk(PickIfElse()).rewrite(mt.code).join(result) + result = Walk(ForLoopNoIterDependance()).rewrite(mt.code).join(result) + + # Do not join result of typeinfer or fixpoint will waste time + result = ( + Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) + .rewrite(mt.code) + .join(result) + ) + result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result) + result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) + result = self.canonicalize_ilist.fixpoint(mt).join(result) + rule = Chain( + CommonSubexpressionElimination(), + DeadCodeElimination(), + ) + result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) + + return result + + def inline_heuristic(self, node: ir.Statement) -> bool: + """The heuristic to decide whether to inline a function call or not. + inside loops and if-else, only inline simple functions, i.e. + functions with a single block + """ + return not isinstance( + node.parent_stmt, (scf.For, scf.IfElse) + ) and self.additional_inline_heuristic( + node + ) # always inline calls outside of loops and if-else + + +@dataclass +class FlattenExceptLoops(Pass): + """ + like standard Flatten but without unrolling to let analysis go into loops + """ + + unroll: UnrollNoLoops = field(init=False) + simplify_if: StimSimplifyIfs = field(init=False) + + def __post_init__(self): + self.unroll = UnrollNoLoops(self.dialects, no_raise=self.no_raise) + self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + rewrite_result = RewriteResult() + rewrite_result = self.simplify_if(mt).join(rewrite_result) + rewrite_result = self.unroll(mt).join(rewrite_result) + return rewrite_result diff --git a/src/bloqade/stim/passes/soft_flatten.py b/src/bloqade/stim/passes/soft_flatten.py deleted file mode 100644 index cf4fbe6a..00000000 --- a/src/bloqade/stim/passes/soft_flatten.py +++ /dev/null @@ -1,90 +0,0 @@ -# Taken from Phillip Weinberg's bloqade-shuttle implementation -from typing import Callable -from dataclasses import field, dataclass - -from kirin import ir -from kirin.passes import Fold, Pass, TypeInfer - -# from kirin.passes.aggressive import UnrollScf -from kirin.rewrite import ( - Walk, - Chain, - Inline, - Fixpoint, - CFGCompactify, - DeadCodeElimination, - CommonSubexpressionElimination, -) -from kirin.dialects import scf, ilist -from kirin.rewrite.abc import RewriteResult - -# from bloqade.qasm2.passes.fold import AggressiveUnroll -from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs - - -@dataclass -class AggressiveUnroll(Pass): - """A pass to unroll structured control flow""" - - additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True - - fold: Fold = field(init=False) - typeinfer: TypeInfer = field(init=False) - # scf_unroll: UnrollScf = field(init=False) - - def __post_init__(self): - self.fold = Fold(self.dialects, no_raise=self.no_raise) - self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) - # self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise) - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - result = RewriteResult() - # result = self.scf_unroll.unsafe_run(mt).join(result) - result = ( - Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) - .rewrite(mt.code) - .join(result) - ) - self.typeinfer.unsafe_run(mt) - result = self.fold.unsafe_run(mt).join(result) - result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result) - result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) - - rule = Chain( - CommonSubexpressionElimination(), - DeadCodeElimination(), - ) - result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) - - return result - - def inline_heuristic(self, node: ir.Statement) -> bool: - """The heuristic to decide whether to inline a function call or not. - inside loops and if-else, only inline simple functions, i.e. - functions with a single block - """ - return not isinstance( - node.parent_stmt, (scf.For, scf.IfElse) - ) and self.additional_inline_heuristic( - node - ) # always inline calls outside of loops and if-else - - -@dataclass -class SoftFlatten(Pass): - """ - like standard Flatten but without unrolling to let analysis go into loops - """ - - unroll: AggressiveUnroll = field(init=False) - simplify_if: StimSimplifyIfs = field(init=False) - - def __post_init__(self): - self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise) - self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - rewrite_result = RewriteResult() - rewrite_result = self.simplify_if(mt).join(rewrite_result) - rewrite_result = self.unroll(mt).join(rewrite_result) - return rewrite_result diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 2ec96bd0..c00f0032 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -25,12 +25,10 @@ from bloqade.rewrite.passes import CanonicalizeIList from bloqade.analysis.address import AddressAnalysis from bloqade.analysis.measure_id import MeasurementIDAnalysis -from bloqade.stim.passes.flatten import Flatten +from bloqade.stim.passes.flatten_except_loops import FlattenExceptLoops from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim -# from bloqade.stim.passes.soft_flatten import SoftFlatten - @dataclass class SquinToStimPass(Pass): @@ -38,9 +36,9 @@ class SquinToStimPass(Pass): def unsafe_run(self, mt: Method) -> RewriteResult: # inline aggressively: - rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint( - mt - ) + rewrite_result = FlattenExceptLoops( + dialects=mt.dialects, no_raise=self.no_raise + ).fixpoint(mt) # after this the program should be in a state where it is analyzable # ------------------------------------------------------------------- diff --git a/src/bloqade/stim/rewrite/set_detector_to_stim.py b/src/bloqade/stim/rewrite/set_detector_to_stim.py index ed5dfa3b..9d768a6f 100644 --- a/src/bloqade/stim/rewrite/set_detector_to_stim.py +++ b/src/bloqade/stim/rewrite/set_detector_to_stim.py @@ -52,11 +52,12 @@ def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: coord_ssas.append(coord_stmt.result) coord_stmt.insert_before(node) - measure_ids = self.measure_id_frame.entries.get(node.measurements, None) + measure_ids = self.measure_id_frame.entries.get(node.result, None) if measure_ids is None: return RewriteResult() assert isinstance(measure_ids, MeasureIdTuple) + assert measure_ids.immutable get_record_list = insert_get_records(node, measure_ids) diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py index 065213bf..e8c436df 100644 --- a/test/analysis/record/test_record_analysis.py +++ b/test/analysis/record/test_record_analysis.py @@ -6,7 +6,7 @@ from bloqade.analysis.record import RecordAnalysis # from bloqade.analysis.measure_id import MeasurementIDAnalysis -from bloqade.stim.passes.soft_flatten import SoftFlatten +from bloqade.stim.passes.flatten_except_loops import FlattenExceptForLoop """ @squin.kernel @@ -83,7 +83,7 @@ def test(x: int): y[0] += x return y, z - SoftFlatten(dialects=test.dialects).fixpoint(test) + FlattenExceptForLoop(dialects=test.dialects).fixpoint(test) test.print() frame, _ = RecordAnalysis(dialects=test.dialects).run(test) test.print(analysis=frame.entries, hint="const") @@ -172,7 +172,7 @@ def test(): squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) squin.set_observable([data_ms[2]]) - SoftFlatten(dialects=test.dialects).fixpoint(test) + FlattenExceptForLoop(dialects=test.dialects).fixpoint(test) # test.print() frame, _ = RecordAnalysis(dialects=test.dialects).run(test) test.print(analysis=frame.entries) diff --git a/test/stim/passes/test_annotation_to_stim.py b/test/stim/passes/test_annotation_to_stim.py index ba32fa9b..124d9353 100644 --- a/test/stim/passes/test_annotation_to_stim.py +++ b/test/stim/passes/test_annotation_to_stim.py @@ -1,6 +1,7 @@ import io import os +import pytest from kirin import ir from kirin.dialects import scf, ilist @@ -154,9 +155,6 @@ def main(): assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) -test_missing_predicate() - - def test_incorrect_predicate(): # You can only rewrite squin.is_one(...) predicates to @@ -179,6 +177,7 @@ def main(): assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) +@pytest.mark.xfail(reason="nested looping not targeted for conversion yet") def test_nested_for(): @squin.kernel diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index f4768f6f..a8e7bc79 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -3,6 +3,7 @@ import math from math import pi +import pytest from kirin import ir from kirin.dialects import py, scf @@ -210,6 +211,7 @@ def main(): assert codegen(main) == base_stim_prog.rstrip() +@pytest.mark.xfail(reason="Holding off on in-loop unrolling") def test_nested_for_loop_rewrite(): @sq.kernel @@ -315,9 +317,6 @@ def test(): assert len(remaining_if_else) == 2 -test_invalid_if_measure_predicate() - - def test_non_pure_loop_iterator(): @kernel def test_squin_kernel(): From 5013602e7edeaedb83ef74a24829ea07995e8fad Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 10 Dec 2025 11:49:28 -0500 Subject: [PATCH 18/26] delay CSE due to getitem uniqueness issues --- src/bloqade/analysis/measure_id/impls.py | 2 ++ src/bloqade/stim/passes/flatten_except_loops.py | 16 +++++++++------- src/bloqade/stim/passes/simplify_ifs.py | 5 ++--- src/bloqade/stim/passes/squin_to_stim.py | 2 +- test/stim/passes/test_squin_meas_to_stim.py | 1 - 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 2e165df2..5a96db47 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -353,6 +353,8 @@ def if_else( stmt: scf.stmts.IfElse, ): cond_measure_id = frame.get(stmt.cond) + print("cond measure id encountered:") + print(cond_measure_id) if isinstance(cond_measure_id, PredicatedMeasureId): detached_cond_measure_id = PredicatedMeasureId( idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate diff --git a/src/bloqade/stim/passes/flatten_except_loops.py b/src/bloqade/stim/passes/flatten_except_loops.py index e7b3a630..7562a1ee 100644 --- a/src/bloqade/stim/passes/flatten_except_loops.py +++ b/src/bloqade/stim/passes/flatten_except_loops.py @@ -6,14 +6,13 @@ from kirin.passes import Pass, TypeInfer # from kirin.passes.aggressive import UnrollScf -from kirin.rewrite import ( +from kirin.rewrite import ( # CommonSubexpressionElimination, Walk, Chain, Inline, Fixpoint, CFGCompactify, DeadCodeElimination, - CommonSubexpressionElimination, ) from kirin.analysis import const from kirin.dialects import py, scf, ilist @@ -75,7 +74,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @dataclass -class UnrollNoLoops(Pass): +class RestrictedLoopUnroll(Pass): """A pass to unroll structured control flow""" additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True @@ -97,9 +96,12 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: result = RewriteResult() result = self.fold.unsafe_run(mt).join(result) - # equivalent of ScfUnroll but now customized + # equivalent of ScfUnroll but now customized and + # essentially inlined hear for development purposes result = Walk(PickIfElse()).rewrite(mt.code).join(result) result = Walk(ForLoopNoIterDependance()).rewrite(mt.code).join(result) + result = self.fold.unsafe_run(mt).join(result) + result = self.typeinfer.unsafe_run(mt) # no join here, avoid fixpoint issues # Do not join result of typeinfer or fixpoint will waste time result = ( @@ -111,7 +113,7 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) result = self.canonicalize_ilist.fixpoint(mt).join(result) rule = Chain( - CommonSubexpressionElimination(), + # CommonSubexpressionElimination(), - delay until later DeadCodeElimination(), ) result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) @@ -136,11 +138,11 @@ class FlattenExceptLoops(Pass): like standard Flatten but without unrolling to let analysis go into loops """ - unroll: UnrollNoLoops = field(init=False) + unroll: RestrictedLoopUnroll = field(init=False) simplify_if: StimSimplifyIfs = field(init=False) def __post_init__(self): - self.unroll = UnrollNoLoops(self.dialects, no_raise=self.no_raise) + self.unroll = RestrictedLoopUnroll(self.dialects, no_raise=self.no_raise) self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) def unsafe_run(self, mt: ir.Method) -> RewriteResult: diff --git a/src/bloqade/stim/passes/simplify_ifs.py b/src/bloqade/stim/passes/simplify_ifs.py index e2bb47c1..0f3615c4 100644 --- a/src/bloqade/stim/passes/simplify_ifs.py +++ b/src/bloqade/stim/passes/simplify_ifs.py @@ -2,13 +2,12 @@ from kirin import ir from kirin.passes import Pass -from kirin.rewrite import ( +from kirin.rewrite import ( # CommonSubexpressionElimination, Walk, Chain, Fixpoint, ConstantFold, DeadCodeElimination, - CommonSubexpressionElimination, ) from kirin.dialects.scf.trim import UnusedYield from kirin.dialects.ilist.passes import ConstList2IList @@ -37,7 +36,7 @@ def unsafe_run(self, mt: ir.Method): Chain( Fixpoint(Walk(ConstantFold())), Walk(ConstList2IList()), - Walk(CommonSubexpressionElimination()), + # Walk(CommonSubexpressionElimination()), - delay until later ) .rewrite(mt.code) .join(result) diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index c00f0032..53c1e645 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -35,7 +35,7 @@ class SquinToStimPass(Pass): def unsafe_run(self, mt: Method) -> RewriteResult: - # inline aggressively: + # massage things rewrite_result = FlattenExceptLoops( dialects=mt.dialects, no_raise=self.no_raise ).fixpoint(mt) diff --git a/test/stim/passes/test_squin_meas_to_stim.py b/test/stim/passes/test_squin_meas_to_stim.py index 78a86476..92253f4e 100644 --- a/test/stim/passes/test_squin_meas_to_stim.py +++ b/test/stim/passes/test_squin_meas_to_stim.py @@ -104,7 +104,6 @@ def main(): SquinToStimPass(main.dialects)(main) base_stim_prog = load_reference_program("record_index_order.stim") - assert base_stim_prog == codegen(main) From 9be26cd1675ba26649253491d9c73c56a65631e6 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 11 Dec 2025 08:55:39 -0500 Subject: [PATCH 19/26] fix coordinate conversion with ilist as the new type for set_detector --- src/bloqade/stim/passes/squin_to_stim.py | 11 +++ src/bloqade/stim/rewrite/__init__.py | 1 + src/bloqade/stim/rewrite/scf_for_to_stim.py | 18 +++++ .../stim/rewrite/set_detector_to_stim.py | 21 ++++- test/stim/passes/test_repetition_code.py | 80 +++++++++++++++++++ 5 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 src/bloqade/stim/rewrite/scf_for_to_stim.py create mode 100644 test/stim/passes/test_repetition_code.py diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 53c1e645..e82727d3 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -46,8 +46,13 @@ def unsafe_run(self, mt: Method) -> RewriteResult: mia = MeasurementIDAnalysis(dialects=mt.dialects) meas_analysis_frame, _ = mia.run(mt) + print("measure_id analysis") + mt.print(analysis=meas_analysis_frame.entries) + aa = AddressAnalysis(dialects=mt.dialects) address_analysis_frame, _ = aa.run(mt) + print("address analysis") + mt.print(analysis=address_analysis_frame.entries) # wrap the address analysis result rewrite_result = ( @@ -72,6 +77,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult: .join(rewrite_result) ) + # print("after if-else, set_detector, set_observable rewrites") + # mt.print() + # Rewrite the noise statements first. rewrite_result = Walk(SquinNoiseToStim()).rewrite(mt.code).join(rewrite_result) @@ -89,6 +97,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult: .join(rewrite_result) ) + print("after squin qubit and measure rewrites") + mt.print() + rewrite_result = ( CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise) .unsafe_run(mt) diff --git a/src/bloqade/stim/rewrite/__init__.py b/src/bloqade/stim/rewrite/__init__.py index 6b04bdc2..7c3b5165 100644 --- a/src/bloqade/stim/rewrite/__init__.py +++ b/src/bloqade/stim/rewrite/__init__.py @@ -2,6 +2,7 @@ from .squin_noise import SquinNoiseToStim as SquinNoiseToStim from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim from .squin_measure import SquinMeasureToStim as SquinMeasureToStim +from .scf_for_to_stim import ScfForToStim as ScfForToStim from .py_constant_to_stim import PyConstantToStim as PyConstantToStim from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim from .set_observable_to_stim import SetObservableToStim as SetObservableToStim diff --git a/src/bloqade/stim/rewrite/scf_for_to_stim.py b/src/bloqade/stim/rewrite/scf_for_to_stim.py new file mode 100644 index 00000000..7fc12452 --- /dev/null +++ b/src/bloqade/stim/rewrite/scf_for_to_stim.py @@ -0,0 +1,18 @@ +from kirin import ir +from kirin.dialects import scf +from kirin.rewrite.abc import RewriteRule, RewriteResult + + +class ScfForToStim(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement): + if not isinstance(node, scf.stmts.For): + return RewriteResult() + + return RewriteResult(has_done_something=True) + + def rewrite_Region(self, node: ir.Region): + return RewriteResult() + + def rewrite_Block(self, node: ir.Block): + return RewriteResult() diff --git a/src/bloqade/stim/rewrite/set_detector_to_stim.py b/src/bloqade/stim/rewrite/set_detector_to_stim.py index 9d768a6f..fccce4ef 100644 --- a/src/bloqade/stim/rewrite/set_detector_to_stim.py +++ b/src/bloqade/stim/rewrite/set_detector_to_stim.py @@ -1,7 +1,7 @@ -from typing import Iterable from dataclasses import dataclass from kirin import ir +from kirin.dialects import ilist from kirin.dialects.py import Constant from kirin.rewrite.abc import RewriteRule, RewriteResult @@ -33,9 +33,25 @@ def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: # get coordinates and generate correct consts coord_ssas = [] - if not isinstance(node.coordinates.owner, Constant): + print(node.coordinates.owner) + if not isinstance(node.coordinates.owner, ilist.New): return RewriteResult() + for coord_value_ssa in node.coordinates.owner.values: + if isinstance(coord_value_ssa.owner, Constant): + value = coord_value_ssa.owner.value.unwrap() + if isinstance(value, float): + coord_stmt = auxiliary.ConstFloat(value=value) + elif isinstance(value, int): + coord_stmt = auxiliary.ConstInt(value=value) + else: + return RewriteResult() + coord_ssas.append(coord_stmt.result) + coord_stmt.insert_before(node) + else: + return RewriteResult() + + """ coord_values = node.coordinates.owner.value.unwrap() if not isinstance(coord_values, Iterable): @@ -51,6 +67,7 @@ def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: coord_stmt = auxiliary.ConstInt(value=coord_value) coord_ssas.append(coord_stmt.result) coord_stmt.insert_before(node) + """ measure_ids = self.measure_id_frame.entries.get(node.result, None) if measure_ids is None: diff --git a/test/stim/passes/test_repetition_code.py b/test/stim/passes/test_repetition_code.py new file mode 100644 index 00000000..f8ad1682 --- /dev/null +++ b/test/stim/passes/test_repetition_code.py @@ -0,0 +1,80 @@ +from bloqade import squin +from bloqade.stim.passes import SquinToStimPass + + +def test_repeat_on_gates_only(): + + @squin.kernel + def test(): + + qs = squin.qalloc(3) + + squin.broadcast.reset(qs) + + for _ in range(5): + squin.broadcast.h(qs) + squin.broadcast.x(qs) + + SquinToStimPass(dialects=test.dialects)(test) + test.print() + + +def test_repeat_with_invariant_measure(): + + @squin.kernel + def test(): + + qs = squin.qalloc(3) + curr_ms = squin.broadcast.measure(qs) + + for _ in range(5): + prev_ms = curr_ms + squin.broadcast.h(qs) + curr_ms = squin.broadcast.measure(qs) + squin.set_detector( + measurements=[curr_ms[0], prev_ms[0]], coordinates=[0, 0] + ) + + SquinToStimPass(dialects=test.dialects)(test) + test.print() + + +test_repeat_with_invariant_measure() + + +def test_rep_code(): + @squin.kernel + def test(): + + qs = squin.qalloc(5) + data_qs = [qs[0], qs[2], qs[4]] + and_qs = [qs[1], qs[3]] + + squin.broadcast.reset(qs) + squin.broadcast.cx(controls=[qs[0], qs[2]], targets=[qs[1], qs[3]]) + squin.broadcast.cx(controls=[qs[2], qs[4]], targets=[qs[1], qs[3]]) + + curr_ms = squin.broadcast.measure(and_qs) + squin.set_detector([curr_ms[0]], coordinates=[0, 0]) + squin.set_detector([curr_ms[1]], coordinates=[0, 1]) + + for _ in range(3): + + prev_ms = curr_ms + + squin.broadcast.cx(controls=[qs[0], qs[2]], targets=[qs[1], qs[3]]) + squin.broadcast.cx(controls=[qs[2], qs[4]], targets=[qs[1], qs[3]]) + + curr_ms = squin.broadcast.measure(and_qs) + + squin.annotate.set_detector([prev_ms[0], curr_ms[0]], coordinates=[0, 0]) + squin.annotate.set_detector([prev_ms[1], curr_ms[1]], coordinates=[0, 1]) + + data_ms = squin.broadcast.measure(data_qs) + + squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0]) + squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) + squin.set_observable([data_ms[2]]) + + SquinToStimPass(dialects=test.dialects)(test) + test.print() From 782ca808cff3888200f85059ac4f8396cbd49945 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 12 Dec 2025 09:31:04 -0500 Subject: [PATCH 20/26] caught some problems with coordinate rewrite as well as address analysis data just not being wrapped properly. Opted for passing dictionary around instead --- src/bloqade/squin/rewrite/__init__.py | 4 -- src/bloqade/squin/rewrite/wrap_analysis.py | 56 --------------- src/bloqade/stim/dialects/cf/stmts.py | 1 - src/bloqade/stim/passes/squin_to_stim.py | 45 +++++++----- src/bloqade/stim/rewrite/ifs_to_stim.py | 11 ++- src/bloqade/stim/rewrite/qubit_to_stim.py | 39 +++++------ src/bloqade/stim/rewrite/scf_for_to_stim.py | 42 +++++++++-- .../stim/rewrite/set_detector_to_stim.py | 69 ++++++++++--------- src/bloqade/stim/rewrite/squin_measure.py | 32 ++++----- src/bloqade/stim/rewrite/squin_noise.py | 38 ++++++---- src/bloqade/stim/rewrite/util.py | 15 ++-- test/stim/passes/test_repetition_code.py | 28 ++++++-- 12 files changed, 190 insertions(+), 190 deletions(-) delete mode 100644 src/bloqade/squin/rewrite/wrap_analysis.py diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index af45435a..89f51a17 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,6 +1,2 @@ -from .wrap_analysis import ( - AddressAttribute as AddressAttribute, - WrapAddressAnalysis as WrapAddressAnalysis, -) from .U3_to_clifford import SquinU3ToClifford as SquinU3ToClifford from .remove_dangling_qubits import RemoveDeadRegister as RemoveDeadRegister diff --git a/src/bloqade/squin/rewrite/wrap_analysis.py b/src/bloqade/squin/rewrite/wrap_analysis.py deleted file mode 100644 index e1a35440..00000000 --- a/src/bloqade/squin/rewrite/wrap_analysis.py +++ /dev/null @@ -1,56 +0,0 @@ -from abc import abstractmethod -from dataclasses import dataclass - -from kirin import ir -from kirin.rewrite.abc import RewriteRule, RewriteResult -from kirin.print.printer import Printer - -from bloqade import qubit -from bloqade.analysis.address import Address - - -@qubit.dialect.register -@dataclass -class AddressAttribute(ir.Attribute): - - name = "Address" - address: Address - - def __hash__(self) -> int: - return hash(self.address) - - def print_impl(self, printer: Printer) -> None: - # Can return to implementing this later - printer.print(self.address) - - -@dataclass -class WrapAnalysis(RewriteRule): - - @abstractmethod - def wrap(self, value: ir.SSAValue) -> bool: - pass - - def rewrite_Block(self, node: ir.Block) -> RewriteResult: - has_done_something = any(self.wrap(arg) for arg in node.args) - return RewriteResult(has_done_something=has_done_something) - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - has_done_something = any(self.wrap(result) for result in node.results) - return RewriteResult(has_done_something=has_done_something) - - -@dataclass -class WrapAddressAnalysis(WrapAnalysis): - address_analysis: dict[ir.SSAValue, Address] - - def wrap(self, value: ir.SSAValue) -> bool: - if (address_analysis_result := self.address_analysis.get(value)) is None: - return False - - if value.hints.get("address") is not None: - return False - - value.hints["address"] = AddressAttribute(address_analysis_result) - - return True diff --git a/src/bloqade/stim/dialects/cf/stmts.py b/src/bloqade/stim/dialects/cf/stmts.py index e3d2e64f..346184b4 100644 --- a/src/bloqade/stim/dialects/cf/stmts.py +++ b/src/bloqade/stim/dialects/cf/stmts.py @@ -14,7 +14,6 @@ class REPEAT(ir.Statement): """ name = "REPEAT" - traits = frozenset({ir.MaybePure(), ir.HasCFG(), ir.SSACFG()}) count: ir.SSAValue = info.argument(types.Int) body: ir.Region = info.region(multi=False) diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index e82727d3..4b15d6d7 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -11,7 +11,7 @@ from kirin.passes.abc import Pass from kirin.rewrite.abc import RewriteResult -from bloqade.stim.rewrite import ( +from bloqade.stim.rewrite import ( # ScfForToStim, PyConstantToStim, SquinNoiseToStim, SquinQubitToStim, @@ -20,7 +20,6 @@ from bloqade.squin.rewrite import ( SquinU3ToClifford, RemoveDeadRegister, - WrapAddressAnalysis, ) from bloqade.rewrite.passes import CanonicalizeIList from bloqade.analysis.address import AddressAnalysis @@ -35,7 +34,8 @@ class SquinToStimPass(Pass): def unsafe_run(self, mt: Method) -> RewriteResult: - # massage things + # There's some logic here to not touch loops that look like they should be + # rewritten to REPEATs. rewrite_result = FlattenExceptLoops( dialects=mt.dialects, no_raise=self.no_raise ).fixpoint(mt) @@ -51,15 +51,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult: aa = AddressAnalysis(dialects=mt.dialects) address_analysis_frame, _ = aa.run(mt) - print("address analysis") - mt.print(analysis=address_analysis_frame.entries) - - # wrap the address analysis result - rewrite_result = ( - Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries)) - .rewrite(mt.code) - .join(rewrite_result) - ) # 2. rewrite ## Invoke DCE afterwards to eliminate any GetItems @@ -68,7 +59,12 @@ def unsafe_run(self, mt: Method) -> RewriteResult: ## unused measure statements. rewrite_result = ( Chain( - Walk(IfToStim(measure_frame=meas_analysis_frame)), + Walk( + IfToStim( + measure_frame=meas_analysis_frame, + address_frame=address_analysis_frame, + ) + ), Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)), Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)), Fixpoint(Walk(DeadCodeElimination())), @@ -81,7 +77,11 @@ def unsafe_run(self, mt: Method) -> RewriteResult: # mt.print() # Rewrite the noise statements first. - rewrite_result = Walk(SquinNoiseToStim()).rewrite(mt.code).join(rewrite_result) + rewrite_result = ( + Walk(SquinNoiseToStim(address_frame=address_analysis_frame)) + .rewrite(mt.code) + .join(rewrite_result) + ) # Wrap Rewrite + SquinToStim can happen w/ standard walk rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result) @@ -89,17 +89,14 @@ def unsafe_run(self, mt: Method) -> RewriteResult: rewrite_result = ( Walk( Chain( - SquinQubitToStim(), - SquinMeasureToStim(), + SquinQubitToStim(address_frame=address_analysis_frame), + SquinMeasureToStim(address_frame=address_analysis_frame), ) ) .rewrite(mt.code) .join(rewrite_result) ) - print("after squin qubit and measure rewrites") - mt.print() - rewrite_result = ( CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise) .unsafe_run(mt) @@ -124,5 +121,15 @@ def unsafe_run(self, mt: Method) -> RewriteResult: .rewrite(mt.code) .join(rewrite_result) ) + # print("before final loop rewrites") + # mt.print() + # return rewrite_result + # Remaining loops should be safe to convert to REPEAT + # Also make sure to DCE the IList(range) from the for loop lowering + """ + rewrite_result = Walk( + Chain(ScfForToStim()) + ).rewrite(mt.code).join(rewrite_result) + """ return rewrite_result diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index a3f9063c..1ce4253f 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -1,12 +1,12 @@ from dataclasses import field, dataclass from kirin import ir +from kirin.analysis import ForwardFrame from kirin.dialects import py, scf, func from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade.squin import gate from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts -from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.rewrite.util import ( insert_qubit_idx_from_address, ) @@ -130,6 +130,7 @@ class IfToStim(IfElseSimplification, RewriteRule): """ measure_frame: MeasureIDFrame + address_frame: ForwardFrame def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @@ -171,15 +172,13 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: measure_id_idx_stmt = py.Constant(condition_type.idx) get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) - address_attr = stmts[0].qubits.hints.get("address") + address_lattice_elem = self.address_frame.entries.get(stmts[0].qubits) - if address_attr is None: + if address_lattice_elem is None: return RewriteResult() - assert isinstance(address_attr, AddressAttribute) - # note: insert things before (literally above/outside) the If qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=stmt + address=address_lattice_elem, stmt_to_insert_before=stmt ) if qubit_idx_ssas is None: return RewriteResult() diff --git a/src/bloqade/stim/rewrite/qubit_to_stim.py b/src/bloqade/stim/rewrite/qubit_to_stim.py index c6055b8f..654616eb 100644 --- a/src/bloqade/stim/rewrite/qubit_to_stim.py +++ b/src/bloqade/stim/rewrite/qubit_to_stim.py @@ -1,16 +1,23 @@ +from dataclasses import dataclass + from kirin import ir +from kirin.analysis import ForwardFrame from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import qubit from bloqade.squin import gate -from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse +from bloqade.analysis.address import Address from bloqade.stim.rewrite.util import ( insert_qubit_idx_from_address, ) +@dataclass class SquinQubitToStim(RewriteRule): + + address_frame: ForwardFrame[Address] + """ NOTE this require address analysis result to be wrapped before using this rule. """ @@ -33,15 +40,12 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_Reset(self, stmt: qubit.stmts.Reset) -> RewriteResult: - qubit_addr_attr = stmt.qubits.hints.get("address", None) - - if qubit_addr_attr is None: + address_lattice_elem = self.address_frame.entries.get(stmt.qubits) + if address_lattice_elem is None: return RewriteResult() - assert isinstance(qubit_addr_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=qubit_addr_attr, stmt_to_insert_before=stmt + address=address_lattice_elem, stmt_to_insert_before=stmt ) if qubit_idx_ssas is None: @@ -60,14 +64,12 @@ def rewrite_SingleQubitGate( Address Analysis should have been run along with Wrap Analysis before this rewrite is applied. """ - qubit_addr_attr = stmt.qubits.hints.get("address", None) - if qubit_addr_attr is None: + address_lattice_elem = self.address_frame.entries.get(stmt.qubits) + if address_lattice_elem is None: return RewriteResult() - assert isinstance(qubit_addr_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=qubit_addr_attr, stmt_to_insert_before=stmt + address=address_lattice_elem, stmt_to_insert_before=stmt ) if qubit_idx_ssas is None: @@ -97,20 +99,17 @@ def rewrite_ControlledGate(self, stmt: gate.stmts.ControlledGate) -> RewriteResu Address Analysis should have been run along with Wrap Analysis before this rewrite is applied. """ - controls_addr_attr = stmt.controls.hints.get("address", None) - targets_addr_attr = stmt.targets.hints.get("address", None) + controls_addr_lattice_elem = self.address_frame.entries.get(stmt.controls) + targets_addr_lattice_elem = self.address_frame.entries.get(stmt.targets) - if controls_addr_attr is None or targets_addr_attr is None: + if controls_addr_lattice_elem is None or targets_addr_lattice_elem is None: return RewriteResult() - assert isinstance(controls_addr_attr, AddressAttribute) - assert isinstance(targets_addr_attr, AddressAttribute) - controls_idx_ssas = insert_qubit_idx_from_address( - address=controls_addr_attr, stmt_to_insert_before=stmt + address=controls_addr_lattice_elem, stmt_to_insert_before=stmt ) targets_idx_ssas = insert_qubit_idx_from_address( - address=targets_addr_attr, stmt_to_insert_before=stmt + address=targets_addr_lattice_elem, stmt_to_insert_before=stmt ) if controls_idx_ssas is None or targets_idx_ssas is None: diff --git a/src/bloqade/stim/rewrite/scf_for_to_stim.py b/src/bloqade/stim/rewrite/scf_for_to_stim.py index 7fc12452..d38d122f 100644 --- a/src/bloqade/stim/rewrite/scf_for_to_stim.py +++ b/src/bloqade/stim/rewrite/scf_for_to_stim.py @@ -1,7 +1,11 @@ from kirin import ir -from kirin.dialects import scf +from kirin.dialects import scf, ilist +from kirin.dialects.py import Constant from kirin.rewrite.abc import RewriteRule, RewriteResult +from bloqade.stim.dialects import cf +from bloqade.stim.dialects.auxiliary import ConstInt + class ScfForToStim(RewriteRule): @@ -9,10 +13,36 @@ def rewrite_Statement(self, node: ir.Statement): if not isinstance(node, scf.stmts.For): return RewriteResult() - return RewriteResult(has_done_something=True) + # Convert the scf.For iterable to + # a single integer constant + ## Detach to allow DCE to do its job later + loop_iterable_stmt = node.iterable.owner + assert isinstance(loop_iterable_stmt, Constant) + assert isinstance(loop_iterable_stmt.value, ilist.IList) + loop_range = loop_iterable_stmt.value.data + assert isinstance(loop_range, range) + num_times_to_repeat = len(loop_range) + + const_repeat_num = ConstInt(value=num_times_to_repeat) + const_repeat_num.insert_before(node) - def rewrite_Region(self, node: ir.Region): - return RewriteResult() + # figured out from scf2cf, can't just + # point the old body into the new REPEAT body + new_block = ir.Block() + for stmt in node.body.blocks[0].stmts: + if isinstance(stmt, scf.stmts.Yield): + print(stmt.values) + continue + stmt.detach() + new_block.stmts.append(stmt) - def rewrite_Block(self, node: ir.Block): - return RewriteResult() + new_region = ir.Region(new_block) + + # Create the REPEAT statement + repeat_stmt = cf.stmts.REPEAT( + count=const_repeat_num.result, + body=new_region, + ) + node.replace_by(repeat_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/stim/rewrite/set_detector_to_stim.py b/src/bloqade/stim/rewrite/set_detector_to_stim.py index fccce4ef..0c363cba 100644 --- a/src/bloqade/stim/rewrite/set_detector_to_stim.py +++ b/src/bloqade/stim/rewrite/set_detector_to_stim.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from kirin import ir -from kirin.dialects import ilist +from kirin.dialects import py, ilist from kirin.dialects.py import Constant from kirin.rewrite.abc import RewriteRule, RewriteResult @@ -14,6 +14,17 @@ from ..rewrite.get_record_util import insert_get_records +def python_num_val_to_stim_const(value: int | float) -> ir.Statement | None: + if isinstance(value, float): + const_stmt = auxiliary.ConstFloat(value=value) + elif isinstance(value, int): + const_stmt = auxiliary.ConstInt(value=value) + else: + return None + + return const_stmt + + @dataclass class SetDetectorToStim(RewriteRule): """ @@ -25,49 +36,45 @@ class SetDetectorToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case SetDetector(): + print("detector encountered") return self.rewrite_SetDetector(node) case _: return RewriteResult() def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: - # get coordinates and generate correct consts coord_ssas = [] - print(node.coordinates.owner) - if not isinstance(node.coordinates.owner, ilist.New): + + # coordinates can be a py.Constant with an ilist or a raw ilist + if not isinstance(node.coordinates.owner, (ilist.New, py.Constant)): return RewriteResult() - for coord_value_ssa in node.coordinates.owner.values: - if isinstance(coord_value_ssa.owner, Constant): - value = coord_value_ssa.owner.value.unwrap() - if isinstance(value, float): - coord_stmt = auxiliary.ConstFloat(value=value) - elif isinstance(value, int): - coord_stmt = auxiliary.ConstInt(value=value) + if isinstance(node.coordinates.owner, ilist.New): + coord_values_ssas = node.coordinates.owner.values + for coord_value_ssa in coord_values_ssas: + if isinstance(coord_value_ssa.owner, Constant): + value = coord_value_ssa.owner.value.unwrap() + coord_stmt = python_num_val_to_stim_const(value) + if coord_stmt is None: + return RewriteResult() + coord_ssas.append(coord_stmt.result) + coord_stmt.insert_before(node) else: return RewriteResult() + + if isinstance(node.coordinates.owner, py.Constant): + const_value = node.coordinates.owner.value.unwrap() + if not isinstance(const_value, ilist.IList): + return RewriteResult() + ilist_value = const_value.data + if not isinstance(ilist_value, list): + return RewriteResult() + for value in ilist_value: + coord_stmt = python_num_val_to_stim_const(value) + if coord_stmt is None: + return RewriteResult() coord_ssas.append(coord_stmt.result) coord_stmt.insert_before(node) - else: - return RewriteResult() - - """ - coord_values = node.coordinates.owner.value.unwrap() - - if not isinstance(coord_values, Iterable): - return RewriteResult() - - if any(not isinstance(value, (int, float)) for value in coord_values): - return RewriteResult() - - for coord_value in coord_values: - if isinstance(coord_value, float): - coord_stmt = auxiliary.ConstFloat(value=coord_value) - else: # int - coord_stmt = auxiliary.ConstInt(value=coord_value) - coord_ssas.append(coord_stmt.result) - coord_stmt.insert_before(node) - """ measure_ids = self.measure_id_frame.entries.get(node.result, None) if measure_ids is None: diff --git a/src/bloqade/stim/rewrite/squin_measure.py b/src/bloqade/stim/rewrite/squin_measure.py index 25f4d759..f752cc41 100644 --- a/src/bloqade/stim/rewrite/squin_measure.py +++ b/src/bloqade/stim/rewrite/squin_measure.py @@ -2,12 +2,13 @@ from dataclasses import dataclass from kirin import ir +from kirin.analysis import ForwardFrame from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import qubit -from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.dialects import collapse +from bloqade.analysis.address import Address from bloqade.stim.rewrite.util import ( insert_qubit_idx_from_address, ) @@ -19,18 +20,27 @@ class SquinMeasureToStim(RewriteRule): Rewrite squin measure-related statements to stim statements. """ + address_frame: ForwardFrame[Address] + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case qubit.stmts.Measure(): + print("measure encountered") return self.rewrite_Measure(node) case _: return RewriteResult() def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult: - qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) + address_lattice_elem = self.address_frame.entries.get(measure_stmt.qubits) + if address_lattice_elem is None: + return RewriteResult() + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_lattice_elem, stmt_to_insert_before=measure_stmt + ) if qubit_idx_ssas is None: + print(f"no qubit idx ssas found for measure {measure_stmt}") return RewriteResult() prob_noise_stmt = py.constant.Constant(0.0) @@ -48,21 +58,3 @@ def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult: measure_stmt.delete() return RewriteResult(has_done_something=True) - - def get_qubit_idx_ssas( - self, measure_stmt: qubit.stmts.Measure - ) -> tuple[ir.SSAValue, ...] | None: - """ - Extract the address attribute and insert qubit indices for the given measure statement. - """ - address_attr = measure_stmt.qubits.hints.get("address") - if address_attr is None: - return None - - assert isinstance(address_attr, AddressAttribute) - - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=measure_stmt - ) - - return qubit_idx_ssas diff --git a/src/bloqade/stim/rewrite/squin_noise.py b/src/bloqade/stim/rewrite/squin_noise.py index 955c3167..efd6183b 100644 --- a/src/bloqade/stim/rewrite/squin_noise.py +++ b/src/bloqade/stim/rewrite/squin_noise.py @@ -4,19 +4,22 @@ from kirin import types from kirin.ir import SSAValue, Statement +from kirin.analysis import ForwardFrame from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade.squin import noise as squin_noise from bloqade.stim.dialects import noise as stim_noise +from bloqade.analysis.address import Address from bloqade.stim.rewrite.util import insert_qubit_idx_from_address from bloqade.analysis.address.lattice import AddressReg, PartialIList -from bloqade.squin.rewrite.wrap_analysis import AddressAttribute @dataclass class SquinNoiseToStim(RewriteRule): + address_frame: ForwardFrame[Address] + def rewrite_Statement(self, node: Statement) -> RewriteResult: match node: case squin_noise.stmts.NoiseChannel(): @@ -38,12 +41,12 @@ def rewrite_NoiseChannel( # CorrelatedQubitLoss represents a broadcast operation, but Stim does not # support broadcasting for multi-qubit noise channels. # Therefore, we must expand the broadcast into individual stim statements. - qubit_address_attr = stmt.qubits.hints.get("address", None) - if not isinstance(qubit_address_attr, AddressAttribute): + address_lattice_elem = self.address_frame.entries.get(stmt.qubits) + if address_lattice_elem is None: return RewriteResult() - if not isinstance(address := qubit_address_attr.address, PartialIList): + if not isinstance(address := address_lattice_elem, PartialIList): return RewriteResult() if not types.is_tuple_of(data := address.data, AddressReg): @@ -51,9 +54,7 @@ def rewrite_NoiseChannel( for address_reg in data: - qubit_idx_ssas = insert_qubit_idx_from_address( - AddressAttribute(address_reg), stmt - ) + qubit_idx_ssas = insert_qubit_idx_from_address(address_reg, stmt) stim_stmt = rewrite_method(stmt, qubit_idx_ssas) stim_stmt.insert_before(stmt) @@ -63,21 +64,28 @@ def rewrite_NoiseChannel( return RewriteResult(has_done_something=True) if isinstance(stmt, squin_noise.stmts.SingleQubitNoiseChannel): - qubit_address_attr = stmt.qubits.hints.get("address", None) - if qubit_address_attr is None: + + address_lattice_elem = self.address_frame.entries.get(stmt.qubits) + if address_lattice_elem is None: + return RewriteResult() + + qubit_idx_ssas = insert_qubit_idx_from_address(address_lattice_elem, stmt) + if qubit_idx_ssas is None: return RewriteResult() - qubit_idx_ssas = insert_qubit_idx_from_address(qubit_address_attr, stmt) elif isinstance(stmt, squin_noise.stmts.TwoQubitNoiseChannel): - control_address_attr = stmt.controls.hints.get("address", None) - target_address_attr = stmt.targets.hints.get("address", None) - if control_address_attr is None or target_address_attr is None: + control_address_lattice_elem = self.address_frame.entries.get(stmt.controls) + target_address_lattice_elem = self.address_frame.entries.get(stmt.targets) + if ( + control_address_lattice_elem is None + or target_address_lattice_elem is None + ): return RewriteResult() control_qubit_idx_ssas = insert_qubit_idx_from_address( - control_address_attr, stmt + control_address_lattice_elem, stmt ) target_qubit_idx_ssas = insert_qubit_idx_from_address( - target_address_attr, stmt + target_address_lattice_elem, stmt ) if control_qubit_idx_ssas is None or target_qubit_idx_ssas is None: return RewriteResult() diff --git a/src/bloqade/stim/rewrite/util.py b/src/bloqade/stim/rewrite/util.py index ee5ede41..84b17cc8 100644 --- a/src/bloqade/stim/rewrite/util.py +++ b/src/bloqade/stim/rewrite/util.py @@ -1,8 +1,7 @@ from kirin import ir from kirin.dialects import py -from bloqade.squin.rewrite import AddressAttribute -from bloqade.analysis.address import AddressReg, AddressQubit +from bloqade.analysis.address import Address, AddressReg, AddressQubit def create_and_insert_qubit_idx_stmt( @@ -14,20 +13,20 @@ def create_and_insert_qubit_idx_stmt( def insert_qubit_idx_from_address( - address: AddressAttribute, stmt_to_insert_before: ir.Statement + address: Address, stmt_to_insert_before: ir.Statement ) -> tuple[ir.SSAValue, ...] | None: """ - Extract qubit indices from an AddressAttribute and insert them into the SSA form. + Extract qubit indices from an address analysis lattice element and insert them into the SSA form. """ qubit_idx_ssas = [] - if isinstance(address_data := address.address, AddressReg): - for qubit_idx in address_data.qubits: + if isinstance(address, AddressReg): + for qubit_idx in address.qubits: create_and_insert_qubit_idx_stmt( qubit_idx.data, stmt_to_insert_before, qubit_idx_ssas ) - elif isinstance(address_data, AddressQubit): + elif isinstance(address, AddressQubit): create_and_insert_qubit_idx_stmt( - address_data.data, stmt_to_insert_before, qubit_idx_ssas + address.data, stmt_to_insert_before, qubit_idx_ssas ) else: return diff --git a/test/stim/passes/test_repetition_code.py b/test/stim/passes/test_repetition_code.py index f8ad1682..7fe04fa9 100644 --- a/test/stim/passes/test_repetition_code.py +++ b/test/stim/passes/test_repetition_code.py @@ -1,7 +1,21 @@ -from bloqade import squin +import io + +from kirin import ir + +from bloqade import stim, squin +from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass +def codegen(mt: ir.Method): + # method should not have any arguments! + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) + emit.initialize() + emit.run(mt) + return buf.getvalue().strip() + + def test_repeat_on_gates_only(): @squin.kernel @@ -19,7 +33,9 @@ def test(): test.print() -def test_repeat_with_invariant_measure(): +# Very similar to a full repetition code +# but simplified +def test_repetition_code_structure(): @squin.kernel def test(): @@ -35,14 +51,18 @@ def test(): measurements=[curr_ms[0], prev_ms[0]], coordinates=[0, 0] ) + final_ms = squin.broadcast.measure(qs) + squin.set_detector(measurements=[final_ms[0], curr_ms[0]], coordinates=[1, 0]) + squin.set_observable([final_ms[0]]) + SquinToStimPass(dialects=test.dialects)(test) test.print() -test_repeat_with_invariant_measure() +test_repetition_code_structure() -def test_rep_code(): +def test_full_repetition_code(): @squin.kernel def test(): From 356dcfde76e3fd4d0768fec943f9ca3157af250c Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 12 Dec 2025 10:44:33 -0500 Subject: [PATCH 21/26] need to remove dead measures, confirm full repetition code works --- src/bloqade/squin/rewrite/__init__.py | 3 ++- .../squin/rewrite/remove_dead_measure.py | 19 +++++++++++++++++ ...ling_qubits.py => remove_dead_register.py} | 0 src/bloqade/stim/passes/squin_to_stim.py | 21 ++++++++++++++----- src/bloqade/stim/rewrite/scf_for_to_stim.py | 1 - test/stim/passes/test_repetition_code.py | 5 ++--- 6 files changed, 39 insertions(+), 10 deletions(-) create mode 100644 src/bloqade/squin/rewrite/remove_dead_measure.py rename src/bloqade/squin/rewrite/{remove_dangling_qubits.py => remove_dead_register.py} (100%) diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index 89f51a17..b1e03f06 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,2 +1,3 @@ from .U3_to_clifford import SquinU3ToClifford as SquinU3ToClifford -from .remove_dangling_qubits import RemoveDeadRegister as RemoveDeadRegister +from .remove_dead_measure import RemoveDeadMeasure as RemoveDeadMeasure +from .remove_dead_register import RemoveDeadRegister as RemoveDeadRegister diff --git a/src/bloqade/squin/rewrite/remove_dead_measure.py b/src/bloqade/squin/rewrite/remove_dead_measure.py new file mode 100644 index 00000000..ed065385 --- /dev/null +++ b/src/bloqade/squin/rewrite/remove_dead_measure.py @@ -0,0 +1,19 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import qubit + + +class RemoveDeadMeasure(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if not isinstance(node, qubit.stmts.Measure): + return RewriteResult() + + if bool(node.result.uses): + return RewriteResult() + else: + node.delete() + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/remove_dangling_qubits.py b/src/bloqade/squin/rewrite/remove_dead_register.py similarity index 100% rename from src/bloqade/squin/rewrite/remove_dangling_qubits.py rename to src/bloqade/squin/rewrite/remove_dead_register.py diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 4b15d6d7..2c9a6353 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -11,13 +11,15 @@ from kirin.passes.abc import Pass from kirin.rewrite.abc import RewriteResult -from bloqade.stim.rewrite import ( # ScfForToStim, +from bloqade.stim.rewrite import ( + ScfForToStim, PyConstantToStim, SquinNoiseToStim, SquinQubitToStim, SquinMeasureToStim, ) from bloqade.squin.rewrite import ( + RemoveDeadMeasure, SquinU3ToClifford, RemoveDeadRegister, ) @@ -126,10 +128,19 @@ def unsafe_run(self, mt: Method) -> RewriteResult: # return rewrite_result # Remaining loops should be safe to convert to REPEAT # Also make sure to DCE the IList(range) from the for loop lowering - """ - rewrite_result = Walk( - Chain(ScfForToStim()) + + rewrite_result = ( + Chain( + Walk(ScfForToStim()), + ) + .rewrite(mt.code) + .join(rewrite_result) + ) + + Fixpoint( + Walk( + Chain(DeadCodeElimination(), RemoveDeadMeasure(), RemoveDeadRegister()) + ) ).rewrite(mt.code).join(rewrite_result) - """ return rewrite_result diff --git a/src/bloqade/stim/rewrite/scf_for_to_stim.py b/src/bloqade/stim/rewrite/scf_for_to_stim.py index d38d122f..c3698666 100644 --- a/src/bloqade/stim/rewrite/scf_for_to_stim.py +++ b/src/bloqade/stim/rewrite/scf_for_to_stim.py @@ -31,7 +31,6 @@ def rewrite_Statement(self, node: ir.Statement): new_block = ir.Block() for stmt in node.body.blocks[0].stmts: if isinstance(stmt, scf.stmts.Yield): - print(stmt.values) continue stmt.detach() new_block.stmts.append(stmt) diff --git a/test/stim/passes/test_repetition_code.py b/test/stim/passes/test_repetition_code.py index 7fe04fa9..b1d8ec2a 100644 --- a/test/stim/passes/test_repetition_code.py +++ b/test/stim/passes/test_repetition_code.py @@ -59,9 +59,6 @@ def test(): test.print() -test_repetition_code_structure() - - def test_full_repetition_code(): @squin.kernel def test(): @@ -98,3 +95,5 @@ def test(): SquinToStimPass(dialects=test.dialects)(test) test.print() + + print(codegen(test)) From d4858d3c78c443b10c90d6c68317db21fbfe30ea Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 12 Dec 2025 11:33:00 -0500 Subject: [PATCH 22/26] remove debug print statement --- src/bloqade/analysis/measure_id/impls.py | 2 -- src/bloqade/stim/dialects/cf/emit.py | 1 - src/bloqade/stim/passes/squin_to_stim.py | 11 ----------- src/bloqade/stim/rewrite/set_detector_to_stim.py | 1 - src/bloqade/stim/rewrite/squin_measure.py | 2 -- test/stim/passes/test_repetition_code.py | 5 ++++- test/stim/passes/test_squin_noise_to_stim.py | 11 ++++------- 7 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 5a96db47..2e165df2 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -353,8 +353,6 @@ def if_else( stmt: scf.stmts.IfElse, ): cond_measure_id = frame.get(stmt.cond) - print("cond measure id encountered:") - print(cond_measure_id) if isinstance(cond_measure_id, PredicatedMeasureId): detached_cond_measure_id = PredicatedMeasureId( idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate diff --git a/src/bloqade/stim/dialects/cf/emit.py b/src/bloqade/stim/dialects/cf/emit.py index 992f06bd..122f1c0e 100644 --- a/src/bloqade/stim/dialects/cf/emit.py +++ b/src/bloqade/stim/dialects/cf/emit.py @@ -12,7 +12,6 @@ class EmitStimCfMethods(MethodTable): @impl(stmts.REPEAT) def repeat(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.REPEAT): - print(stmt.count) count = frame.get(stmt.count) frame.write_line(f"REPEAT {count} {{") with frame.indent(): diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 2c9a6353..69b36cfa 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -48,9 +48,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult: mia = MeasurementIDAnalysis(dialects=mt.dialects) meas_analysis_frame, _ = mia.run(mt) - print("measure_id analysis") - mt.print(analysis=meas_analysis_frame.entries) - aa = AddressAnalysis(dialects=mt.dialects) address_analysis_frame, _ = aa.run(mt) @@ -75,9 +72,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult: .join(rewrite_result) ) - # print("after if-else, set_detector, set_observable rewrites") - # mt.print() - # Rewrite the noise statements first. rewrite_result = ( Walk(SquinNoiseToStim(address_frame=address_analysis_frame)) @@ -123,11 +117,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult: .rewrite(mt.code) .join(rewrite_result) ) - # print("before final loop rewrites") - # mt.print() - # return rewrite_result - # Remaining loops should be safe to convert to REPEAT - # Also make sure to DCE the IList(range) from the for loop lowering rewrite_result = ( Chain( diff --git a/src/bloqade/stim/rewrite/set_detector_to_stim.py b/src/bloqade/stim/rewrite/set_detector_to_stim.py index 0c363cba..b2ac0296 100644 --- a/src/bloqade/stim/rewrite/set_detector_to_stim.py +++ b/src/bloqade/stim/rewrite/set_detector_to_stim.py @@ -36,7 +36,6 @@ class SetDetectorToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case SetDetector(): - print("detector encountered") return self.rewrite_SetDetector(node) case _: return RewriteResult() diff --git a/src/bloqade/stim/rewrite/squin_measure.py b/src/bloqade/stim/rewrite/squin_measure.py index f752cc41..62402117 100644 --- a/src/bloqade/stim/rewrite/squin_measure.py +++ b/src/bloqade/stim/rewrite/squin_measure.py @@ -26,7 +26,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case qubit.stmts.Measure(): - print("measure encountered") return self.rewrite_Measure(node) case _: return RewriteResult() @@ -40,7 +39,6 @@ def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult: address=address_lattice_elem, stmt_to_insert_before=measure_stmt ) if qubit_idx_ssas is None: - print(f"no qubit idx ssas found for measure {measure_stmt}") return RewriteResult() prob_noise_stmt = py.constant.Constant(0.0) diff --git a/test/stim/passes/test_repetition_code.py b/test/stim/passes/test_repetition_code.py index b1d8ec2a..63b8ea18 100644 --- a/test/stim/passes/test_repetition_code.py +++ b/test/stim/passes/test_repetition_code.py @@ -75,7 +75,7 @@ def test(): squin.set_detector([curr_ms[0]], coordinates=[0, 0]) squin.set_detector([curr_ms[1]], coordinates=[0, 1]) - for _ in range(3): + for _ in range(10): prev_ms = curr_ms @@ -97,3 +97,6 @@ def test(): test.print() print(codegen(test)) + + +test_full_repetition_code() diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index 04df2837..6c2b4e53 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -13,7 +13,6 @@ from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass, flatten from bloqade.stim.rewrite import SquinNoiseToStim -from bloqade.squin.rewrite import WrapAddressAnalysis from bloqade.analysis.address import AddressAnalysis @@ -274,7 +273,8 @@ def test(q: ilist.IList[Qubit, kirin_types.Literal]): return flatten.Flatten(dialects=test.dialects).fixpoint(test) - Walk(SquinNoiseToStim()).rewrite(test.code) + frame, _ = AddressAnalysis(test.dialects).run(test) + Walk(SquinNoiseToStim(address_frame=frame)).rewrite(test.code) expected_1q_noise_pauli_channel = get_stmt_at_idx(test, 6) @@ -302,9 +302,8 @@ def test(): return frame, _ = AddressAnalysis(test.dialects).run(test) - WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) - rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) + rewrite_result = Walk(SquinNoiseToStim(address_frame=frame)).rewrite(test.code) expected_noise_channel_stmt = get_stmt_at_idx(test, 2) @@ -323,9 +322,7 @@ def test(): return frame, _ = AddressAnalysis(test.dialects).run(test) - WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) - - rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) + rewrite_result = Walk(SquinNoiseToStim(address_frame=frame)).rewrite(test.code) # Rewrite should not have done anything because target is not a noise channel assert not rewrite_result.has_done_something From 1b4fbd4eebc6f2e893f8e768eae9582f924a3fba Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 15 Dec 2025 14:05:47 -0500 Subject: [PATCH 23/26] fix test failures with coordinates --- test/stim/passes/test_code_basic_operations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/stim/passes/test_code_basic_operations.py b/test/stim/passes/test_code_basic_operations.py index ecef66d7..ffba8a08 100644 --- a/test/stim/passes/test_code_basic_operations.py +++ b/test/stim/passes/test_code_basic_operations.py @@ -59,7 +59,7 @@ def get_qubit(idx: int) -> Qubit: m = squin.broadcast.measure(sub_q) for i in range(len(m)): - squin.annotate.set_detector(measurements=[m[i]], coordinates=(0, 0)) + squin.annotate.set_detector(measurements=[m[i]], coordinates=[0, 0]) SquinToStimPass(dialects=test.dialects)(test) @@ -93,7 +93,7 @@ def main(): mr = measure_out(qubits) for i in range(len(mr)): - squin.set_detector(measurements=[mr[i]], coordinates=(0, 0)) + squin.set_detector(measurements=[mr[i]], coordinates=[0, 0]) SquinToStimPass(main.dialects)(main) From 32a6582c2975d7613cdcb08d659b0a7f09e6db91 Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 15 Dec 2025 16:51:01 -0500 Subject: [PATCH 24/26] restore test functionality, missed the measurement inbetween to get the record idxs to increment --- src/bloqade/analysis/measure_id/analysis.py | 10 --- src/bloqade/analysis/measure_id/impls.py | 47 ++++++++---- src/bloqade/analysis/measure_id/lattice.py | 20 ++--- src/bloqade/stim/passes/squin_to_stim.py | 6 ++ src/bloqade/stim/rewrite/ifs_to_stim.py | 10 ++- test/analysis/measure_id/test_measure_id.py | 41 ++++------ .../qubit/delayed_cse_measure_predicate.stim | 4 + test/stim/passes/test_annotation_to_stim.py | 1 + test/stim/passes/test_squin_qubit_to_stim.py | 74 +++++++++++++++++++ 9 files changed, 148 insertions(+), 65 deletions(-) create mode 100644 test/stim/passes/stim_reference_programs/qubit/delayed_cse_measure_predicate.stim diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 423fd589..78b898f3 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -52,16 +52,6 @@ def clone_raw_measure_id(self, raw_measure_id: RawMeasureId) -> RawMeasureId: self.buffer.append(cloned_raw_measure_id) return cloned_raw_measure_id - def clone_predicated_measure_id( - self, predicated_measure_id: PredicatedMeasureId - ) -> PredicatedMeasureId: - cloned_predicated_measure_id = PredicatedMeasureId( - idx=predicated_measure_id.idx, - predicate=predicated_measure_id.predicate, - ) - self.buffer.append(cloned_predicated_measure_id) - return cloned_predicated_measure_id - def clone_measure_ids(self, measure_id_type: MeasureId) -> MeasureId: if isinstance(measure_id_type, RawMeasureId): diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 2e165df2..3a217b56 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -75,11 +75,7 @@ def measurement_predicate( else: return (InvalidMeasureId(),) - predicate_measure_ids = [ - PredicatedMeasureId(measure_id.idx, predicate) - for measure_id in original_measure_id_tuple.data - ] - return (MeasureIdTuple(data=tuple(predicate_measure_ids)),) + return (PredicatedMeasureId(on_type=original_measure_id_tuple, cond=predicate),) @gemini.logical.dialect.register(key="measure_id") @@ -179,18 +175,35 @@ def getitem( idx_or_slice = idx_or_slice.value obj = frame.get(stmt.obj) - if isinstance(obj, MeasureIdTuple): - if isinstance(idx_or_slice, slice): - return (MeasureIdTuple(data=obj.data[idx_or_slice]),) - elif isinstance(idx_or_slice, int): - return (obj.data[idx_or_slice],) + + if isinstance(obj, PredicatedMeasureId): + if isinstance(obj.on_type, MeasureIdTuple): + # apply the slice/indexing to the interior MeasureIdTuple and + type_to_wrap = self.measure_id_tuple_handling(obj.on_type, idx_or_slice) + return (PredicatedMeasureId(on_type=type_to_wrap, cond=obj.cond),) else: return (InvalidMeasureId(),) - # just propagate these down the line - elif isinstance(obj, (AnyMeasureId, NotMeasureId)): + + if isinstance(obj, MeasureIdTuple): + return (self.measure_id_tuple_handling(obj, idx_or_slice),) + + # just propagate down the line + if isinstance(obj, (AnyMeasureId, NotMeasureId)): return (obj,) + + # literally everything else failed + return (InvalidMeasureId(),) + + def measure_id_tuple_handling( + self, measure_id_tuple: MeasureIdTuple, idx_or_slice: int | slice + ) -> RawMeasureId | MeasureIdTuple: + + if isinstance(idx_or_slice, slice): + return MeasureIdTuple(data=measure_id_tuple.data[idx_or_slice]) + elif isinstance(idx_or_slice, int): + return measure_id_tuple.data[idx_or_slice] else: - return (InvalidMeasureId(),) + return InvalidMeasureId() @py.assign.dialect.register(key="measure_id") @@ -353,9 +366,12 @@ def if_else( stmt: scf.stmts.IfElse, ): cond_measure_id = frame.get(stmt.cond) - if isinstance(cond_measure_id, PredicatedMeasureId): + if isinstance(cond_measure_id, PredicatedMeasureId) and isinstance( + cond_measure_id.on_type, RawMeasureId + ): detached_cond_measure_id = PredicatedMeasureId( - idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate + on_type=deepcopy(RawMeasureId(idx=cond_measure_id.on_type.idx)), + cond=cond_measure_id.cond, ) frame.type_for_scf_conds[stmt] = detached_cond_measure_id return @@ -363,7 +379,6 @@ def if_else( # If you don't get a PredicatedMeasureId, don't bother # converting anything frame.type_for_scf_conds[stmt] = InvalidMeasureId() - # nothing to return, this thing already lives on the @py.dialect.register(key="measure_id") diff --git a/src/bloqade/analysis/measure_id/lattice.py b/src/bloqade/analysis/measure_id/lattice.py index 5bf71ccf..d9430124 100644 --- a/src/bloqade/analysis/measure_id/lattice.py +++ b/src/bloqade/analysis/measure_id/lattice.py @@ -79,25 +79,25 @@ def is_subseteq(self, other: MeasureId) -> bool: @final @dataclass -class PredicatedMeasureId(MeasureId): - idx: int - predicate: Predicate +class MeasureIdTuple(MeasureId): + data: tuple[RawMeasureId, ...] + immutable: bool = False def is_subseteq(self, other: MeasureId) -> bool: - if isinstance(other, PredicatedMeasureId): - return self.idx == other.idx and self.predicate == other.predicate + if isinstance(other, MeasureIdTuple): + return all(a.is_subseteq(b) for a, b in zip(self.data, other.data)) return False @final @dataclass -class MeasureIdTuple(MeasureId): - data: tuple[MeasureId, ...] - immutable: bool = False +class PredicatedMeasureId(MeasureId): + on_type: MeasureIdTuple | RawMeasureId + cond: Predicate def is_subseteq(self, other: MeasureId) -> bool: - if isinstance(other, MeasureIdTuple): - return all(a.is_subseteq(b) for a, b in zip(self.data, other.data)) + if isinstance(other, PredicatedMeasureId): + return self.cond == other.cond and self.on_type.is_subseteq(other.on_type) return False diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 69b36cfa..16de2836 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -42,12 +42,18 @@ def unsafe_run(self, mt: Method) -> RewriteResult: dialects=mt.dialects, no_raise=self.no_raise ).fixpoint(mt) + print("after flattenExceptLoops:") + mt.print() + # after this the program should be in a state where it is analyzable # ------------------------------------------------------------------- mia = MeasurementIDAnalysis(dialects=mt.dialects) meas_analysis_frame, _ = mia.run(mt) + print("after measurement ID analysis:") + mt.print(analysis=meas_analysis_frame.entries) + aa = AddressAnalysis(dialects=mt.dialects) address_analysis_frame, _ = aa.run(mt) diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index 1ce4253f..f477c884 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -15,6 +15,7 @@ from bloqade.stim.dialects.auxiliary import GetRecord from bloqade.analysis.measure_id.lattice import ( Predicate, + RawMeasureId, InvalidMeasureId, PredicatedMeasureId, ) @@ -145,12 +146,13 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: condition_type = self.measure_frame.type_for_scf_conds.get(stmt) if condition_type is None or condition_type is InvalidMeasureId(): return RewriteResult() - # Check the condition is a singular MeasurementIdBool and that - # it was generated by querying if the measurement is equivalent to the one state + if not isinstance(condition_type, PredicatedMeasureId): return RewriteResult() - if condition_type.predicate != Predicate.IS_ONE: + if condition_type.cond != Predicate.IS_ONE or not isinstance( + condition_type.on_type, RawMeasureId + ): return RewriteResult() # Reusing code from SplitIf, @@ -169,7 +171,7 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: return RewriteResult() # generate get record statement - measure_id_idx_stmt = py.Constant(condition_type.idx) + measure_id_idx_stmt = py.Constant(condition_type.on_type.idx) get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) address_lattice_elem = self.address_frame.entries.get(stmts[0].qubits) diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index d9b1b164..f7993bff 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -33,19 +33,20 @@ def results_of_variables(kernel, variable_names): def test_subset_eq_PredicatedMeasureId(): - m0 = PredicatedMeasureId(idx=1, predicate=Predicate.IS_ONE) - m1 = PredicatedMeasureId(idx=1, predicate=Predicate.IS_ONE) + wrapped_type = RawMeasureId(idx=1) + m0 = PredicatedMeasureId(on_type=wrapped_type, cond=Predicate.IS_ONE) + m1 = PredicatedMeasureId(on_type=wrapped_type, cond=Predicate.IS_ONE) assert m0.is_subseteq(m1) # not equivalent if predicate is different - m2 = PredicatedMeasureId(idx=1, predicate=Predicate.IS_ZERO) + m2 = PredicatedMeasureId(on_type=wrapped_type, cond=Predicate.IS_ZERO) assert not m0.is_subseteq(m2) # not equivalent if index is different either, # they are only equivalent if both index and predicate match - m3 = PredicatedMeasureId(idx=2, predicate=Predicate.IS_ONE) + m3 = PredicatedMeasureId(on_type=RawMeasureId(idx=2), cond=Predicate.IS_ONE) assert not m0.is_subseteq(m3) @@ -304,29 +305,19 @@ def test(): test, ("is_zero_bools", "is_one_bools", "is_lost_bools") ) - expected_is_zero_bools = MeasureIdTuple( - data=tuple( - [ - PredicatedMeasureId(idx=i, predicate=Predicate.IS_ZERO) - for i in range(-3, 0) - ] - ), + expected_is_zero_bools = PredicatedMeasureId( + on_type=MeasureIdTuple(data=tuple([RawMeasureId(idx=i) for i in range(-3, 0)])), + cond=Predicate.IS_ZERO, ) - expected_is_one_bools = MeasureIdTuple( - data=tuple( - [ - PredicatedMeasureId(idx=i, predicate=Predicate.IS_ONE) - for i in range(-3, 0) - ] - ), + + expected_is_one_bools = PredicatedMeasureId( + on_type=MeasureIdTuple(data=tuple([RawMeasureId(idx=i) for i in range(-3, 0)])), + cond=Predicate.IS_ONE, ) - expected_is_lost_bools = MeasureIdTuple( - data=tuple( - [ - PredicatedMeasureId(idx=i, predicate=Predicate.IS_LOST) - for i in range(-3, 0) - ] - ), + + expected_is_lost_bools = PredicatedMeasureId( + on_type=MeasureIdTuple(data=tuple([RawMeasureId(idx=i) for i in range(-3, 0)])), + cond=Predicate.IS_LOST, ) assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools diff --git a/test/stim/passes/stim_reference_programs/qubit/delayed_cse_measure_predicate.stim b/test/stim/passes/stim_reference_programs/qubit/delayed_cse_measure_predicate.stim new file mode 100644 index 00000000..783df1f6 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/qubit/delayed_cse_measure_predicate.stim @@ -0,0 +1,4 @@ +MZ(0.00000000) 0 1 2 3 +CZ rec[-4] 0 +MZ(0.00000000) 0 1 2 3 +CX rec[-8] 0 \ No newline at end of file diff --git a/test/stim/passes/test_annotation_to_stim.py b/test/stim/passes/test_annotation_to_stim.py index 124d9353..d8e5a6cc 100644 --- a/test/stim/passes/test_annotation_to_stim.py +++ b/test/stim/passes/test_annotation_to_stim.py @@ -174,6 +174,7 @@ def main(): return SquinToStimPass(main.dialects, no_raise=True)(main) + main.print() assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index a8e7bc79..8819e7c9 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -292,6 +292,80 @@ def test(): assert codegen(test) == base_stim_prog.rstrip() +# The SquinToStimPass has some modified rules in its own unroll to postpone +# running CSE, the reason being is the getitems in a kernel like this +# need to be preserved despite looking the same because the lattice +# element is different. +def test_delayed_cse_measure_predicate(): + + @sq.kernel + def test(): + q = sq.qalloc(4) + ms0 = sq.broadcast.measure(q) + + if sq.is_one(ms0[0]): + sq.z(q[0]) + + sq.broadcast.measure(q) + + if sq.is_one(ms0[0]): + sq.x(q[0]) + + SquinToStimPass(test.dialects).unsafe_run(test) + test.print() + base_stim_prog = load_reference_program("delayed_cse_measure_predicate.stim") + assert codegen(test) == base_stim_prog.rstrip() + + +def test_reused_measure_getitem(): + + @sq.kernel + def test(): + q = sq.qalloc(4) + ms0 = sq.broadcast.measure(q) + reusable_ms = ms0[0] + + if sq.is_one(reusable_ms): + sq.z(q[0]) + + sq.broadcast.measure(q) + + if sq.is_one(reusable_ms): + sq.x(q[0]) + + SquinToStimPass(test.dialects).unsafe_run(test) + test.print() + base_stim_prog = load_reference_program("delayed_cse_measure_predicate.stim") + assert codegen(test) == base_stim_prog.rstrip() + + +def test_reused_predicate_result(): + + @sq.kernel + def test(): + q = sq.qalloc(4) + ms = sq.broadcast.measure(q) + pred_ms = sq.broadcast.is_one(ms) + + if pred_ms[0]: + sq.z(q[0]) + + sq.broadcast.measure(q) + + if pred_ms[ + 0 + ]: # this is no longer rec[-4], should be rec[-8] like in the above scenarios + sq.x(q[0]) + + SquinToStimPass(test.dialects).unsafe_run(test) + test.print() + base_stim_prog = load_reference_program("delayed_cse_measure_predicate.stim") + assert codegen(test) == base_stim_prog.rstrip() + + +test_reused_predicate_result() + + # You can only convert a combination of a predicate type and # scf.IfElse if the predicate type is IS_ONE. Otherwise anything # else is invalid From 0fec8baca12a0fce598094d5b983ec47ff07ce07 Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 15 Dec 2025 22:19:06 -0500 Subject: [PATCH 25/26] finish turning development examples into proper tests --- src/bloqade/analysis/measure_id/impls.py | 10 ++++ src/bloqade/stim/passes/squin_to_stim.py | 6 --- .../scf_for/feedforward_inside_loop.stim | 8 +++ .../scf_for/rep_code.stim | 17 +++++++ .../scf_for/rep_code_structure.stim | 9 ++++ .../scf_for/repeat_on_gates_only.stim | 9 ++++ ...tion_code.py => test_scf_for_to_repeat.py} | 50 ++++++++++++++++--- 7 files changed, 97 insertions(+), 12 deletions(-) create mode 100644 test/stim/passes/stim_reference_programs/scf_for/feedforward_inside_loop.stim create mode 100644 test/stim/passes/stim_reference_programs/scf_for/rep_code.stim create mode 100644 test/stim/passes/stim_reference_programs/scf_for/rep_code_structure.stim create mode 100644 test/stim/passes/stim_reference_programs/scf_for/repeat_on_gates_only.stim rename test/stim/passes/{test_repetition_code.py => test_scf_for_to_repeat.py} (63%) diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 3a217b56..4749606c 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -319,7 +319,17 @@ def for_loop( # print(f"Joining {lattice_element} and {second_loop_frame.entries[ssa_val]} to get {verified_latticed_element}") unified_frame_buffer[ssa_val] = verified_latticed_element + # need to unify the IfElse entries as well + # they should stay the same type in the loop + unified_if_else_cond_types = {} + for ssa_val, lattice_element in first_loop_frame.type_for_scf_conds.items(): + unified_if_else_cond_element = second_loop_frame.type_for_scf_conds[ + ssa_val + ].join(lattice_element) + unified_if_else_cond_types[ssa_val] = unified_if_else_cond_element + frame.entries.update(unified_frame_buffer) + frame.type_for_scf_conds.update(unified_if_else_cond_types) frame.global_record_state.offset_existing_records( first_loop_frame.measure_count_offset ) diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 16de2836..69b36cfa 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -42,18 +42,12 @@ def unsafe_run(self, mt: Method) -> RewriteResult: dialects=mt.dialects, no_raise=self.no_raise ).fixpoint(mt) - print("after flattenExceptLoops:") - mt.print() - # after this the program should be in a state where it is analyzable # ------------------------------------------------------------------- mia = MeasurementIDAnalysis(dialects=mt.dialects) meas_analysis_frame, _ = mia.run(mt) - print("after measurement ID analysis:") - mt.print(analysis=meas_analysis_frame.entries) - aa = AddressAnalysis(dialects=mt.dialects) address_analysis_frame, _ = aa.run(mt) diff --git a/test/stim/passes/stim_reference_programs/scf_for/feedforward_inside_loop.stim b/test/stim/passes/stim_reference_programs/scf_for/feedforward_inside_loop.stim new file mode 100644 index 00000000..f997d98a --- /dev/null +++ b/test/stim/passes/stim_reference_programs/scf_for/feedforward_inside_loop.stim @@ -0,0 +1,8 @@ +MZ(0.00000000) 0 1 2 3 4 +REPEAT 3 { + CY rec[-5] 0 + CX rec[-4] 1 + CZ rec[-4] 2 + MZ(0.00000000) 0 1 2 3 4 +} +DETECTOR(0, 0) rec[-5] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/scf_for/rep_code.stim b/test/stim/passes/stim_reference_programs/scf_for/rep_code.stim new file mode 100644 index 00000000..34b1622e --- /dev/null +++ b/test/stim/passes/stim_reference_programs/scf_for/rep_code.stim @@ -0,0 +1,17 @@ +RZ 0 1 2 3 4 +CX 0 1 2 3 +CX 2 1 4 3 +MZ(0.00000000) 1 3 +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 1) rec[-1] +REPEAT 10 { + CX 0 1 2 3 + CX 2 1 4 3 + MZ(0.00000000) 1 3 + DETECTOR(0, 0) rec[-4] rec[-2] + DETECTOR(0, 1) rec[-3] rec[-1] +} +MZ(0.00000000) 0 2 4 +DETECTOR(2, 0) rec[-3] rec[-2] rec[-5] +DETECTOR(2, 1) rec[-1] rec[-2] rec[-4] +OBSERVABLE_INCLUDE(0) rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/scf_for/rep_code_structure.stim b/test/stim/passes/stim_reference_programs/scf_for/rep_code_structure.stim new file mode 100644 index 00000000..42abe5bc --- /dev/null +++ b/test/stim/passes/stim_reference_programs/scf_for/rep_code_structure.stim @@ -0,0 +1,9 @@ +MZ(0.00000000) 0 1 2 +REPEAT 5 { + H 0 1 2 + MZ(0.00000000) 0 1 2 + DETECTOR(0, 0) rec[-3] rec[-6] +} +MZ(0.00000000) 0 1 2 +DETECTOR(1, 0) rec[-3] rec[-6] +OBSERVABLE_INCLUDE(0) rec[-3] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/scf_for/repeat_on_gates_only.stim b/test/stim/passes/stim_reference_programs/scf_for/repeat_on_gates_only.stim new file mode 100644 index 00000000..f3f00dd4 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/scf_for/repeat_on_gates_only.stim @@ -0,0 +1,9 @@ +RZ 0 1 2 +REPEAT 5 { + H 0 1 2 + X 0 1 2 + CZ 0 1 + DEPOLARIZE1(0.01000000) 0 + I_ERROR[loss](0.02000000) 1 + I_ERROR[loss](0.03000000) 0 1 2 +} \ No newline at end of file diff --git a/test/stim/passes/test_repetition_code.py b/test/stim/passes/test_scf_for_to_repeat.py similarity index 63% rename from test/stim/passes/test_repetition_code.py rename to test/stim/passes/test_scf_for_to_repeat.py index 63b8ea18..e0adaba4 100644 --- a/test/stim/passes/test_repetition_code.py +++ b/test/stim/passes/test_scf_for_to_repeat.py @@ -1,4 +1,5 @@ import io +import os from kirin import ir @@ -16,6 +17,14 @@ def codegen(mt: ir.Method): return buf.getvalue().strip() +def load_reference_program(filename): + path = os.path.join( + os.path.dirname(__file__), "stim_reference_programs", "scf_for", filename + ) + with open(path, "r") as f: + return f.read() + + def test_repeat_on_gates_only(): @squin.kernel @@ -28,13 +37,18 @@ def test(): for _ in range(5): squin.broadcast.h(qs) squin.broadcast.x(qs) + squin.cz(control=qs[0], target=qs[1]) + squin.depolarize(p=0.01, qubit=qs[0]) + squin.qubit_loss(p=0.02, qubit=qs[1]) + squin.broadcast.qubit_loss(p=0.03, qubits=qs) SquinToStimPass(dialects=test.dialects)(test) - test.print() + base_program = load_reference_program("repeat_on_gates_only.stim") + assert codegen(test) == base_program.rstrip() # Very similar to a full repetition code -# but simplified +# but simplified for debugging/development purposes def test_repetition_code_structure(): @squin.kernel @@ -56,7 +70,8 @@ def test(): squin.set_observable([final_ms[0]]) SquinToStimPass(dialects=test.dialects)(test) - test.print() + base_program = load_reference_program("rep_code_structure.stim") + assert codegen(test) == base_program.rstrip() def test_full_repetition_code(): @@ -94,9 +109,32 @@ def test(): squin.set_observable([data_ms[2]]) SquinToStimPass(dialects=test.dialects)(test) - test.print() + base_program = load_reference_program("rep_code.stim") + assert codegen(test) == base_program.rstrip() + - print(codegen(test)) +def test_feedforward_inside_loop(): + @squin.kernel + def test(): -test_full_repetition_code() + qs = squin.qalloc(5) + curr_ms = squin.broadcast.measure(qs) + + for _ in range(3): + prev_ms = curr_ms + + if squin.is_one(prev_ms[0]): + squin.y(qs[0]) + + if squin.is_one(prev_ms[1]): + squin.x(qs[1]) + squin.z(qs[2]) + + curr_ms = squin.broadcast.measure(qs) + + squin.set_detector([curr_ms[0]], coordinates=[0, 0]) + + SquinToStimPass(dialects=test.dialects)(test) + base_program = load_reference_program("feedforward_inside_loop.stim") + assert codegen(test) == base_program.rstrip() From 9ec62d4aab29fbdf2d8202adf8283b6ec898b9dd Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 15 Dec 2025 22:29:16 -0500 Subject: [PATCH 26/26] finally let go of the original record analysis. Farewell old friend --- src/bloqade/analysis/record/__init__.py | 2 - src/bloqade/analysis/record/analysis.py | 81 ------ src/bloqade/analysis/record/impls.py | 274 ------------------- src/bloqade/analysis/record/lattice.py | 102 ------- test/analysis/record/test_record_analysis.py | 206 -------------- test/test_annotate.py | 35 --- 6 files changed, 700 deletions(-) delete mode 100644 src/bloqade/analysis/record/__init__.py delete mode 100644 src/bloqade/analysis/record/analysis.py delete mode 100644 src/bloqade/analysis/record/impls.py delete mode 100644 src/bloqade/analysis/record/lattice.py delete mode 100644 test/analysis/record/test_record_analysis.py delete mode 100644 test/test_annotate.py diff --git a/src/bloqade/analysis/record/__init__.py b/src/bloqade/analysis/record/__init__.py deleted file mode 100644 index 6741d40a..00000000 --- a/src/bloqade/analysis/record/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import impls as impls -from .analysis import RecordAnalysis as RecordAnalysis diff --git a/src/bloqade/analysis/record/analysis.py b/src/bloqade/analysis/record/analysis.py deleted file mode 100644 index 4ef20a0e..00000000 --- a/src/bloqade/analysis/record/analysis.py +++ /dev/null @@ -1,81 +0,0 @@ -from dataclasses import field, dataclass - -from kirin import ir -from kirin.analysis import ForwardExtra -from kirin.analysis.forward import ForwardFrame - -from .lattice import Record, RecordIdx, RecordTuple - - -@dataclass -class GlobalRecordState: - buffer: list[RecordIdx] = field(default_factory=list) - - # assume that this RecordIdx will always be -1 - def add_record_idxs(self, num_new_records: int, id: int) -> list[RecordIdx]: - # adjust all previous indices - for record_idx in self.buffer: - record_idx.idx -= num_new_records - - # generate new indices and add them to the buffer - new_record_idxs = [RecordIdx(-i, id) for i in range(num_new_records, 0, -1)] - self.buffer += new_record_idxs - # Return for usage, idxs linked to the global state - return new_record_idxs - - # Need for loop invariance, especially when you - # run the loop twice "behind the scenes". Then - # it isn't sufficient to just have two - # copies of a lattice element point to one entry on the - # buffer - def clone_record_idxs(self, record_tuple: RecordTuple, id: int) -> RecordTuple: - cloned_members = [] - for record_idx in record_tuple.members: - cloned_record_idx = RecordIdx(record_idx.idx, id) - # put into the global buffer but also - # return an analysis-facing copy - self.buffer.append(cloned_record_idx) - cloned_members.append(cloned_record_idx) - - return RecordTuple(members=tuple(cloned_members)) - - def offset_existing_records(self, offset: int): - for record_idx in self.buffer: - record_idx.idx -= offset - - """ - Might need a free after use! You can keep the size of the list small - but could be a premature optimization... - """ - # def drop_record_idxs(self, record_tuple: RecordTuple): - # for record_idx in record_tuple.members: - # self.buffer.remove(record_idx) - - -@dataclass -class RecordFrame(ForwardFrame): - global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState) - measure_count_offset: int = 0 - frame_id: int = 0 - - -class RecordAnalysis(ForwardExtra[RecordFrame, Record]): - keys = ["record"] - lattice = Record - - def initialize_frame( - self, node: ir.Statement, *, has_parent_access: bool = False - ) -> RecordFrame: - return RecordFrame(node, has_parent_access=has_parent_access) - - def eval_fallback( - self, frame: RecordFrame, node: ir.Statement - ) -> tuple[Record, ...]: - return tuple(self.lattice.bottom() for _ in node.results) - - def run_method(self, method, args: tuple[Record, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_method(method.code, (self.lattice.bottom(),) + args) - - def method_self(self, method: ir.Method) -> Record: - return self.lattice.bottom() diff --git a/src/bloqade/analysis/record/impls.py b/src/bloqade/analysis/record/impls.py deleted file mode 100644 index f4a7deaa..00000000 --- a/src/bloqade/analysis/record/impls.py +++ /dev/null @@ -1,274 +0,0 @@ -from copy import deepcopy - -from kirin import types as kirin_types, interp -from kirin.ir import PyAttr -from kirin.dialects import py, scf, ilist - -from bloqade import qubit, annotate -from bloqade.annotate.stmts import SetDetector, SetObservable - -from .lattice import ( - AnyRecord, - NotRecord, - RecordIdx, - RecordTuple, - InvalidRecord, - ConstantCarrier, - ImmutableRecords, -) -from .analysis import RecordFrame, RecordAnalysis - - -@annotate.dialect.register(key="record") -class PhysicalAnnotations(interp.MethodTable): - # Both statements inherit from the base class "ConsumesMeasurementResults" - # both statements consume IList of MeasurementResults, so the input type should be - # expected to be a RecordTuple - @interp.impl(SetObservable) - @interp.impl(SetDetector) - def consumes_measurements( - self, interp: RecordAnalysis, frame: RecordFrame, stmt: SetDetector - ): - # Get the measurement results being consumed - record_tuple_at_stmt = frame.get(stmt.measurements) - - if not ( - isinstance(record_tuple_at_stmt, RecordTuple) - and kirin_types.is_tuple_of(record_tuple_at_stmt.members, RecordIdx) - ): - return (InvalidRecord(),) - - final_record_idxs = [ - deepcopy(record_idx) for record_idx in record_tuple_at_stmt.members - ] - - return (ImmutableRecords(members=tuple(final_record_idxs)),) - - -@qubit.dialect.register(key="record") -class SquinQubit(interp.MethodTable): - - @interp.impl(qubit.stmts.Measure) - def measure_qubit_list( - self, - interp: RecordAnalysis, - frame: RecordFrame, - stmt: qubit.stmts.Measure, - ): - - # try to get the length of the list - ## "...safely assume the type inference will give you what you need" - qubits_type = stmt.qubits.type - # vars[0] is just the type of the elements in the ilist, - # vars[1] can contain a literal with length information - num_qubits = qubits_type.vars[1] - if not isinstance(num_qubits, kirin_types.Literal): - return (AnyRecord(),) - - # increment the parent frame measure count offset. - # Loop analysis relies on local state tracking - # so we use this data after exiting a loop to - # readjust the previous global measure count. - frame.measure_count_offset += num_qubits.data - - record_idxs = frame.global_record_state.add_record_idxs( - num_qubits.data, id=frame.frame_id - ) - - return (RecordTuple(members=tuple(record_idxs)),) - - -@py.indexing.dialect.register(key="record") -class PyIndexing(interp.MethodTable): - @interp.impl(py.GetItem) - def getitem(self, interp: RecordAnalysis, frame: RecordFrame, stmt: py.GetItem): - - # maybe_const will work fine outside of any loops because - # constprop will put the expected data into a hint. - - # if maybeconst fails, we fall back to getting the value from the frame - # (note that even outside loops, the constant impl will happily - # capture integer/slice constants so if THAT fails, then something - # has truly gone wrong). - possible_idx_or_slice = interp.maybe_const(stmt.index, (int, slice)) - if possible_idx_or_slice is not None: - idx_or_slice = possible_idx_or_slice - else: - idx_or_slice = frame.get(stmt.index) - if not isinstance(idx_or_slice, ConstantCarrier): - return (InvalidRecord(),) - else: - idx_or_slice = idx_or_slice.value - - obj = frame.get(stmt.obj) - if isinstance(obj, RecordTuple): - if isinstance(idx_or_slice, slice): - return (RecordTuple(members=obj.members[idx_or_slice]),) - elif isinstance(idx_or_slice, int): - return (obj.members[idx_or_slice],) - else: - return (InvalidRecord(),) - # just propagate these down the line - elif isinstance(obj, (AnyRecord, NotRecord)): - return (obj,) - else: - return (InvalidRecord(),) - - -@ilist.dialect.register(key="record") -class IList(interp.MethodTable): - @interp.impl(ilist.New) - def new_ilist( - self, - interp: RecordAnalysis, - frame: interp.Frame, - stmt: ilist.New, - ): - return (RecordTuple(frame.get_values(stmt.values)),) - - -@py.assign.dialect.register(key="record") -class PyAlias(interp.MethodTable): - @interp.impl(py.Alias) - def alias( - self, - interp_: RecordAnalysis, - frame: RecordFrame, - stmt: py.Alias, - ): - input = frame.get(stmt.value) # expect this to be a RecordTuple - # input could belong to another frame and get repossessed with an - # independent copy in this frame. Might need to set a new frame_id here - new_input = frame.global_record_state.clone_record_idxs( - input, id=frame.frame_id - ) - # two variables share the same references in the global state - return (new_input,) - - -@scf.dialect.register(key="record") -class LoopHandling(interp.MethodTable): - - @interp.impl(scf.stmts.For) - def for_loop_double_pass( - self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For - ): - - init_loop_vars = frame.get_values(stmt.initializers) - - # You go through the loops twice to verify the loop invariant. - # we need to freeze the frame entries right after exiting the loop - - local_state = deepcopy(frame.global_record_state) - # local_state = GlobalRecordState() - - first_loop_frame = RecordFrame( - stmt, - frame_id=frame.frame_id + 1, - global_record_state=local_state, - parent=frame, - has_parent_access=True, - ) - - first_loop_vars = interp_.frame_call_region( - first_loop_frame, stmt, stmt.body, InvalidRecord(), *init_loop_vars - ) - - if first_loop_vars is None: - first_loop_vars = () - elif isinstance(first_loop_vars, interp.ReturnValue): - return first_loop_vars - - captured_first_loop_entries = {} - captured_first_loop_vars = deepcopy(first_loop_vars) - - for ssa_val, lattice_element in first_loop_frame.entries.items(): - captured_first_loop_entries[ssa_val] = deepcopy(lattice_element) - - second_loop_frame = RecordFrame( - stmt, - frame_id=frame.frame_id + 2, - global_record_state=local_state, - parent=frame, - has_parent_access=True, - ) - second_loop_vars = interp_.frame_call_region( - second_loop_frame, stmt, stmt.body, InvalidRecord(), *first_loop_vars - ) - - if second_loop_vars is None: - second_loop_vars = () - elif isinstance(second_loop_vars, interp.ReturnValue): - return second_loop_vars - - # take the entries in the first and second loops - # update the parent frame - - unified_frame_buffer = {} - for ssa_val, lattice_element in captured_first_loop_entries.items(): - verified_latticed_element = second_loop_frame.entries[ssa_val].join( - lattice_element - ) - unified_frame_buffer[ssa_val] = verified_latticed_element - - frame.entries.update(unified_frame_buffer) - frame.global_record_state.offset_existing_records( - first_loop_frame.measure_count_offset - ) - - if captured_first_loop_vars is None or second_loop_vars is None: - return () - - joined_loop_vars = [] - for first_loop_var, second_loop_var in zip( - captured_first_loop_vars, second_loop_vars - ): - joined_loop_vars.append(first_loop_var.join(second_loop_var)) - - # TrimYield is currently disabled meaning that the same RecordIdx - # can get copied into the parent frame twice! As a result - # we need to be careful to only add unique RecordIdx entries - witnessed_record_idxs = set() - for var in joined_loop_vars: - if isinstance(var, RecordTuple): - for member in var.members: - if ( - isinstance(member, RecordIdx) - and member.idx not in witnessed_record_idxs - ): - witnessed_record_idxs.add(member.idx) - frame.global_record_state.buffer.append(member) - - return tuple(joined_loop_vars) - - @interp.impl(scf.stmts.Yield) - def for_yield( - self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.Yield - ): - return interp.YieldValue(frame.get_values(stmt.values)) - - -@py.dialect.register(key="record") -class ConstantForwarding(interp.MethodTable): - @interp.impl(py.Constant) - def constant( - self, - interp_: RecordAnalysis, - frame: RecordFrame, - stmt: py.Constant, - ): - # can't use interp_.maybe_const/expect_const because it assumes the data is already - # there to begin with... - if not isinstance(stmt.value, PyAttr): - return (InvalidRecord(),) - - expected_int_or_slice = stmt.value.data - - if not isinstance(expected_int_or_slice, (int, slice)): - return (InvalidRecord(),) - - return (ConstantCarrier(value=expected_int_or_slice),) - - -# outside_frame -> create new frame with context manager COPIED from outside frame -# the frame and the stack are separate diff --git a/src/bloqade/analysis/record/lattice.py b/src/bloqade/analysis/record/lattice.py deleted file mode 100644 index 2fd065a8..00000000 --- a/src/bloqade/analysis/record/lattice.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import final -from dataclasses import dataclass - -from kirin.lattice import ( - SingletonMeta, - BoundedLattice, - SimpleJoinMixin, - SimpleMeetMixin, -) - -# Taken directly from Kai-Hsin Wu's implementation -# with minor changes to names and addition of CanMeasureId type - - -@dataclass -class Record( - SimpleJoinMixin["Record"], - SimpleMeetMixin["Record"], - BoundedLattice["Record"], -): - - @classmethod - def bottom(cls) -> "Record": - return InvalidRecord() - - @classmethod - def top(cls) -> "Record": - return AnyRecord() - - -# Can pop up if user constructs some list containing a mixture -# of bools from measure results and other places, -# in which case the whole list is invalid -@final -@dataclass -class InvalidRecord(Record, metaclass=SingletonMeta): - - def is_subseteq(self, other: Record) -> bool: - return True - - -@final -@dataclass -class AnyRecord(Record, metaclass=SingletonMeta): - - def is_subseteq(self, other: Record) -> bool: - return isinstance(other, AnyRecord) - - -@final -@dataclass -class NotRecord(Record, metaclass=SingletonMeta): - - def is_subseteq(self, other: Record) -> bool: - return isinstance(other, NotRecord) - - -# For now I only care about propagating constant integers or slices, -# things that can be used as indices to list of measurements -@final -@dataclass -class ConstantCarrier(Record): - value: int | slice - - def is_subseteq(self, other: Record) -> bool: - if isinstance(other, ConstantCarrier): - return self.value == other.value - return False - - -@final -@dataclass -class RecordIdx(Record): - idx: int - id: int - - def is_subseteq(self, other: Record) -> bool: - if isinstance(other, RecordIdx): - return self.idx == other.idx - return False - - -@final -@dataclass -class RecordTuple(Record): - members: tuple[RecordIdx, ...] - - def is_subseteq(self, other: Record) -> bool: - if isinstance(other, RecordTuple): - return all(a.is_subseteq(b) for a, b in zip(self.members, other.members)) - return False - - -@final -@dataclass -class ImmutableRecords(Record): - members: tuple[RecordIdx, ...] - - def is_subseteq(self, other: Record) -> bool: - if isinstance(other, ImmutableRecords): - return all(a.is_subseteq(b) for a, b in zip(self.members, other.members)) - return False diff --git a/test/analysis/record/test_record_analysis.py b/test/analysis/record/test_record_analysis.py deleted file mode 100644 index e8c436df..00000000 --- a/test/analysis/record/test_record_analysis.py +++ /dev/null @@ -1,206 +0,0 @@ -# from kirin.passes.fold import Fold - -from bloqade import squin - -# from bloqade.stim.passes import SquinToStimPass -from bloqade.analysis.record import RecordAnalysis - -# from bloqade.analysis.measure_id import MeasurementIDAnalysis -from bloqade.stim.passes.flatten_except_loops import FlattenExceptForLoop - -""" -@squin.kernel -def test(): - qs = squin.qalloc(5) - data_qs = [qs[0], qs[2], qs[4]] - and_qs = [qs[1], qs[3]] - - init_and_meas_res = squin.broadcast.measure(and_qs) - squin.set_detector([init_and_meas_res[0]], coordinates=[0, 0]) - squin.set_detector([init_and_meas_res[1]], coordinates=[0, 1]) - - and_meas_res = None - for _ in range(10): - and_meas_res = squin.broadcast.measure(and_qs) - - squin.set_detector([and_meas_res[0], init_and_meas_res[0]], coordinates=[0, 0]) - squin.set_detector([and_meas_res[1], init_and_meas_res[1]], coordinates=[1, 1]) - - init_and_meas_res = and_meas_res - - data_meas_res = squin.broadcast.measure(data_qs) - squin.set_detector( - [data_meas_res[0], data_meas_res[1], and_meas_res[0]], coordinates=[2, 0] - ) - squin.set_detector( - [data_meas_res[2], data_meas_res[1], and_meas_res[1]], coordinates=[2, 1] - ) - squin.set_observable([data_meas_res[0]]) - - # return and_meas_res - - -test.print() -SoftFlatten(dialects=test.dialects).fixpoint(test) -test.print() -frame, _ = RecordAnalysis(dialects=test.dialects).run(test) -test.print(analysis=frame.entries) -""" - -""" -def hint_const_failure(): - - @squin.kernel - def test(): - qs = squin.qalloc(3) - ms0 = squin.broadcast.measure(qs) - i = 0 - for _ in range(5): - ms1 = squin.broadcast.measure(qs) - squin.set_detector([ms0[i], ms1[i]], coordinates=[i, i]) - - # SoftFlatten(dialects=test.dialects).fixpoint(test) - Fold(dialects=test.dialects, no_raise=False).fixpoint(test) - test.print(hint="const") - # frame, _ = RecordAnalysis(dialects=test.dialects).run(test) - # test.print(analysis=frame.entries, hint="const") - - -# Problematic having the variable substitution happen at the end -""" - - -def test_custom_const_carrier(): - - @squin.kernel(fold=False) - def test(x: int): - y = None - z = None - for _ in range(5): - f = [1, 2, 3, 4, 5, 5, 6, 7, 8] - z = slice(0, 2) - y = f[z] - y[0] += x - return y, z - - FlattenExceptForLoop(dialects=test.dialects).fixpoint(test) - test.print() - frame, _ = RecordAnalysis(dialects=test.dialects).run(test) - test.print(analysis=frame.entries, hint="const") - - -""" -def assignment_last_rep_code(): - @squin.kernel(fold=True) - def test(): - - qs = squin.qalloc(5) - data_qs = [qs[0], qs[2], qs[4]] - and_qs = [qs[1], qs[3]] - - init_and_ms = squin.broadcast.measure(and_qs) - - squin.set_detector([init_and_ms[0]], coordinates=[0, 0]) - squin.set_detector([init_and_ms[1]], coordinates=[0, 1]) - - # loop_and_ms = None - for _ in range(5): - loop_and_ms = squin.broadcast.measure(and_qs) - squin.annotate.set_detector([loop_and_ms[0], init_and_ms[0]], coordinates=[0,0]) - squin.annotate.set_detector([loop_and_ms[1], init_and_ms[1]], coordinates=[1,1]) - - #for i in range(len(curr_ms)): - # squin.annotate.set_detector([curr_ms[i], prev_ms[i]], coordinates=[1,1]) - - ##init_and_ms = loop_and_ms - - #data_ms = squin.broadcast.measure(data_qs) - #squin.set_detector( - # [data_ms[0], data_ms[1], loop_and_ms[0]], coordinates=[2, 0] - #) - #squin.set_detector( - # [data_ms[2], data_ms[1], loop_and_ms[1]], coordinates=[2, 1] - #) - - - SoftFlatten(dialects=test.dialects).fixpoint(test) - test.print() - frame, _ = RecordAnalysis(dialects=test.dialects).run(test) - test.print(analysis=frame.entries, hint="const") - -""" - -""" -from kirin.prelude import structural_no_opt - -@structural_no_opt -def demo(): - - a = 0 - b= 1 - for _ in range(10): - c = b - b = a - a = c - -demo.print() -""" - - -def assignment_first_rep_code(): - @squin.kernel - def test(): - - qs = squin.qalloc(5) - data_qs = [qs[0], qs[2], qs[4]] - and_qs = [qs[1], qs[3]] - - curr_ms = squin.broadcast.measure(and_qs) # 2 meas - squin.set_detector([curr_ms[0]], coordinates=[0, 0]) - squin.set_detector([curr_ms[1]], coordinates=[0, 1]) - - for _ in range(5): - # prev lives entirely in the loop - prev_ms = curr_ms - curr_ms = squin.broadcast.measure(and_qs) # another 2 meas - squin.annotate.set_detector([prev_ms[0], curr_ms[0]], coordinates=[0, 0]) - squin.annotate.set_detector([prev_ms[1], curr_ms[1]], coordinates=[0, 1]) - - data_ms = squin.broadcast.measure(data_qs) # 3 meas - - squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0]) - squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1]) - squin.set_observable([data_ms[2]]) - - FlattenExceptForLoop(dialects=test.dialects).fixpoint(test) - # test.print() - frame, _ = RecordAnalysis(dialects=test.dialects).run(test) - test.print(analysis=frame.entries) - - # frame, _ = MeasurementIDAnalysis(dialects=test.dialects).run(test) - # test.print(analysis=frame.entries) - - -assignment_first_rep_code() - -""" -@squin.kernel -def analysis_demo(): - qs = squin.qalloc(3) - ms0 = squin.broadcast.measure(qs) - ms1 = squin.broadcast.measure(qs) - squin.set_detector(ms0, coordinates=[0, 0]) # -4 -5 -6 - squin.set_detector(ms1, coordinates=[0, 1]) # -1 -2 -3 - # squin.broadcast.measure(qs) - squin.set_detector(ms1, coordinates=[0, 2]) # -4 -5 -6 - - # get aliasing to work - ms1 = ms0 - squin.set_detector(ms1, coordinates=[1, 0]) # should also be -4 -5 -6 - - -SoftFlatten(dialects=analysis_demo.dialects).fixpoint(analysis_demo) -analysis_demo.print() -frame, _ = RecordAnalysis(dialects=analysis_demo.dialects).run(analysis_demo) -analysis_demo.print(analysis=frame.entries) -""" diff --git a/test/test_annotate.py b/test/test_annotate.py deleted file mode 100644 index f902866c..00000000 --- a/test/test_annotate.py +++ /dev/null @@ -1,35 +0,0 @@ -import io - -from kirin import ir - -from bloqade import stim, squin -from bloqade.stim.emit import EmitStimMain -from bloqade.stim.passes import SquinToStimPass - - -def codegen(mt: ir.Method): - # method should not have any arguments! - buf = io.StringIO() - emit = EmitStimMain(dialects=stim.main, io=buf) - emit.initialize() - emit.run(mt) - return buf.getvalue().strip() - - -def test_annotate(): - - @squin.kernel - def test(): - qs = squin.qalloc(4) - ms = squin.broadcast.measure(qs) - squin.set_detector([ms[0], ms[1], ms[2]], coordinates=(0, 0)) - squin.set_observable([ms[3]]) - - SquinToStimPass(dialects=test.dialects)(test) - codegen_output = codegen(test) - expected_output = ( - "MZ(0.00000000) 0 1 2 3\n" - "DETECTOR(0, 0) rec[-4] rec[-3] rec[-2]\n" - "OBSERVABLE_INCLUDE(0) rec[-1]" - ) - assert codegen_output == expected_output