From 8668d4b2414c8c7d61aecc3af730d0aaa8658843 Mon Sep 17 00:00:00 2001 From: arunjmoorthy Date: Fri, 5 Sep 2025 17:54:23 -0700 Subject: [PATCH 1/3] merge function and tests --- src/pyqasm/modules/base.py | 153 +++++++++++++++++++++++++++++++++++++ tests/qasm3/test_merge.py | 118 ++++++++++++++++++++++++++++ 2 files changed, 271 insertions(+) create mode 100644 tests/qasm3/test_merge.py diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index 3b21330a..b8d98475 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -20,6 +20,7 @@ import functools from abc import ABC, abstractmethod +import re from collections import Counter from copy import deepcopy from typing import Optional @@ -36,6 +37,7 @@ from pyqasm.visitor import QasmVisitor, ScopeManager + def track_user_operation(func): """Decorator to track user operations on a QasmModule.""" @@ -761,3 +763,154 @@ def accept(self, visitor): Args: visitor (QasmVisitor): The visitor to accept """ + + + def merge(self, other: "QasmModule", device_qubits: Optional[int] = None) -> "QasmModule": + """Merge this module with another module into a single consolidated module. + + Notes: + - Both modules are unrolled with consolidated qubit registers prior to merging. + - The resulting module has a single declaration: ``qubit[] __PYQASM_QUBITS__``. + - All quantum operations from the second module are appended after the first, with + qubit indices offset by the size of the first module. + + Args: + other (QasmModule): The module to merge with the current module. + device_qubits (int | None): Optional device qubit budget to use during unrolling. + + Returns: + QasmModule: A new Qasm3Module representing the merged program. + """ + + if not isinstance(other, QasmModule): + raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}") + + # Normalize both modules to QASM3 form (without mutating originals) + from pyqasm.modules.qasm2 import Qasm2Module # pylint: disable=import-outside-toplevel + from pyqasm.modules.qasm3 import Qasm3Module # pylint: disable=import-outside-toplevel + left_mod = self.to_qasm3(as_str=False) if isinstance(self, Qasm2Module) else self.copy() + right_mod = other.to_qasm3(as_str=False) if isinstance(other, Qasm2Module) else other.copy() + + # Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__ + unroll_kwargs: dict[str, object] = {"consolidate_qubits": True} + if device_qubits is not None: + unroll_kwargs["device_qubits"] = device_qubits + + left_mod.unroll(**unroll_kwargs) + right_mod.unroll(**unroll_kwargs) + + # Determine sizes after consolidation + left_qubits = left_mod.num_qubits + right_qubits = right_mod.num_qubits + total_qubits = left_qubits + right_qubits + + # Build a new Program. We'll add includes (unique) first, then declaration and ops + merged_program = Program(statements=[], version="3.0") + + # gets unique include filenames from both modules + # added this because we get duplicate File 'stdgates.inc' errors + include_names: list[str] = [] + for module in (left_mod, right_mod): + for stmt in module.unrolled_ast.statements: + if isinstance(stmt, qasm3_ast.Include): + if stmt.filename not in include_names: + include_names.append(stmt.filename) + for inc_name in include_names: + merged_program.statements.append(qasm3_ast.Include(filename=inc_name)) + + # single consolidated qubit declaration + merged_qubit_decl = qasm3_ast.QubitDeclaration( + size=qasm3_ast.IntegerLiteral(value=total_qubits), + qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"), + ) + merged_program.statements.append(merged_qubit_decl) + + # Append left (self) statements, skipping its consolidated qubit declaration + for stmt in left_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): + continue + merged_program.statements.append(deepcopy(stmt)) + + # Offsets indices inside a statement by a fixed amount to make sure we merge correctly + def _offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): + if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement): + # Offset measured qubit source + bit = stmt.measure.qubit + if isinstance(bit, qasm3_ast.IndexedIdentifier): + for group in bit.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + # target is classical; leave untouched + return + + if isinstance(stmt, qasm3_ast.QuantumGate): + # Offset all qubit operands + for q in stmt.qubits: + for group in q.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumReset): + q = stmt.qubits + if isinstance(q, qasm3_ast.IndexedIdentifier): + for group in q.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumBarrier): + # Barrier can be represented with IndexedIdentifier or a string slice on Identifier + qubits = stmt.qubits + if len(qubits) == 0: + return + first = qubits[0] + if isinstance(first, qasm3_ast.IndexedIdentifier): + for group in first.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + elif isinstance(first, qasm3_ast.Identifier): + # Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E] + name = first.name + if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"): + slice_str = name[len("__PYQASM_QUBITS__"):] + # Parse slice forms [S:E], [:E], or [S:] and capture optional start/end integers + m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str) + if m: + start_s, end_s = m.group(1), m.group(2) + if start_s is None and end_s is not None: + # [:E] + end_v = int(end_s) + offset + first.name = f"__PYQASM_QUBITS__[:{end_v}]" + elif start_s is not None and end_s is None: + # [S:] + start_v = int(start_s) + offset + first.name = f"__PYQASM_QUBITS__[{start_v}:]" + elif start_s is not None and end_s is not None: + # [S:E] + start_v = int(start_s) + offset + end_v = int(end_s) + offset + first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]" + return + + # Append statements with index offset, skipping its qubit declaration and include statements + for stmt in right_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): + continue + stmt_copy = deepcopy(stmt) + _offset_statement_qubits(stmt_copy, left_qubits) + merged_program.statements.append(stmt_copy) + + # Build merged module + merged_module = Qasm3Module(name=f"{left_mod.name}_merged_{right_mod.name}", program=merged_program) + + # inputs already unrolled, we can set the unrolled AST directly + merged_module.unrolled_ast = Program(statements=list(merged_program.statements), version="3.0") + + # Combine metadata/history in a straightforward manner + merged_module._external_gates = list({*left_mod._external_gates, *right_mod._external_gates}) + merged_module._user_operations = list(left_mod.history) + list(right_mod.history) + merged_module._user_operations.append(f"merge(other={right_mod.name})") + merged_module.validate() + + return merged_module diff --git a/tests/qasm3/test_merge.py b/tests/qasm3/test_merge.py new file mode 100644 index 00000000..6ed763fb --- /dev/null +++ b/tests/qasm3/test_merge.py @@ -0,0 +1,118 @@ +# Copyright 2025 qBraid +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for QasmModule.merge(). +""" + +from pyqasm.entrypoint import loads +from pyqasm.modules import QasmModule + + +def _qasm3(qasm: str) -> QasmModule: + return loads(qasm) + + +def test_merge_basic_gates_and_offsets(): + qasm_a = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[2] q;\n" + "x q[0];\n" + "cx q[0], q[1];\n" + ) + qasm_b = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[3] r;\n" + "h r[0];\n" + "cx r[1], r[2];\n" + ) + + mod_a = _qasm3(qasm_a) + mod_b = _qasm3(qasm_b) + + merged = mod_a.merge(mod_b) + + # Unrolled representation should have a single consolidated qubit declaration of size 5 + text = str(merged) + assert "qubit[5] __PYQASM_QUBITS__;" in text + + lines = [l.strip() for l in text.splitlines() if l.strip()] + # Keep only gate lines for comparison; skip version/includes/declarations + gate_lines = [ + l + for l in lines + if l[0].isalpha() + and not l.startswith("include") + and not l.startswith("OPENQASM") + and not l.startswith("qubit") + ] + assert gate_lines[0].startswith("x __PYQASM_QUBITS__[0]") + assert gate_lines[1].startswith("cx __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[1]") + assert any(l.startswith("h __PYQASM_QUBITS__[2]") for l in gate_lines) + assert any(l.startswith("cx __PYQASM_QUBITS__[3], __PYQASM_QUBITS__[4]") for l in gate_lines) + + +def test_merge_with_measurements_and_barriers(): + # Module A: 1 qubit + classical 1; has barrier and measure + qasm_a = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[1] qa; bit[1] ca;\n" + "h qa[0];\n" + "barrier qa;\n" + "ca[0] = measure qa[0];\n" + ) + # Module B: 2 qubits + classical 2 + qasm_b = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[2] qb; bit[2] cb;\n" + "x qb[1];\n" + "cb[1] = measure qb[1];\n" + ) + + mod_a = _qasm3(qasm_a) + mod_b = _qasm3(qasm_b) + + merged = mod_a.merge(mod_b) + merged_text = str(merged) + + assert "qubit[3] __PYQASM_QUBITS__;" in merged_text + assert "measure __PYQASM_QUBITS__[2];" in merged_text + assert "barrier __PYQASM_QUBITS__" in merged_text + + +def test_merge_qasm2_with_qasm3(): + qasm2 = ( + "OPENQASM 2.0;\n" + "include \"qelib1.inc\";\n" + "qreg q[1];\n" + "h q[0];\n" + ) + qasm3 = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[2] r;\n" + "x r[0];\n" + ) + + mod2 = loads(qasm2) + mod3 = loads(qasm3) + + merged = mod2.merge(mod3) + text = str(merged) + assert "qubit[3] __PYQASM_QUBITS__;" in text + assert "x __PYQASM_QUBITS__[1];" in text \ No newline at end of file From ec2dfdde71937c84bad6dabaaaac3d6516a472f5 Mon Sep 17 00:00:00 2001 From: arunjmoorthy Date: Mon, 8 Sep 2025 10:08:29 -0700 Subject: [PATCH 2/3] format changes --- examples/unroll_example.py | 2 +- src/pyqasm/modules/base.py | 16 ++++++++++++---- tests/qasm3/test_merge.py | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/unroll_example.py b/examples/unroll_example.py index 8d7f375b..b7ae33ad 100755 --- a/examples/unroll_example.py +++ b/examples/unroll_example.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, cyclic-import """ Script demonstrating how to unroll a QASM 3 program using pyqasm. diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index b8d98475..8a2f5249 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -765,7 +765,7 @@ def accept(self, visitor): """ - def merge(self, other: "QasmModule", device_qubits: Optional[int] = None) -> "QasmModule": + def merge(self, other: "QasmModule", device_qubits: Optional[int] = None) -> "QasmModule": """Merge this module with another module into a single consolidated module. Notes: @@ -902,13 +902,21 @@ def _offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): merged_program.statements.append(stmt_copy) # Build merged module - merged_module = Qasm3Module(name=f"{left_mod.name}_merged_{right_mod.name}", program=merged_program) + merged_module = Qasm3Module( + name=f"{left_mod.name}_merged_{right_mod.name}", + program=merged_program, + ) # inputs already unrolled, we can set the unrolled AST directly - merged_module.unrolled_ast = Program(statements=list(merged_program.statements), version="3.0") + merged_module.unrolled_ast = Program( + statements=list(merged_program.statements), + version="3.0", + ) # Combine metadata/history in a straightforward manner - merged_module._external_gates = list({*left_mod._external_gates, *right_mod._external_gates}) + merged_module._external_gates = list( + {*left_mod._external_gates, *right_mod._external_gates} + ) merged_module._user_operations = list(left_mod.history) + list(right_mod.history) merged_module._user_operations.append(f"merge(other={right_mod.name})") merged_module.validate() diff --git a/tests/qasm3/test_merge.py b/tests/qasm3/test_merge.py index 6ed763fb..e99f53b4 100644 --- a/tests/qasm3/test_merge.py +++ b/tests/qasm3/test_merge.py @@ -115,4 +115,4 @@ def test_merge_qasm2_with_qasm3(): merged = mod2.merge(mod3) text = str(merged) assert "qubit[3] __PYQASM_QUBITS__;" in text - assert "x __PYQASM_QUBITS__[1];" in text \ No newline at end of file + assert "x __PYQASM_QUBITS__[1];" in text From 98cc66ccf4cd3a36e09372a672eee3c13d47c34d Mon Sep 17 00:00:00 2001 From: arunjmoorthy Date: Thu, 11 Sep 2025 13:19:46 -0700 Subject: [PATCH 3/3] comment changes/moving to new PR --- src/pyqasm/modules/base.py | 219 +++++++++++------------------------- src/pyqasm/modules/qasm2.py | 73 ++++++++++++ src/pyqasm/modules/qasm3.py | 70 +++++++++++- 3 files changed, 208 insertions(+), 154 deletions(-) diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index 8a2f5249..2164705a 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -765,160 +765,73 @@ def accept(self, visitor): """ - def merge(self, other: "QasmModule", device_qubits: Optional[int] = None) -> "QasmModule": - """Merge this module with another module into a single consolidated module. - - Notes: - - Both modules are unrolled with consolidated qubit registers prior to merging. - - The resulting module has a single declaration: ``qubit[] __PYQASM_QUBITS__``. - - All quantum operations from the second module are appended after the first, with - qubit indices offset by the size of the first module. - - Args: - other (QasmModule): The module to merge with the current module. - device_qubits (int | None): Optional device qubit budget to use during unrolling. - - Returns: - QasmModule: A new Qasm3Module representing the merged program. + @abstractmethod + def merge( + self, + other: "QasmModule", + device_qubits: Optional[int] = None, + ) -> "QasmModule": + """Merge this module with another module. + + Implemented by concrete subclasses to avoid version mixing and + import-time cycles. Implementations should ensure both operands + are normalized to the same version prior to merging. """ - if not isinstance(other, QasmModule): - raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}") - - # Normalize both modules to QASM3 form (without mutating originals) - from pyqasm.modules.qasm2 import Qasm2Module # pylint: disable=import-outside-toplevel - from pyqasm.modules.qasm3 import Qasm3Module # pylint: disable=import-outside-toplevel - left_mod = self.to_qasm3(as_str=False) if isinstance(self, Qasm2Module) else self.copy() - right_mod = other.to_qasm3(as_str=False) if isinstance(other, Qasm2Module) else other.copy() - - # Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__ - unroll_kwargs: dict[str, object] = {"consolidate_qubits": True} - if device_qubits is not None: - unroll_kwargs["device_qubits"] = device_qubits - - left_mod.unroll(**unroll_kwargs) - right_mod.unroll(**unroll_kwargs) - - # Determine sizes after consolidation - left_qubits = left_mod.num_qubits - right_qubits = right_mod.num_qubits - total_qubits = left_qubits + right_qubits - - # Build a new Program. We'll add includes (unique) first, then declaration and ops - merged_program = Program(statements=[], version="3.0") - - # gets unique include filenames from both modules - # added this because we get duplicate File 'stdgates.inc' errors - include_names: list[str] = [] - for module in (left_mod, right_mod): - for stmt in module.unrolled_ast.statements: - if isinstance(stmt, qasm3_ast.Include): - if stmt.filename not in include_names: - include_names.append(stmt.filename) - for inc_name in include_names: - merged_program.statements.append(qasm3_ast.Include(filename=inc_name)) - - # single consolidated qubit declaration - merged_qubit_decl = qasm3_ast.QubitDeclaration( - size=qasm3_ast.IntegerLiteral(value=total_qubits), - qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"), - ) - merged_program.statements.append(merged_qubit_decl) - - # Append left (self) statements, skipping its consolidated qubit declaration - for stmt in left_mod.unrolled_ast.statements: - if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): - continue - merged_program.statements.append(deepcopy(stmt)) - - # Offsets indices inside a statement by a fixed amount to make sure we merge correctly - def _offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): - if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement): - # Offset measured qubit source - bit = stmt.measure.qubit - if isinstance(bit, qasm3_ast.IndexedIdentifier): - for group in bit.indices: - for ind in group: - ind.value += offset # type: ignore[attr-defined] - # target is classical; leave untouched - return - - if isinstance(stmt, qasm3_ast.QuantumGate): - # Offset all qubit operands - for q in stmt.qubits: - for group in q.indices: - for ind in group: - ind.value += offset # type: ignore[attr-defined] - return - - if isinstance(stmt, qasm3_ast.QuantumReset): - q = stmt.qubits - if isinstance(q, qasm3_ast.IndexedIdentifier): - for group in q.indices: - for ind in group: - ind.value += offset # type: ignore[attr-defined] - return - - if isinstance(stmt, qasm3_ast.QuantumBarrier): - # Barrier can be represented with IndexedIdentifier or a string slice on Identifier - qubits = stmt.qubits - if len(qubits) == 0: - return - first = qubits[0] - if isinstance(first, qasm3_ast.IndexedIdentifier): - for group in first.indices: - for ind in group: - ind.value += offset # type: ignore[attr-defined] - elif isinstance(first, qasm3_ast.Identifier): - # Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E] - name = first.name - if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"): - slice_str = name[len("__PYQASM_QUBITS__"):] - # Parse slice forms [S:E], [:E], or [S:] and capture optional start/end integers - m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str) - if m: - start_s, end_s = m.group(1), m.group(2) - if start_s is None and end_s is not None: - # [:E] - end_v = int(end_s) + offset - first.name = f"__PYQASM_QUBITS__[:{end_v}]" - elif start_s is not None and end_s is None: - # [S:] - start_v = int(start_s) + offset - first.name = f"__PYQASM_QUBITS__[{start_v}:]" - elif start_s is not None and end_s is not None: - # [S:E] - start_v = int(start_s) + offset - end_v = int(end_s) + offset - first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]" - return - - # Append statements with index offset, skipping its qubit declaration and include statements - for stmt in right_mod.unrolled_ast.statements: - if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): - continue - stmt_copy = deepcopy(stmt) - _offset_statement_qubits(stmt_copy, left_qubits) - merged_program.statements.append(stmt_copy) - - # Build merged module - merged_module = Qasm3Module( - name=f"{left_mod.name}_merged_{right_mod.name}", - program=merged_program, - ) - # inputs already unrolled, we can set the unrolled AST directly - merged_module.unrolled_ast = Program( - statements=list(merged_program.statements), - version="3.0", - ) +def offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): + """Offset qubit indices for a given statement in-place by ``offset``. - # Combine metadata/history in a straightforward manner - merged_module._external_gates = list( - {*left_mod._external_gates, *right_mod._external_gates} - ) - merged_module._user_operations = list(left_mod.history) + list(right_mod.history) - merged_module._user_operations.append(f"merge(other={right_mod.name})") - merged_module.validate() - - return merged_module + Handles gates, measurements, resets, and barriers (including slice forms). + """ + if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement): + bit = stmt.measure.qubit + if isinstance(bit, qasm3_ast.IndexedIdentifier): + for group in bit.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumGate): + for q in stmt.qubits: + for group in q.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumReset): + q = stmt.qubits + if isinstance(q, qasm3_ast.IndexedIdentifier): + for group in q.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumBarrier): + qubits = stmt.qubits + if len(qubits) == 0: + return + first = qubits[0] + if isinstance(first, qasm3_ast.IndexedIdentifier): + for group in first.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + elif isinstance(first, qasm3_ast.Identifier): + # Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E] + name = first.name + if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"): + slice_str = name[len("__PYQASM_QUBITS__"):] + # Parse slice forms [S:E], [:E], or [S:] + m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str) + if m: + start_s, end_s = m.group(1), m.group(2) + if start_s is None and end_s is not None: + end_v = int(end_s) + offset + first.name = f"__PYQASM_QUBITS__[:{end_v}]" + elif start_s is not None and end_s is None: + start_v = int(start_s) + offset + first.name = f"__PYQASM_QUBITS__[{start_v}:]" + elif start_s is not None and end_s is not None: + start_v = int(start_s) + offset + end_v = int(end_s) + offset + first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]" diff --git a/src/pyqasm/modules/qasm2.py b/src/pyqasm/modules/qasm2.py index f4b0de9d..808ed8a3 100644 --- a/src/pyqasm/modules/qasm2.py +++ b/src/pyqasm/modules/qasm2.py @@ -26,6 +26,7 @@ from pyqasm.exceptions import ValidationError from pyqasm.modules.base import QasmModule from pyqasm.modules.qasm3 import Qasm3Module +from pyqasm.modules.base import offset_statement_qubits class Qasm2Module(QasmModule): @@ -108,3 +109,75 @@ def accept(self, visitor): final_stmt_list = visitor.finalize(unrolled_stmt_list) self.unrolled_ast.statements = final_stmt_list + + def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule: + """Merge two modules and return a QASM2 result without mixing versions. + + - If ``other`` is QASM3, it is merged into this module's semantics, and + any standard gate includes are mapped to ``qelib1.inc``. + - The merged program keeps version "2.0" and prints as QASM2. + """ + if not isinstance(other, QasmModule): + raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}") + + left_mod = self.copy() + right_mod = other.copy() + + # Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__ + unroll_kwargs: dict[str, object] = {"consolidate_qubits": True} + if device_qubits is not None: + unroll_kwargs["device_qubits"] = device_qubits + + left_mod.unroll(**unroll_kwargs) + right_mod.unroll(**unroll_kwargs) + + left_qubits = left_mod.num_qubits + total_qubits = left_qubits + right_mod.num_qubits + + merged_program = Program(statements=[], version="2.0") + + # Unique includes first; map stdgates.inc -> qelib1.inc for QASM2 + include_names: list[str] = [] + for module in (left_mod, right_mod): + for stmt in module.unrolled_ast.statements: + if isinstance(stmt, Include): + fname = stmt.filename + if fname == "stdgates.inc": + fname = "qelib1.inc" + if fname not in include_names: + include_names.append(fname) + for name in include_names: + merged_program.statements.append(Include(filename=name)) + + # Consolidated qubit declaration (converted to qreg on print) + merged_program.statements.append( + qasm3_ast.QubitDeclaration( + size=qasm3_ast.IntegerLiteral(value=total_qubits), + qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"), + ) + ) + + # Append left ops (skip decls and includes) + for stmt in left_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)): + continue + merged_program.statements.append(deepcopy(stmt)) + + # Append right ops with index offset + for stmt in right_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)): + continue + stmt_copy = deepcopy(stmt) + offset_statement_qubits(stmt_copy, left_qubits) + merged_program.statements.append(stmt_copy) + + merged_module = Qasm2Module( + name=f"{left_mod.name}_merged_{right_mod.name}", + program=merged_program, + ) + merged_module.unrolled_ast = Program(statements=list(merged_program.statements), version="2.0") + merged_module._external_gates = list({*left_mod._external_gates, *right_mod._external_gates}) + merged_module._user_operations = list(left_mod.history) + list(right_mod.history) + merged_module._user_operations.append(f"merge(other={right_mod.name})") + merged_module.validate() + return merged_module diff --git a/src/pyqasm/modules/qasm3.py b/src/pyqasm/modules/qasm3.py index 8ed08d51..f8eaa971 100644 --- a/src/pyqasm/modules/qasm3.py +++ b/src/pyqasm/modules/qasm3.py @@ -16,10 +16,11 @@ Defines a module for handling OpenQASM 3.0 programs. """ +import openqasm3.ast as qasm3_ast from openqasm3.ast import Program from openqasm3.printer import dumps -from pyqasm.modules.base import QasmModule +from pyqasm.modules.base import QasmModule, offset_statement_qubits class Qasm3Module(QasmModule): @@ -52,3 +53,70 @@ def accept(self, visitor): final_stmt_list = visitor.finalize(unrolled_stmt_list) self._unrolled_ast.statements = final_stmt_list + + def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule: + """Merge two modules as OpenQASM 3.0 without mixing versions. + + If ``other`` is QASM2, it will be converted to QASM3 before merging. + The merged program keeps version "3.0". + """ + if not isinstance(other, QasmModule): + raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}") + + # Convert right to QASM3 if it supports conversion; otherwise copy + convert = getattr(other, "to_qasm3", None) + right_mod = convert(as_str=False) if callable(convert) else other.copy() # type: ignore[assignment] + + left_mod = self.copy() + + # Unroll with consolidation so both use __PYQASM_QUBITS__ + unroll_kwargs: dict[str, object] = {"consolidate_qubits": True} + if device_qubits is not None: + unroll_kwargs["device_qubits"] = device_qubits + + left_mod.unroll(**unroll_kwargs) + right_mod.unroll(**unroll_kwargs) + + left_qubits = left_mod.num_qubits + total_qubits = left_qubits + right_mod.num_qubits + + merged_program = Program(statements=[], version="3.0") + + # Unique includes first + include_names: list[str] = [] + for module in (left_mod, right_mod): + for stmt in module.unrolled_ast.statements: + if isinstance(stmt, qasm3_ast.Include) and stmt.filename not in include_names: + include_names.append(stmt.filename) + for name in include_names: + merged_program.statements.append(qasm3_ast.Include(filename=name)) + + # Consolidated qubit declaration + merged_program.statements.append( + qasm3_ast.QubitDeclaration( + size=qasm3_ast.IntegerLiteral(value=total_qubits), + qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"), + ) + ) + + # Append left ops + for stmt in left_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): + continue + merged_program.statements.append(stmt) + + # Append right ops with index offset + for stmt in right_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): + continue + # right_mod is a copy, so it's safe to modify statements in place + offset_statement_qubits(stmt, left_qubits) + merged_program.statements.append(stmt) + + merged_module = Qasm3Module(name=f"{left_mod.name}_merged_{right_mod.name}", program=merged_program) + merged_module.unrolled_ast = Program(statements=list(merged_program.statements), version="3.0") + merged_module._external_gates = list({*left_mod._external_gates, *right_mod._external_gates}) + merged_module._user_operations = list(left_mod.history) + list(right_mod.history) + merged_module._user_operations.append(f"merge(other={right_mod.name})") + merged_module.validate() + return merged_module