Coverage for /home/runner/work/tket/tket/pytket/pytket/qasm/qasm.py: 92%
1057 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-25 16:00 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-25 16:00 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import itertools
16import os
17import re
18import uuid
19from collections import OrderedDict
20from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
21from dataclasses import dataclass
22from decimal import Decimal
23from importlib import import_module
24from itertools import chain, groupby
25from typing import Any, NewType, TextIO, TypeVar, Union, cast
27from lark import Discard, Lark, Token, Transformer, Tree
28from sympy import Expr, Symbol, pi
30from pytket._tket.circuit import (
31 BarrierOp,
32 ClExpr,
33 ClExprOp,
34 Command,
35 Conditional,
36 CopyBitsOp,
37 MultiBitOp,
38 RangePredicateOp,
39 SetBitsOp,
40 WASMOp,
41 WiredClExpr,
42)
43from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE
44from pytket.circuit import (
45 Bit,
46 BitRegister,
47 Circuit,
48 Op,
49 OpType,
50 Qubit,
51 QubitRegister,
52 UnitID,
53)
54from pytket.circuit.clexpr import (
55 check_register_alignments,
56 has_reg_output,
57 wired_clexpr_from_logic_exp,
58)
59from pytket.circuit.decompose_classical import _int_to_bools
60from pytket.circuit.logic_exp import (
61 ArgType,
62 BitLogicExp,
63 BitWiseOp,
64 LogicExp,
65 PredicateExp,
66 RegEq,
67 RegLogicExp,
68 RegNeg,
69 RegWiseOp,
70 create_logic_exp,
71 create_predicate_exp,
72)
73from pytket.passes import (
74 AutoRebase,
75 DecomposeBoxes,
76 RemoveRedundancies,
77 scratch_reg_resize_pass,
78)
79from pytket.qasm.grammar import grammar
80from pytket.wasm import WasmFileHandler, WasmModuleHandler
83class QASMParseError(Exception):
84 """Error while parsing QASM input."""
86 def __init__(self, msg: str, line: int | None = None, fname: str | None = None):
87 self.msg = msg
88 self.line = line
89 self.fname = fname
91 ctx = "" if fname is None else f"\nFile:{fname}: "
92 ctx += "" if line is None else f"\nLine:{line}. "
94 super().__init__(f"{msg}{ctx}")
97class QASMUnsupportedError(Exception):
98 """
99 Error due to QASM input being incompatible with the supported fragment.
100 """
103Value = Union[int, float, str] # noqa: UP007
104T = TypeVar("T")
106_BITOPS = {op.value for op in BitWiseOp}
107_BITOPS.update(("+", "-")) # both are parsed to XOR
108_REGOPS = {op.value for op in RegWiseOp}
110Arg = Union[list, str] # noqa: UP007
113NOPARAM_COMMANDS = {
114 "CX": OpType.CX, # built-in gate equivalent to "cx"
115 "cx": OpType.CX,
116 "x": OpType.X,
117 "y": OpType.Y,
118 "z": OpType.Z,
119 "h": OpType.H,
120 "s": OpType.S,
121 "sdg": OpType.Sdg,
122 "t": OpType.T,
123 "tdg": OpType.Tdg,
124 "sx": OpType.SX,
125 "sxdg": OpType.SXdg,
126 "cz": OpType.CZ,
127 "cy": OpType.CY,
128 "ch": OpType.CH,
129 "csx": OpType.CSX,
130 "ccx": OpType.CCX,
131 "c3x": OpType.CnX,
132 "c4x": OpType.CnX,
133 "ZZ": OpType.ZZMax,
134 "measure": OpType.Measure,
135 "reset": OpType.Reset,
136 "id": OpType.noop,
137 "barrier": OpType.Barrier,
138 "swap": OpType.SWAP,
139 "cswap": OpType.CSWAP,
140}
141PARAM_COMMANDS = {
142 "p": OpType.U1, # alias. https://github.com/Qiskit/qiskit-terra/pull/4765
143 "u": OpType.U3, # alias. https://github.com/Qiskit/qiskit-terra/pull/4765
144 "U": OpType.U3, # built-in gate equivalent to "u3"
145 "u3": OpType.U3,
146 "u2": OpType.U2,
147 "u1": OpType.U1,
148 "rx": OpType.Rx,
149 "rxx": OpType.XXPhase,
150 "ry": OpType.Ry,
151 "rz": OpType.Rz,
152 "RZZ": OpType.ZZPhase,
153 "rzz": OpType.ZZPhase,
154 "Rz": OpType.Rz,
155 "U1q": OpType.PhasedX,
156 "crz": OpType.CRz,
157 "crx": OpType.CRx,
158 "cry": OpType.CRy,
159 "cu1": OpType.CU1,
160 "cu3": OpType.CU3,
161 "Rxxyyzz": OpType.TK2,
162}
164NOPARAM_EXTRA_COMMANDS = {
165 "v": OpType.V,
166 "vdg": OpType.Vdg,
167 "cv": OpType.CV,
168 "cvdg": OpType.CVdg,
169 "csxdg": OpType.CSXdg,
170 "bridge": OpType.BRIDGE,
171 "iswapmax": OpType.ISWAPMax,
172 "zzmax": OpType.ZZMax,
173 "ecr": OpType.ECR,
174 "cs": OpType.CS,
175 "csdg": OpType.CSdg,
176}
178PARAM_EXTRA_COMMANDS = {
179 "tk2": OpType.TK2,
180 "iswap": OpType.ISWAP,
181 "phasediswap": OpType.PhasedISWAP,
182 "yyphase": OpType.YYPhase,
183 "xxphase3": OpType.XXPhase3,
184 "eswap": OpType.ESWAP,
185 "fsim": OpType.FSim,
186}
188N_PARAMS_EXTRA_COMMANDS = {
189 OpType.TK2: 3,
190 OpType.ISWAP: 1,
191 OpType.PhasedISWAP: 2,
192 OpType.YYPhase: 1,
193 OpType.XXPhase3: 1,
194 OpType.ESWAP: 1,
195 OpType.FSim: 2,
196}
198_tk_to_qasm_noparams = {item[1]: item[0] for item in NOPARAM_COMMANDS.items()}
199_tk_to_qasm_noparams[OpType.CX] = "cx" # prefer "cx" to "CX"
200_tk_to_qasm_params = {item[1]: item[0] for item in PARAM_COMMANDS.items()}
201_tk_to_qasm_params[OpType.U3] = "u3" # prefer "u3" to "U"
202_tk_to_qasm_params[OpType.Rz] = "rz" # prefer "rz" to "Rz"
203_tk_to_qasm_extra_noparams = {
204 item[1]: item[0] for item in NOPARAM_EXTRA_COMMANDS.items()
205}
206_tk_to_qasm_extra_params = {item[1]: item[0] for item in PARAM_EXTRA_COMMANDS.items()}
208_classical_gatestr_map = {"AND": "&", "OR": "|", "XOR": "^"}
211_all_known_gates = (
212 set(NOPARAM_COMMANDS.keys())
213 .union(PARAM_COMMANDS.keys())
214 .union(PARAM_EXTRA_COMMANDS.keys())
215 .union(NOPARAM_EXTRA_COMMANDS.keys())
216)
217_all_string_maps = {
218 key: val.name
219 for key, val in chain(
220 PARAM_COMMANDS.items(),
221 NOPARAM_COMMANDS.items(),
222 PARAM_EXTRA_COMMANDS.items(),
223 NOPARAM_EXTRA_COMMANDS.items(),
224 )
225}
227unit_regex = re.compile(r"([a-z][a-zA-Z0-9_]*)\[([\d]+)\]")
228regname_regex = re.compile(r"^[a-z][a-zA-Z0-9_]*$")
231def _extract_reg(var: Token) -> tuple[str, int]:
232 match = unit_regex.match(var.value)
233 if match is None:
234 raise QASMParseError(
235 f"Invalid register definition '{var.value}'. Register definitions "
236 "must follow the pattern '<name> [<size in integer>]'. "
237 "For example, 'q [5]'. QASM register names must begin with a "
238 "lowercase letter and may only contain lowercase and uppercase "
239 "letters, numbers, and underscores."
240 )
241 return match.group(1), int(match.group(2))
244def _load_include_module(
245 header_name: str, do_filter: bool, decls_only: bool
246) -> dict[str, dict]:
247 try:
248 if decls_only:
249 include_def: dict[str, dict] = import_module( # noqa: SLF001
250 f"pytket.qasm.includes._{header_name}_decls"
251 )._INCLUDE_DECLS
252 else:
253 include_def = import_module( # noqa: SLF001
254 f"pytket.qasm.includes._{header_name}_defs"
255 )._INCLUDE_DEFS
256 except ModuleNotFoundError as e:
257 raise QASMParseError(
258 f"Header {header_name} is not known and cannot be loaded."
259 ) from e
260 return {
261 gate: include_def[gate]
262 for gate in include_def
263 if not do_filter or gate not in _all_known_gates
264 }
267def _bin_par_exp(op: "str") -> Callable[["_CircuitTransformer", list[str]], str]:
268 def f(self: "_CircuitTransformer", vals: list[str]) -> str:
269 return f"({vals[0]} {op} {vals[1]})"
271 return f
274def _un_par_exp(op: "str") -> Callable[["_CircuitTransformer", list[str]], str]:
275 def f(self: "_CircuitTransformer", vals: list[str]) -> str:
276 return f"({op}{vals[0]})"
278 return f
281def _un_call_exp(op: "str") -> Callable[["_CircuitTransformer", list[str]], str]:
282 def f(self: "_CircuitTransformer", vals: list[str]) -> str:
283 return f"{op}({vals[0]})"
285 return f
288def _hashable_uid(arg: list) -> tuple[str, int]:
289 return arg[0], arg[1][0]
292def _can_treat_as_bit(arg: ArgType) -> bool:
293 if isinstance(arg, Bit | BitLogicExp):
294 return True
295 if isinstance(arg, int):
296 return arg in (0, 1)
297 if isinstance(arg, BitRegister):
298 return arg.size == 1
299 return False
302Reg = NewType("Reg", str)
303CommandDict = dict[str, Any]
306@dataclass
307class _ParsMap:
308 pars: Iterable[str]
310 def __iter__(self) -> Iterable[str]:
311 return self.pars
314class _CircuitTransformer(Transformer):
315 def __init__(
316 self,
317 return_gate_dict: bool = False,
318 maxwidth: int = 32,
319 ) -> None:
320 super().__init__()
321 self.q_registers: dict[str, int] = {}
322 self.c_registers: dict[str, int] = {}
323 self.gate_dict: dict[str, dict] = {}
324 self.wasm: WasmModuleHandler | None = None
325 self.include = ""
326 self.return_gate_dict = return_gate_dict
327 self.maxwidth = maxwidth
329 def _fresh_temp_bit(self) -> list:
330 if _TEMP_BIT_NAME in self.c_registers:
331 idx = self.c_registers[_TEMP_BIT_NAME]
332 else:
333 idx = 0
334 self.c_registers[_TEMP_BIT_NAME] = idx + 1
336 return [_TEMP_BIT_NAME, [idx]]
338 def _reset_context(self, reset_wasm: bool = True) -> None:
339 self.q_registers = {}
340 self.c_registers = {}
341 self.gate_dict = {}
342 self.include = ""
343 if reset_wasm:
344 self.wasm = None
346 def _get_reg(self, name: str) -> Reg:
347 return Reg(name)
349 def _get_uid(self, iarg: Token) -> list:
350 name, idx = _extract_reg(iarg)
351 return [name, [idx]]
353 def _get_arg(self, arg: Token) -> Arg:
354 if arg.type == "IARG":
355 return self._get_uid(arg)
356 return self._get_reg(arg.value)
358 def unroll_all_args(self, args: Iterable[Arg]) -> Iterator[list[Any]]:
359 for arg in args:
360 if isinstance(arg, str):
361 size = (
362 self.q_registers[arg]
363 if arg in self.q_registers
364 else self.c_registers[arg]
365 )
366 yield [[arg, [idx]] for idx in range(size)]
367 else:
368 yield [arg]
370 def margs(self, tree: Iterable[Token]) -> Iterator[Arg]:
371 return map(self._get_arg, tree)
373 def iargs(self, tree: Iterable[Token]) -> Iterator[list]:
374 return map(self._get_uid, tree)
376 def args(self, tree: Iterable[Token]) -> Iterator[list]:
377 return ([tok.value, [0]] for tok in tree)
379 def creg(self, tree: list[Token]) -> None:
380 name, size = _extract_reg(tree[0])
381 if size > self.maxwidth:
382 raise QASMUnsupportedError(
383 f"Circuit contains classical register {name} of size {size} > "
384 f"{self.maxwidth}: try setting the `maxwidth` parameter to a larger "
385 "value."
386 )
387 self.c_registers[Reg(name)] = size
389 def qreg(self, tree: list[Token]) -> None:
390 name, size = _extract_reg(tree[0])
391 self.q_registers[Reg(name)] = size
393 def meas(self, tree: list[Token]) -> Iterable[CommandDict]:
394 for args in zip(*self.unroll_all_args(self.margs(tree)), strict=False):
395 yield {"args": list(args), "op": {"type": "Measure"}}
397 def barr(self, tree: list[Arg]) -> Iterable[CommandDict]:
398 args = [q for qs in self.unroll_all_args(tree[0]) for q in qs]
399 signature: list[str] = []
400 for arg in args:
401 if arg[0] in self.c_registers:
402 signature.append("C")
403 elif arg[0] in self.q_registers: 403 ↛ 406line 403 didn't jump to line 406 because the condition on line 403 was always true
404 signature.append("Q")
405 else:
406 raise QASMParseError(
407 "UnitID " + str(arg) + " in Barrier arguments is not declared."
408 )
409 yield {
410 "args": args,
411 "op": {"signature": signature, "type": "Barrier"},
412 }
414 def reset(self, tree: list[Token]) -> Iterable[CommandDict]:
415 for qb in next(self.unroll_all_args(self.margs(tree))):
416 yield {"args": [qb], "op": {"type": "Reset"}}
418 def pars(self, vals: Iterable[str]) -> _ParsMap:
419 return _ParsMap(map(str, vals))
421 def mixedcall(self, tree: list) -> Iterator[CommandDict]:
422 child_iter = iter(tree)
424 optoken = next(child_iter)
425 opstr = optoken.value
426 next_tree = next(child_iter)
427 try:
428 args = next(child_iter)
429 pars = cast("_ParsMap", next_tree).pars
430 except StopIteration:
431 args = next_tree
432 pars = []
434 treat_as_barrier = [
435 "sleep",
436 "order2",
437 "order3",
438 "order4",
439 "order5",
440 "order6",
441 "order7",
442 "order8",
443 "order9",
444 "order10",
445 "order11",
446 "order12",
447 "order13",
448 "order14",
449 "order15",
450 "order16",
451 "order17",
452 "order18",
453 "order19",
454 "order20",
455 "group2",
456 "group3",
457 "group4",
458 "group5",
459 "group6",
460 "group7",
461 "group8",
462 "group9",
463 "group10",
464 "group11",
465 "group12",
466 "group13",
467 "group14",
468 "group15",
469 "group16",
470 "group17",
471 "group18",
472 "group19",
473 "group20",
474 ]
475 # other opaque gates, which are not handled as barrier
476 # ["RZZ", "Rxxyyzz", "Rxxyyzz_zphase", "cu", "cp", "rccx", "rc3x", "c3sqrtx"]
478 args = list(args)
480 if opstr in treat_as_barrier:
481 params = [f"{par}" for par in pars]
482 else:
483 params = [f"({par})/pi" for par in pars]
484 if opstr in self.gate_dict:
485 op: dict[str, Any] = {}
486 if opstr in treat_as_barrier:
487 op["type"] = "Barrier"
488 param_sorted = ",".join(params)
490 op["data"] = f"{opstr}({param_sorted})"
492 op["signature"] = [arg[0] for arg in args]
493 else:
494 gdef = self.gate_dict[opstr]
495 op["type"] = "CustomGate"
496 box = {
497 "type": "CustomGate",
498 "id": str(uuid.uuid4()),
499 "gate": gdef,
500 }
501 box["params"] = params
502 op["box"] = box
503 params = [] # to stop duplication in to op
504 else:
505 try:
506 optype = _all_string_maps[opstr]
507 except KeyError as e:
508 raise QASMParseError(
509 f"Cannot parse gate of type: {opstr}", optoken.line
510 ) from e
511 op = {"type": optype}
512 if params:
513 op["params"] = params
514 # Operations needing special handling:
515 if optype.startswith("Cn"):
516 # n-controlled rotations have variable signature
517 op["n_qb"] = len(args)
518 elif optype == "Barrier":
519 op["signature"] = ["Q"] * len(args)
521 for arg in zip(*self.unroll_all_args(args), strict=False):
522 yield {"args": list(arg), "op": op}
524 def gatecall(self, tree: list) -> Iterable[CommandDict]:
525 return self.mixedcall(tree)
527 def exp_args(self, tree: Iterable[Token]) -> Iterable[Reg]:
528 for arg in tree:
529 if arg.type == "ARG": 529 ↛ 532line 529 didn't jump to line 532 because the condition on line 529 was always true
530 yield self._get_reg(arg.value)
531 else:
532 raise QASMParseError(
533 "Non register arguments not supported for extern call.", arg.line
534 )
536 def _logic_exp(self, tree: list, opstr: str) -> LogicExp:
537 args, line = self._get_logic_args(tree)
538 openum: type[BitWiseOp] | type[RegWiseOp]
539 if opstr in _BITOPS and opstr not in _REGOPS: 539 ↛ 540line 539 didn't jump to line 540 because the condition on line 539 was never true
540 openum = BitWiseOp
541 elif (opstr in _REGOPS and opstr not in _BITOPS) or all(
542 isinstance(arg, int) for arg in args
543 ):
544 openum = RegWiseOp
545 elif all(_can_treat_as_bit(arg) for arg in args):
546 openum = BitWiseOp
547 else:
548 openum = RegWiseOp
549 if openum is BitWiseOp and opstr in ("+", "-"):
550 op: BitWiseOp | RegWiseOp = BitWiseOp.XOR
551 else:
552 op = openum(opstr)
553 return create_logic_exp(op, args)
555 def _get_logic_args(
556 self, tree: Sequence[Token | LogicExp]
557 ) -> tuple[list[LogicExp | Bit | BitRegister | int], int | None]:
558 args: list[LogicExp | Bit | BitRegister | int] = []
559 line = None
560 for tok in tree:
561 if isinstance(tok, LogicExp):
562 args.append(tok)
563 elif isinstance(tok, Token): 563 ↛ 574line 563 didn't jump to line 574 because the condition on line 563 was always true
564 line = tok.line
565 if tok.type == "INT":
566 args.append(int(tok.value))
567 elif tok.type == "IARG":
568 args.append(Bit(*_extract_reg(tok)))
569 elif tok.type == "ARG": 569 ↛ 572line 569 didn't jump to line 572 because the condition on line 569 was always true
570 args.append(BitRegister(tok.value, self.c_registers[tok.value]))
571 else:
572 raise QASMParseError(f"Could not pass argument {tok}")
573 else:
574 raise QASMParseError(f"Could not pass argument {tok}")
575 return args, line
577 par_add = _bin_par_exp("+")
578 par_sub = _bin_par_exp("-")
579 par_mul = _bin_par_exp("*")
580 par_div = _bin_par_exp("/")
581 par_pow = _bin_par_exp("**")
583 par_neg = _un_par_exp("-")
585 sqrt = _un_call_exp("sqrt")
586 sin = _un_call_exp("sin")
587 cos = _un_call_exp("cos")
588 tan = _un_call_exp("tan")
589 ln = _un_call_exp("ln")
591 b_and = lambda self, tree: self._logic_exp(tree, "&")
592 b_not = lambda self, tree: self._logic_exp(tree, "~")
593 b_or = lambda self, tree: self._logic_exp(tree, "|")
594 xor = lambda self, tree: self._logic_exp(tree, "^")
595 lshift = lambda self, tree: self._logic_exp(tree, "<<")
596 rshift = lambda self, tree: self._logic_exp(tree, ">>")
597 add = lambda self, tree: self._logic_exp(tree, "+")
598 sub = lambda self, tree: self._logic_exp(tree, "-")
599 mul = lambda self, tree: self._logic_exp(tree, "*")
600 div = lambda self, tree: self._logic_exp(tree, "/")
601 ipow = lambda self, tree: self._logic_exp(tree, "**")
603 def neg(self, tree: list[Token | LogicExp]) -> RegNeg:
604 arg = self._get_logic_args(tree)[0][0]
605 assert isinstance(arg, RegLogicExp | BitRegister | int)
606 return RegNeg(arg)
608 def cond(self, tree: list[Token]) -> PredicateExp:
609 op: BitWiseOp | RegWiseOp
610 arg: Bit | BitRegister
611 if tree[1].type == "IARG":
612 arg = Bit(*_extract_reg(tree[1]))
613 op = BitWiseOp(str(tree[2]))
614 else:
615 arg = BitRegister(tree[1].value, self.c_registers[tree[1].value])
616 op = RegWiseOp(str(tree[2]))
618 return create_predicate_exp(op, [arg, int(tree[3].value)])
620 def ifc(self, tree: Sequence) -> Iterable[CommandDict]:
621 condition = cast("PredicateExp", tree[0])
623 var, val = condition.args
624 condition_bits = []
626 if isinstance(var, Bit):
627 assert condition.op in (BitWiseOp.EQ, BitWiseOp.NEQ)
628 assert isinstance(val, int)
629 assert val in (0, 1)
630 if condition.op == BitWiseOp.NEQ:
631 condition.op = BitWiseOp.EQ
632 val = 1 ^ val
633 condition_bits = [var.to_list()]
635 else:
636 assert isinstance(var, BitRegister)
637 reg_bits = next(self.unroll_all_args([var.name]))
638 if isinstance(condition, RegEq):
639 # special case for base qasm
640 condition_bits = reg_bits
641 else:
642 pred_val = cast("int", val)
643 minval = 0
644 maxval = (1 << self.maxwidth) - 1
645 if condition.op == RegWiseOp.LT:
646 maxval = pred_val - 1
647 elif condition.op == RegWiseOp.GT:
648 minval = pred_val + 1
649 if condition.op in (RegWiseOp.LEQ, RegWiseOp.EQ, RegWiseOp.NEQ):
650 maxval = pred_val
651 if condition.op in (RegWiseOp.GEQ, RegWiseOp.EQ, RegWiseOp.NEQ):
652 minval = pred_val
654 condition_bit = self._fresh_temp_bit()
655 yield {
656 "args": [*reg_bits, condition_bit],
657 "op": {
658 "classical": {
659 "lower": minval,
660 "n_i": len(reg_bits),
661 "upper": maxval,
662 },
663 "type": "RangePredicate",
664 },
665 }
666 condition_bits = [condition_bit]
667 val = int(condition.op != RegWiseOp.NEQ)
669 for com in filter(lambda x: x is not None and x is not Discard, tree[1]):
670 com["args"] = condition_bits + com["args"]
671 com["op"] = {
672 "conditional": {
673 "op": com["op"],
674 "value": val,
675 "width": len(condition_bits),
676 },
677 "type": "Conditional",
678 }
680 yield com
682 def cop(self, tree: Sequence[Iterable[CommandDict]]) -> Iterable[CommandDict]:
683 return tree[0]
685 def _calc_exp_io(
686 self, exp: LogicExp, out_args: list
687 ) -> tuple[list[list], dict[str, Any]]:
688 all_inps: list[tuple[str, int]] = []
689 for inp in exp.all_inputs_ordered():
690 if isinstance(inp, Bit):
691 all_inps.append((inp.reg_name, inp.index[0]))
692 else:
693 assert isinstance(inp, BitRegister)
694 for bit in inp:
695 all_inps.append((bit.reg_name, bit.index[0])) # noqa: PERF401
696 outs = (_hashable_uid(arg) for arg in out_args)
697 o = []
698 io = []
699 for out in outs:
700 if out in all_inps:
701 all_inps.remove(out)
702 io.append(out)
703 else:
704 o.append(out)
706 exp_args = [[x[0], [x[1]]] for x in chain.from_iterable((all_inps, io, o))]
707 numbers_dict = {
708 "n_i": len(all_inps),
709 "n_io": len(io),
710 "n_o": len(o),
711 }
712 return exp_args, numbers_dict
714 def _clexpr_dict(self, exp: LogicExp, out_args: list[list]) -> CommandDict:
715 # Convert the LogicExp to a serialization of a command containing the
716 # corresponding ClExprOp.
717 wexpr, args = wired_clexpr_from_logic_exp(
718 exp, [Bit.from_list(arg) for arg in out_args]
719 )
720 return {
721 "op": {
722 "type": "ClExpr",
723 "expr": wexpr.to_dict(),
724 },
725 "args": [arg.to_list() for arg in args],
726 }
728 def assign(self, tree: list) -> Iterable[CommandDict]: # noqa: PLR0912
729 child_iter = iter(tree)
730 out_args = list(next(child_iter))
731 args_uids = list(self.unroll_all_args(out_args))
733 exp_tree = next(child_iter)
735 exp: str | list | LogicExp | int = ""
736 line = None
737 if isinstance(exp_tree, Token):
738 if exp_tree.type == "INT":
739 exp = int(exp_tree.value)
740 elif exp_tree.type in ("ARG", "IARG"): 740 ↛ 742line 740 didn't jump to line 742 because the condition on line 740 was always true
741 exp = self._get_arg(exp_tree)
742 line = exp_tree.line
743 elif isinstance(exp_tree, Generator):
744 # assume to be extern (wasm) call
745 chained_uids = list(chain.from_iterable(args_uids))
746 com = next(exp_tree)
747 com["args"].pop() # remove the wasmstate from the args
748 com["args"] += chained_uids
749 com["args"].append(["_w", [0]])
750 com["op"]["wasm"]["n"] += len(chained_uids)
751 com["op"]["wasm"]["width_o_parameter"] = [
752 self.c_registers[reg] for reg in out_args
753 ]
755 yield com
756 return
757 else:
758 exp = exp_tree
760 assert len(out_args) == 1
761 out_arg = out_args[0]
762 args = args_uids[0]
763 if isinstance(out_arg, list):
764 if isinstance(exp, LogicExp):
765 yield self._clexpr_dict(exp, args)
766 elif isinstance(exp, int | bool):
767 assert exp in (0, 1, True, False)
768 yield {
769 "args": args,
770 "op": {"classical": {"values": [bool(exp)]}, "type": "SetBits"},
771 }
772 elif isinstance(exp, list): 772 ↛ 778line 772 didn't jump to line 778 because the condition on line 772 was always true
773 yield {
774 "args": [exp, *args],
775 "op": {"classical": {"n_i": 1}, "type": "CopyBits"},
776 }
777 else:
778 raise QASMParseError(f"Unexpected expression in assignment {exp}", line)
779 else:
780 reg = out_arg
781 if isinstance(exp, RegLogicExp):
782 yield self._clexpr_dict(exp, args)
783 elif isinstance(exp, BitLogicExp):
784 yield self._clexpr_dict(exp, args[:1])
785 elif isinstance(exp, int):
786 yield {
787 "args": args,
788 "op": {
789 "classical": {
790 "values": _int_to_bools(exp, self.c_registers[reg])
791 },
792 "type": "SetBits",
793 },
794 }
796 elif isinstance(exp, str): 796 ↛ 804line 796 didn't jump to line 804 because the condition on line 796 was always true
797 width = min(self.c_registers[exp], len(args))
798 yield {
799 "args": [[exp, [i]] for i in range(width)] + args[:width],
800 "op": {"classical": {"n_i": width}, "type": "CopyBits"},
801 }
803 else:
804 raise QASMParseError(f"Unexpected expression in assignment {exp}", line)
806 def extern(self, tree: list[Any]) -> Any:
807 # TODO parse extern defs
808 return Discard
810 def ccall(self, tree: list) -> Iterable[CommandDict]:
811 return self.cce_call(tree)
813 def cce_call(self, tree: list) -> Iterable[CommandDict]:
814 name = tree[0].value
815 params = list(tree[1])
816 if self.wasm is None: 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true
817 raise QASMParseError(
818 "Cannot include extern calls without a wasm module specified.",
819 tree[0].line,
820 )
821 n_i_vec = [self.c_registers[reg] for reg in params]
823 wasm_args = list(chain.from_iterable(self.unroll_all_args(params)))
825 wasm_args.append(["_w", [0]])
827 yield {
828 "args": wasm_args,
829 "op": {
830 "type": "WASM",
831 "wasm": {
832 "func_name": name,
833 "ww_n": 1,
834 "n": sum(n_i_vec),
835 "width_i_parameter": n_i_vec,
836 "width_o_parameter": [], # this will be set in the assign function
837 "wasm_file_uid": str(self.wasm),
838 },
839 },
840 }
842 def transform(self, tree: Tree) -> dict[str, Any]:
843 self._reset_context()
844 return cast("dict[str, Any]", super().transform(tree))
846 def gdef(self, tree: list) -> None:
847 child_iter = iter(tree)
848 gate = next(child_iter).value
849 next_tree = next(child_iter)
850 symbols, args = [], []
851 if isinstance(next_tree, _ParsMap):
852 symbols = list(next_tree.pars)
853 args = list(next(child_iter))
854 else:
855 args = list(next_tree)
857 symbol_map = {sym: sym * pi for sym in map(Symbol, symbols)}
858 rename_map = {Qubit.from_list(qb): Qubit("q", i) for i, qb in enumerate(args)}
860 new = _CircuitTransformer(maxwidth=self.maxwidth)
861 circ_dict = new.prog(child_iter)
863 circ_dict["qubits"] = args
864 gate_circ = Circuit.from_dict(circ_dict)
866 # check to see whether gate definition was generated by pytket converter
867 # if true, add op as pytket Op
868 existing_op: bool = False
869 # NOPARAM_EXTRA_COMMANDS and PARAM_EXTRA_COMMANDS are
870 # gates that aren't in the standard qasm spec but in the standard TKET
871 # optypes
872 if gate in NOPARAM_EXTRA_COMMANDS:
873 qubit_args = [
874 Qubit(gate + "q" + str(index), 0) for index in list(range(len(args)))
875 ]
876 comparison_circ = _get_gate_circuit(
877 NOPARAM_EXTRA_COMMANDS[gate], qubit_args
878 )
879 if circuit_to_qasm_str(
880 comparison_circ, maxwidth=self.maxwidth
881 ) == circuit_to_qasm_str(gate_circ, maxwidth=self.maxwidth):
882 existing_op = True
883 elif gate in PARAM_EXTRA_COMMANDS:
884 optype = PARAM_EXTRA_COMMANDS[gate]
885 # we check this here, as _get_gate_circuit will find issue if it isn't true
886 # the later existing_op=all check will make sure it's the same circuit later
887 if len(symbols) != N_PARAMS_EXTRA_COMMANDS[optype]:
888 existing_op = False
889 else:
890 qubit_args = [
891 Qubit(gate + "q" + str(index), 0) for index in range(len(args))
892 ]
893 comparison_circ = _get_gate_circuit(
894 optype,
895 qubit_args,
896 [
897 Symbol("param" + str(index) + "/pi")
898 for index in range(len(symbols))
899 ],
900 )
901 # checks that each command has same string
902 existing_op = all(
903 str(g) == str(c)
904 for g, c in zip(
905 gate_circ.get_commands(),
906 comparison_circ.get_commands(),
907 strict=False,
908 )
909 )
910 if not existing_op:
911 gate_circ.symbol_substitution(symbol_map)
912 gate_circ.rename_units(cast("dict[UnitID, UnitID]", rename_map))
913 self.gate_dict[gate] = {
914 "definition": gate_circ.to_dict(),
915 "args": symbols,
916 "name": gate,
917 }
919 opaq = gdef
921 def oqasm(self, tree: list) -> Any:
922 return Discard
924 def incl(self, tree: list[Token]) -> None:
925 self.include = str(tree[0].value).split(".")[0]
926 self.gate_dict.update(_load_include_module(self.include, True, False))
928 def prog(self, tree: Iterable) -> dict[str, Any]:
929 outdict: dict[str, Any] = {
930 "commands": list(
931 chain.from_iterable(
932 filter(lambda x: x is not None and x is not Discard, tree)
933 )
934 )
935 }
936 if self.return_gate_dict:
937 return self.gate_dict
938 outdict["qubits"] = [
939 [reg, [i]] for reg, size in self.q_registers.items() for i in range(size)
940 ]
941 outdict["bits"] = [
942 [reg, [i]] for reg, size in self.c_registers.items() for i in range(size)
943 ]
944 outdict["implicit_permutation"] = [[q, q] for q in outdict["qubits"]]
945 outdict["phase"] = "0.0"
946 self._reset_context()
947 return outdict
950def _parser(maxwidth: int) -> Lark:
951 return Lark(
952 grammar,
953 start="prog",
954 debug=False,
955 parser="lalr",
956 cache=True,
957 transformer=_CircuitTransformer(maxwidth=maxwidth),
958 )
961g_parser = None
962g_maxwidth = 32
965def _set_parser(maxwidth: int) -> None:
966 global g_parser, g_maxwidth # noqa: PLW0603
967 if (g_parser is None) or (g_maxwidth != maxwidth): # type: ignore
968 g_parser = _parser(maxwidth=maxwidth)
969 g_maxwidth = maxwidth
972def circuit_from_qasm(
973 input_file: Union[str, "os.PathLike[Any]"],
974 encoding: str = "utf-8",
975 maxwidth: int = 32,
976) -> Circuit:
977 """A method to generate a tket Circuit from a qasm file.
979 :param input_file: path to qasm file; filename must have ``.qasm`` extension
980 :param encoding: file encoding (default utf-8)
981 :param maxwidth: maximum allowed width of classical registers (default 32)
982 :return: pytket circuit
983 """
984 ext = os.path.splitext(input_file)[-1]
985 if ext != ".qasm": 985 ↛ 986line 985 didn't jump to line 986 because the condition on line 985 was never true
986 raise TypeError("Can only convert .qasm files")
987 with open(input_file, encoding=encoding) as f:
988 try:
989 circ = circuit_from_qasm_io(f, maxwidth=maxwidth)
990 except QASMParseError as e:
991 raise QASMParseError(e.msg, e.line, str(input_file)) # noqa: B904
992 return circ
995def circuit_from_qasm_str(qasm_str: str, maxwidth: int = 32) -> Circuit:
996 """A method to generate a tket Circuit from a qasm string.
998 :param qasm_str: qasm string
999 :param maxwidth: maximum allowed width of classical registers (default 32)
1000 :return: pytket circuit
1001 """
1002 global g_parser # noqa: PLW0602
1003 _set_parser(maxwidth=maxwidth)
1004 assert g_parser is not None
1005 cast("_CircuitTransformer", g_parser.options.transformer)._reset_context( # noqa: SLF001
1006 reset_wasm=False
1007 )
1008 circ = Circuit.from_dict(g_parser.parse(qasm_str)) # type: ignore[arg-type]
1009 cpass = scratch_reg_resize_pass(maxwidth)
1010 cpass.apply(circ)
1011 return circ
1014def circuit_from_qasm_io(stream_in: TextIO, maxwidth: int = 32) -> Circuit:
1015 """A method to generate a tket Circuit from a qasm text stream"""
1016 return circuit_from_qasm_str(stream_in.read(), maxwidth=maxwidth)
1019def circuit_from_qasm_wasm(
1020 input_file: Union[str, "os.PathLike[Any]"],
1021 wasm_file: Union[str, "os.PathLike[Any]"],
1022 encoding: str = "utf-8",
1023 maxwidth: int = 32,
1024) -> Circuit:
1025 """A method to generate a tket Circuit from a qasm string and external WASM module.
1027 :param input_file: path to qasm file; filename must have ``.qasm`` extension
1028 :param wasm_file: path to WASM file containing functions used in qasm
1029 :param encoding: encoding of qasm file (default utf-8)
1030 :param maxwidth: maximum allowed width of classical registers (default 32)
1031 :return: pytket circuit
1032 """
1033 global g_parser # noqa: PLW0602
1034 wasm_module = WasmFileHandler(str(wasm_file))
1035 _set_parser(maxwidth=maxwidth)
1036 assert g_parser is not None
1037 cast("_CircuitTransformer", g_parser.options.transformer).wasm = wasm_module
1038 return circuit_from_qasm(input_file, encoding=encoding, maxwidth=maxwidth)
1041def circuit_from_qasm_str_wasm(
1042 qasm_str: str,
1043 wasm: bytes,
1044 maxwidth: int = 32,
1045) -> Circuit:
1046 """A method to generate a tket Circuit from a qasm string and external WASM module.
1048 :param qasm_str: qasm string
1049 :param wasm: bytes of the corresponding wasm module
1050 :param maxwidth: maximum allowed width of classical registers (default 32)
1051 :return: pytket circuit
1052 """
1053 global g_parser # noqa: PLW0602
1054 wasm_module = WasmModuleHandler(wasm)
1055 _set_parser(maxwidth=maxwidth)
1056 assert g_parser is not None
1057 cast("_CircuitTransformer", g_parser.options.transformer).wasm = wasm_module
1058 return circuit_from_qasm_str(qasm_str, maxwidth=maxwidth)
1061def circuit_from_qasm_str_wasmmh(
1062 qasm_str: str,
1063 wasmmh: WasmModuleHandler,
1064 maxwidth: int = 32,
1065) -> Circuit:
1066 """A method to generate a tket Circuit from a qasm string and external WASM module.
1068 :param qasm_str: qasm string
1069 :param wasmmh: handler corresponding to the wasm module
1070 :param maxwidth: maximum allowed width of classical registers (default 32)
1071 :return: pytket circuit
1072 """
1073 global g_parser # noqa: PLW0602
1074 _set_parser(maxwidth=maxwidth)
1075 assert g_parser is not None
1076 cast("_CircuitTransformer", g_parser.options.transformer).wasm = wasmmh
1077 return circuit_from_qasm_str(qasm_str, maxwidth=maxwidth)
1080def circuit_to_qasm(
1081 circ: Circuit, output_file: str, header: str = "qelib1", maxwidth: int = 32
1082) -> None:
1083 """Convert a Circuit to QASM and write it to a file.
1085 Classical bits in the pytket circuit must be singly-indexed.
1087 Note that this will not account for implicit qubit permutations in the Circuit.
1089 :param circ: pytket circuit
1090 :param output_file: path to output qasm file
1091 :param header: qasm header (default "qelib1")
1092 :param maxwidth: maximum allowed width of classical registers (default 32)
1093 """
1094 with open(output_file, "w") as out:
1095 circuit_to_qasm_io(circ, out, header=header, maxwidth=maxwidth)
1098def _filtered_qasm_str(qasm: str) -> str:
1099 # remove any c registers starting with _TEMP_BIT_NAME
1100 # that are not being used somewhere else
1101 lines = qasm.split("\n")
1102 def_matcher = re.compile(rf"creg ({_TEMP_BIT_NAME}\_*\d*)\[\d+\]")
1103 arg_matcher = re.compile(rf"({_TEMP_BIT_NAME}\_*\d*)\[\d+\]")
1104 unused_regs = {}
1105 for i, line in enumerate(lines):
1106 if reg := def_matcher.match(line):
1107 # Mark a reg temporarily as unused
1108 unused_regs[reg.group(1)] = i
1109 elif args := arg_matcher.findall(line):
1110 # If the line contains scratch bits that are used as arguments
1111 # mark these regs as used
1112 for arg in args:
1113 if arg in unused_regs:
1114 unused_regs.pop(arg)
1115 # remove unused reg defs
1116 redundant_lines = sorted(unused_regs.values(), reverse=True)
1117 for line_index in redundant_lines:
1118 del lines[line_index]
1119 return "\n".join(lines)
1122def _is_empty_customgate(op: Op) -> bool:
1123 return op.type == OpType.CustomGate and op.get_circuit().n_gates == 0 # type: ignore
1126def _check_can_convert_circuit(circ: Circuit, header: str, maxwidth: int) -> None:
1127 if any(
1128 circ.n_gates_of_type(typ)
1129 for typ in (
1130 OpType.RangePredicate,
1131 OpType.MultiBit,
1132 OpType.ExplicitPredicate,
1133 OpType.ExplicitModifier,
1134 OpType.SetBits,
1135 OpType.CopyBits,
1136 )
1137 ) and (not _hqs_header(header)):
1138 raise QASMUnsupportedError(
1139 "Complex classical gates not supported with qelib1: try converting with "
1140 "`header=hqslib1`"
1141 )
1142 if any(bit.index[0] >= maxwidth for bit in circ.bits):
1143 raise QASMUnsupportedError(
1144 f"Circuit contains a classical register larger than {maxwidth}: try "
1145 "setting the `maxwidth` parameter to a higher value."
1146 )
1147 set_circ_register = {creg.name for creg in circ.c_registers}
1148 for b in circ.bits:
1149 if b.reg_name not in set_circ_register:
1150 raise QASMUnsupportedError(
1151 f"Circuit contains an invalid classical register {b.reg_name}."
1152 )
1153 # Empty CustomGates should have been removed by DecomposeBoxes().
1154 for cmd in circ:
1155 assert not _is_empty_customgate(cmd.op)
1156 if isinstance(cmd.op, Conditional):
1157 assert not _is_empty_customgate(cmd.op.op)
1158 if not check_register_alignments(circ): 1158 ↛ 1159line 1158 didn't jump to line 1159 because the condition on line 1158 was never true
1159 raise QASMUnsupportedError(
1160 "Circuit contains classical expressions on registers whose arguments or "
1161 "outputs are not register-aligned."
1162 )
1165def circuit_to_qasm_str(
1166 circ: Circuit,
1167 header: str = "qelib1",
1168 include_gate_defs: set[str] | None = None,
1169 maxwidth: int = 32,
1170) -> str:
1171 """Convert a Circuit to QASM and return the string.
1173 Classical bits in the pytket circuit must be singly-indexed.
1175 Note that this will not account for implicit qubit permutations in the Circuit.
1177 :param circ: pytket circuit
1178 :param header: qasm header (default "qelib1")
1179 :param output_file: path to output qasm file
1180 :param include_gate_defs: optional set of gates to include
1181 :param maxwidth: maximum allowed width of classical registers (default 32)
1182 :return: qasm string
1183 """
1185 qasm_writer = _QasmWriter(
1186 circ.qubits, circ.bits, header, include_gate_defs, maxwidth
1187 )
1188 circ1 = circ.copy()
1189 DecomposeBoxes().apply(circ1)
1190 _check_can_convert_circuit(circ1, header, maxwidth)
1191 for command in circ1:
1192 assert isinstance(command, Command)
1193 qasm_writer.add_op(command.op, command.args)
1194 return qasm_writer.finalize()
1197TypeReg = TypeVar("TypeReg", BitRegister, QubitRegister)
1200def _retrieve_registers(
1201 units: list[UnitID], reg_type: type[TypeReg]
1202) -> dict[str, TypeReg]:
1203 if any(len(unit.index) != 1 for unit in units):
1204 raise NotImplementedError("OPENQASM registers must use a single index")
1205 maxunits = map(lambda x: max(x[1]), groupby(units, key=lambda un: un.reg_name)) # noqa: C417
1206 return {
1207 maxunit.reg_name: reg_type(maxunit.reg_name, maxunit.index[0] + 1)
1208 for maxunit in maxunits
1209 }
1212def _parse_range(minval: int, maxval: int, maxwidth: int) -> tuple[str, int]:
1213 if maxwidth > 64: # noqa: PLR2004 1213 ↛ 1214line 1213 didn't jump to line 1214 because the condition on line 1213 was never true
1214 raise NotImplementedError("Register width exceeds maximum of 64.")
1216 REGMAX = (1 << maxwidth) - 1
1218 if minval > REGMAX: 1218 ↛ 1219line 1218 didn't jump to line 1219 because the condition on line 1218 was never true
1219 raise NotImplementedError("Range's lower bound exceeds register capacity.")
1220 if minval > maxval: 1220 ↛ 1221line 1220 didn't jump to line 1221 because the condition on line 1220 was never true
1221 raise NotImplementedError("Range's lower bound exceeds upper bound.")
1222 maxval = min(maxval, REGMAX)
1224 if minval == maxval:
1225 return ("==", minval)
1226 if minval == 0:
1227 return ("<=", maxval)
1228 if maxval == REGMAX: 1228 ↛ 1230line 1228 didn't jump to line 1230 because the condition on line 1228 was always true
1229 return (">=", minval)
1230 raise NotImplementedError("Range can only be bounded on one side.")
1233def _negate_comparator(comparator: str) -> str:
1234 if comparator == "==":
1235 return "!="
1236 if comparator == "!=": 1236 ↛ 1237line 1236 didn't jump to line 1237 because the condition on line 1236 was never true
1237 return "=="
1238 if comparator == "<=":
1239 return ">"
1240 if comparator == ">": 1240 ↛ 1241line 1240 didn't jump to line 1241 because the condition on line 1240 was never true
1241 return "<="
1242 if comparator == ">=": 1242 ↛ 1244line 1242 didn't jump to line 1244 because the condition on line 1242 was always true
1243 return "<"
1244 assert comparator == "<"
1245 return ">="
1248def _get_optype_and_params(op: Op) -> tuple[OpType, list[float | Expr] | None]:
1249 optype = op.type
1250 params = (
1251 op.params
1252 if (optype in _tk_to_qasm_params) or (optype in _tk_to_qasm_extra_params)
1253 else None
1254 )
1255 if optype == OpType.TK1:
1256 # convert to U3
1257 optype = OpType.U3
1258 params = [op.params[1], op.params[0] - 0.5, op.params[2] + 0.5]
1259 elif optype == OpType.CustomGate: 1259 ↛ 1260line 1259 didn't jump to line 1260 because the condition on line 1259 was never true
1260 params = op.params
1261 return optype, params
1264def _get_gate_circuit(
1265 optype: OpType, qubits: list[Qubit], symbols: list[Symbol] | None = None
1266) -> Circuit:
1267 # create Circuit for constructing qasm from
1268 unitids = cast("list[UnitID]", qubits)
1269 gate_circ = Circuit()
1270 for q in qubits:
1271 gate_circ.add_qubit(q)
1272 if symbols:
1273 exprs = [symbol.as_expr() for symbol in symbols]
1274 gate_circ.add_gate(optype, exprs, unitids)
1275 else:
1276 gate_circ.add_gate(optype, unitids)
1277 AutoRebase({OpType.CX, OpType.U3}).apply(gate_circ)
1278 RemoveRedundancies().apply(gate_circ)
1280 return gate_circ
1283def _hqs_header(header: str) -> bool:
1284 return header in ["hqslib1", "hqslib1_dev"]
1287@dataclass
1288class _ConditionString:
1289 variable: str # variable, e.g. "c[1]"
1290 comparator: str # comparator, e.g. "=="
1291 value: int # value, e.g. "1"
1294class _LabelledStringList:
1295 """
1296 Wrapper class for an ordered sequence of strings, where each string has a unique
1297 label, returned when the string is added, and a string may be removed from the
1298 sequence given its label. There is a method to retrieve the concatenation of all
1299 strings in order. The conditions (e.g. "if(c[0]==1)") for some strings are stored
1300 separately in `conditions`. These conditions will be converted to text when
1301 retrieving the full string.
1302 """
1304 def __init__(self) -> None:
1305 self.strings: OrderedDict[int, str] = OrderedDict()
1306 self.conditions: dict[int, _ConditionString] = {}
1307 self.label = 0
1309 def add_string(self, string: str) -> int:
1310 label = self.label
1311 self.strings[label] = string
1312 self.label += 1
1313 return label
1315 def get_string(self, label: int) -> str | None:
1316 return self.strings.get(label, None)
1318 def del_string(self, label: int) -> None:
1319 self.strings.pop(label, None)
1321 def get_full_string(self) -> str:
1322 strings = []
1323 for l, s in self.strings.items():
1324 condition = self.conditions.get(l)
1325 if condition is not None:
1326 strings.append(
1327 f"if({condition.variable}{condition.comparator}{condition.value}) "
1328 + s
1329 )
1330 else:
1331 strings.append(s)
1332 return "".join(strings)
1335def _make_params_str(params: list[float | Expr] | None) -> str:
1336 s = ""
1337 if params is not None: 1337 ↛ 1356line 1337 didn't jump to line 1356 because the condition on line 1337 was always true
1338 n_params = len(params)
1339 s += "("
1340 for i in range(n_params):
1341 reduced = True
1342 try:
1343 p: float | Expr = float(params[i])
1344 except TypeError:
1345 reduced = False
1346 p = params[i]
1347 if i < n_params - 1:
1348 if reduced:
1349 s += f"{p}*pi,"
1350 else:
1351 s += f"({p})*pi,"
1352 elif reduced:
1353 s += f"{p}*pi)"
1354 else:
1355 s += f"({p})*pi)"
1356 s += " "
1357 return s
1360def _make_args_str(args: Sequence[UnitID]) -> str:
1361 s = ""
1362 for i in range(len(args)):
1363 s += f"{args[i]}"
1364 if i < len(args) - 1:
1365 s += ","
1366 else:
1367 s += ";\n"
1368 return s
1371@dataclass
1372class _ScratchPredicate:
1373 variable: str # variable, e.g. "c[1]"
1374 comparator: str # comparator, e.g. "=="
1375 value: int # value, e.g. "1"
1376 dest: str # destination bit, e.g. "tk_SCRATCH_BIT[0]"
1379def _vars_overlap(v: str, w: str) -> bool:
1380 """check if two variables have overlapping bits"""
1381 v_split = v.split("[")
1382 w_split = w.split("[")
1383 if v_split[0] != w_split[0]:
1384 # different registers
1385 return False
1386 # e.g. (a[1], a), (a, a[1]), (a[1], a[1]), (a, a)
1387 return len(v_split) != len(w_split) or v == w
1390def _var_appears(v: str, s: str) -> bool:
1391 """check if variable v appears in string s"""
1392 v_split = v.split("[")
1393 if len(v_split) == 1: 1393 ↛ 1396line 1393 didn't jump to line 1396 because the condition on line 1393 was never true
1394 # check if v appears in s and is not surrounded by word characters
1395 # e.g. a = a & b or a = a[1] & b[1]
1396 return bool(re.search(r"(?<!\w)" + re.escape(v) + r"(?![\w])", s))
1397 if re.search(r"(?<!\w)" + re.escape(v), s): 1397 ↛ 1400line 1397 didn't jump to line 1400 because the condition on line 1397 was never true
1398 # check if v appears in s and is not proceeded by word characters
1399 # e.g. a[1] = a[1]
1400 return True
1401 # check the register of v appears in s
1402 # e.g. a[1] = a & b
1403 return bool(re.search(r"(?<!\w)" + re.escape(v_split[0]) + r"(?![\[\w])", s))
1406class _QasmWriter:
1407 """
1408 Helper class for converting a sequence of TKET Commands to QASM, and retrieving the
1409 final QASM string afterwards.
1410 """
1412 def __init__(
1413 self,
1414 qubits: list[Qubit],
1415 bits: list[Bit],
1416 header: str = "qelib1",
1417 include_gate_defs: set[str] | None = None,
1418 maxwidth: int = 32,
1419 ):
1420 self.header = header
1421 self.maxwidth = maxwidth
1422 self.added_gate_definitions: set[str] = set()
1423 self.include_module_gates = {"measure", "reset", "barrier"}
1424 self.include_module_gates.update(
1425 _load_include_module(header, False, True).keys()
1426 )
1427 self.prefix = ""
1428 self.gatedefs = ""
1429 self.strings = _LabelledStringList()
1431 # Record of `RangePredicate` operations that set a "scratch" bit to 0 or 1
1432 # depending on the value of the predicate. This map is consulted when we
1433 # encounter a `Conditional` operation to see if the condition bit is one of
1434 # these scratch bits, which we can then replace with the original.
1435 self.range_preds: dict[int, _ScratchPredicate] = {}
1437 if include_gate_defs is None:
1438 self.include_gate_defs = self.include_module_gates
1439 self.include_gate_defs.update(NOPARAM_EXTRA_COMMANDS.keys())
1440 self.include_gate_defs.update(PARAM_EXTRA_COMMANDS.keys())
1441 self.prefix = f'OPENQASM 2.0;\ninclude "{header}.inc";\n\n'
1442 self.qregs = _retrieve_registers(
1443 cast("list[UnitID]", qubits), QubitRegister
1444 )
1445 self.cregs = _retrieve_registers(cast("list[UnitID]", bits), BitRegister)
1446 for reg in self.qregs.values():
1447 if regname_regex.match(reg.name) is None:
1448 raise QASMUnsupportedError(
1449 f"Invalid register name '{reg.name}'. QASM register names must "
1450 "begin with a lowercase letter and may only contain lowercase "
1451 "and uppercase letters, numbers, and underscores. "
1452 "Try renaming the register with `rename_units` first."
1453 )
1454 for bit_reg in self.cregs.values():
1455 if regname_regex.match(bit_reg.name) is None: 1455 ↛ 1456line 1455 didn't jump to line 1456 because the condition on line 1455 was never true
1456 raise QASMUnsupportedError(
1457 f"Invalid register name '{bit_reg.name}'. QASM register names "
1458 "must begin with a lowercase letter and may only contain "
1459 "lowercase and uppercase letters, numbers, and underscores. "
1460 "Try renaming the register with `rename_units` first."
1461 )
1462 else:
1463 # gate definition, no header necessary for file
1464 self.include_gate_defs = include_gate_defs
1465 self.cregs = {}
1466 self.qregs = {}
1468 self.cregs_as_bitseqs = {tuple(creg) for creg in self.cregs.values()}
1470 # for holding condition values when writing Conditional blocks
1471 # the size changes when adding and removing scratch bits
1472 self.scratch_reg = BitRegister(
1473 next(
1474 f"{_TEMP_BIT_REG_BASE}_{i}"
1475 for i in itertools.count()
1476 if f"{_TEMP_BIT_REG_BASE}_{i}" not in self.qregs
1477 ),
1478 0,
1479 )
1480 # if a string writes to some classical variables, the string label and
1481 # the affected variables will be recorded.
1482 self.variable_writes: dict[int, list[str]] = {}
1484 def fresh_scratch_bit(self) -> Bit:
1485 self.scratch_reg = BitRegister(self.scratch_reg.name, self.scratch_reg.size + 1)
1486 return Bit(self.scratch_reg.name, self.scratch_reg.size - 1)
1488 def remove_last_scratch_bit(self) -> None:
1489 assert self.scratch_reg.size > 0
1490 self.scratch_reg = BitRegister(self.scratch_reg.name, self.scratch_reg.size - 1)
1492 def write_params(self, params: list[float | Expr] | None) -> None:
1493 params_str = _make_params_str(params)
1494 self.strings.add_string(params_str)
1496 def write_args(self, args: Sequence[UnitID]) -> None:
1497 args_str = _make_args_str(args)
1498 self.strings.add_string(args_str)
1500 def make_gate_definition(
1501 self,
1502 n_qubits: int,
1503 opstr: str,
1504 optype: OpType,
1505 n_params: int | None = None,
1506 ) -> str:
1507 s = "gate " + opstr + " "
1508 symbols: list[Symbol] | None = None
1509 if n_params is not None:
1510 # need to add parameters to gate definition
1511 s += "("
1512 symbols = [
1513 Symbol("param" + str(index) + "/pi") for index in range(n_params)
1514 ]
1515 symbols_header = [Symbol("param" + str(index)) for index in range(n_params)]
1516 for symbol in symbols_header[:-1]:
1517 s += symbol.name + ", "
1518 s += symbols_header[-1].name + ") "
1520 # add qubits to gate definition
1521 qubit_args = [
1522 Qubit(opstr + "q" + str(index)) for index in list(range(n_qubits))
1523 ]
1524 for qb in qubit_args[:-1]:
1525 s += str(qb) + ","
1526 s += str(qubit_args[-1]) + " {\n"
1527 # get rebased circuit for constructing qasm
1528 gate_circ = _get_gate_circuit(optype, qubit_args, symbols)
1529 # write circuit to qasm
1530 s += circuit_to_qasm_str(
1531 gate_circ, self.header, self.include_gate_defs, self.maxwidth
1532 )
1533 s += "}\n"
1534 return s
1536 def mark_as_written(self, label: int, written_variable: str) -> None:
1537 if label in self.variable_writes: 1537 ↛ 1538line 1537 didn't jump to line 1538 because the condition on line 1537 was never true
1538 self.variable_writes[label].append(written_variable)
1539 else:
1540 self.variable_writes[label] = [written_variable]
1542 def check_range_predicate(self, op: RangePredicateOp, args: list[Bit]) -> None:
1543 if (not _hqs_header(self.header)) and op.lower != op.upper: 1543 ↛ 1544line 1543 didn't jump to line 1544 because the condition on line 1543 was never true
1544 raise QASMUnsupportedError(
1545 "OpenQASM conditions must be on a register's fixed value."
1546 )
1547 variable = args[0].reg_name
1548 assert isinstance(variable, str)
1549 if op.n_inputs != self.cregs[variable].size:
1550 raise QASMUnsupportedError(
1551 "RangePredicate conditions must be an entire classical register"
1552 )
1553 if args[:-1] != self.cregs[variable].to_list(): 1553 ↛ 1554line 1553 didn't jump to line 1554 because the condition on line 1553 was never true
1554 raise QASMUnsupportedError(
1555 "RangePredicate conditions must be a single classical register"
1556 )
1558 def add_range_predicate(self, op: RangePredicateOp, args: list[Bit]) -> None:
1559 self.check_range_predicate(op, args)
1560 comparator, value = _parse_range(op.lower, op.upper, self.maxwidth)
1561 variable = args[0].reg_name
1562 dest_bit = str(args[-1])
1563 label = self.strings.add_string(
1564 "".join(
1565 [
1566 f"if({variable}{comparator}{value}) " + f"{dest_bit} = 1;\n",
1567 f"if({variable}{_negate_comparator(comparator)}{value}) "
1568 f"{dest_bit} = 0;\n",
1569 ]
1570 )
1571 )
1572 # Record this operation.
1573 # Later if we find a conditional based on dest_bit, we can replace dest_bit with
1574 # (variable, comparator, value), provided that variable hasn't been written to
1575 # in the mean time. (So we must watch for that, and remove the record from the
1576 # list if it is.)
1577 # Note that we only perform such rewrites for internal scratch bits.
1578 if dest_bit.startswith(_TEMP_BIT_NAME):
1579 self.range_preds[label] = _ScratchPredicate(
1580 variable, comparator, value, dest_bit
1581 )
1583 def replace_condition(self, pred_label: int) -> bool:
1584 """Given the label of a predicate p=(var, comp, value, dest, label)
1585 we scan the lines after p:
1586 1.if dest is the condition of a conditional line we replace dest with
1587 the predicate and do 2 for the inner command.
1588 2.if either the variable or the dest gets written, we stop.
1589 returns true if a replacement is made.
1590 """
1591 assert pred_label in self.range_preds
1592 success = False
1593 pred = self.range_preds[pred_label]
1594 line_labels = []
1595 for label in range(pred_label + 1, self.strings.label):
1596 string = self.strings.get_string(label)
1597 if string is None:
1598 continue
1599 line_labels.append(label)
1600 if "\n" not in string:
1601 continue
1602 written_variables: list[str] = []
1603 # (label, condition)
1604 conditions: list[tuple[int, _ConditionString]] = []
1605 for l in line_labels:
1606 written_variables.extend(self.variable_writes.get(l, []))
1607 cond = self.strings.conditions.get(l)
1608 if cond:
1609 conditions.append((l, cond))
1610 if len(conditions) == 1 and pred.dest == conditions[0][1].variable:
1611 # if the condition is dest, replace the condition with pred
1612 success = True
1613 if conditions[0][1].value == 1:
1614 self.strings.conditions[conditions[0][0]] = _ConditionString(
1615 pred.variable, pred.comparator, pred.value
1616 )
1617 else:
1618 assert conditions[0][1].value == 0
1619 self.strings.conditions[conditions[0][0]] = _ConditionString(
1620 pred.variable,
1621 _negate_comparator(pred.comparator),
1622 pred.value,
1623 )
1624 if any(_vars_overlap(pred.dest, v) for v in written_variables) or any(
1625 _vars_overlap(pred.variable, v) for v in written_variables
1626 ):
1627 return success
1628 line_labels.clear()
1629 conditions.clear()
1630 written_variables.clear()
1631 return success
1633 def remove_unused_predicate(self, pred_label: int) -> bool:
1634 """Given the label of a predicate p=(var, comp, value, dest, label),
1635 we remove p if dest never appears after p."""
1636 assert pred_label in self.range_preds
1637 pred = self.range_preds[pred_label]
1638 for label in range(pred_label + 1, self.strings.label):
1639 string = self.strings.get_string(label)
1640 if string is None:
1641 continue
1642 if _var_appears(pred.dest, string) or (
1643 label in self.strings.conditions
1644 and _vars_overlap(pred.dest, self.strings.conditions[label].variable)
1645 ):
1646 return False
1647 self.range_preds.pop(pred_label)
1648 self.strings.del_string(pred_label)
1649 return True
1651 def add_conditional(self, op: Conditional, args: Sequence[UnitID]) -> None:
1652 control_bits = args[: op.width]
1653 if op.width == 1 and _hqs_header(self.header):
1654 variable = str(control_bits[0])
1655 else:
1656 variable = control_bits[0].reg_name
1657 if (
1658 _hqs_header(self.header)
1659 and control_bits != self.cregs[variable].to_list()
1660 ):
1661 raise QASMUnsupportedError(
1662 "hqslib1 QASM conditions must be an entire classical "
1663 "register or a single bit"
1664 )
1665 if not _hqs_header(self.header):
1666 if op.width != self.cregs[variable].size:
1667 raise QASMUnsupportedError(
1668 "OpenQASM conditions must be an entire classical register"
1669 )
1670 if control_bits != self.cregs[variable].to_list():
1671 raise QASMUnsupportedError(
1672 "OpenQASM conditions must be a single classical register"
1673 )
1674 if op.op.type == OpType.Phase:
1675 # Conditional phase is ignored.
1676 return
1677 if op.op.type == OpType.RangePredicate:
1678 # Special handling for nested ifs
1679 # if condition
1680 # if pred dest = 1
1681 # if not pred dest = 0
1682 # can be written as
1683 # if condition s0 = 1
1684 # if pred s1 = 1
1685 # s2 = s0 & s1
1686 # s3 = s0 & ~s1
1687 # if s2 dest = 1
1688 # if s3 dest = 0
1689 # where s0, s1, s2, and s3 are scratch bits
1690 s0 = self.fresh_scratch_bit()
1691 l = self.strings.add_string(f"{s0} = 1;\n")
1692 # we store the condition in self.strings.conditions
1693 # as it can be later replaced by `replace_condition`
1694 # if possible
1695 self.strings.conditions[l] = _ConditionString(variable, "==", op.value)
1696 # output the RangePredicate to s1
1697 s1 = self.fresh_scratch_bit()
1698 assert isinstance(op.op, RangePredicateOp)
1699 self.check_range_predicate(op.op, cast("list[Bit]", args[op.width :]))
1700 pred_comparator, pred_value = _parse_range(
1701 op.op.lower, op.op.upper, self.maxwidth
1702 )
1703 pred_variable = args[op.width :][0].reg_name
1704 self.strings.add_string(
1705 f"if({pred_variable}{pred_comparator}{pred_value}) {s1} = 1;\n"
1706 )
1707 s2 = self.fresh_scratch_bit()
1708 self.strings.add_string(f"{s2} = {s0} & {s1};\n")
1709 s3 = self.fresh_scratch_bit()
1710 self.strings.add_string(f"{s3} = {s0} & (~ {s1});\n")
1711 self.strings.add_string(f"if({s2}==1) {args[-1]} = 1;\n")
1712 self.strings.add_string(f"if({s3}==1) {args[-1]} = 0;\n")
1713 return
1714 # we assign the condition to a scratch bit, which we will later remove
1715 # if the condition variable is unchanged.
1716 scratch_bit = self.fresh_scratch_bit()
1717 pred_label = self.strings.add_string(
1718 f"if({variable}=={op.value}) " + f"{scratch_bit} = 1;\n"
1719 )
1720 self.range_preds[pred_label] = _ScratchPredicate(
1721 variable, "==", op.value, str(scratch_bit)
1722 )
1723 # we will later add condition to all lines starting from next_label
1724 next_label = self.strings.label
1725 self.add_op(op.op, args[op.width :])
1726 # add conditions to the lines after the predicate
1727 is_new_line = True
1728 for label in range(next_label, self.strings.label):
1729 string = self.strings.get_string(label)
1730 assert string is not None
1731 if is_new_line and string != "\n":
1732 self.strings.conditions[label] = _ConditionString(
1733 str(scratch_bit), "==", 1
1734 )
1735 is_new_line = "\n" in string
1736 if self.replace_condition(pred_label) and self.remove_unused_predicate(
1737 pred_label
1738 ):
1739 # remove the unused scratch bit
1740 self.remove_last_scratch_bit()
1742 def add_set_bits(self, op: SetBitsOp, args: list[Bit]) -> None:
1743 creg_name = args[0].reg_name
1744 bits, vals = zip(*sorted(zip(args, op.values, strict=False)), strict=False)
1745 # check if whole register can be set at once
1746 if bits == tuple(self.cregs[creg_name].to_list()):
1747 value = int("".join(map(str, map(int, vals[::-1]))), 2)
1748 label = self.strings.add_string(f"{creg_name} = {value};\n")
1749 self.mark_as_written(label, f"{creg_name}")
1750 else:
1751 for bit, value in zip(bits, vals, strict=False):
1752 label = self.strings.add_string(f"{bit} = {int(value)};\n")
1753 self.mark_as_written(label, f"{bit}")
1755 def add_copy_bits(self, op: CopyBitsOp, args: list[Bit]) -> None:
1756 l_args = args[op.n_inputs :]
1757 r_args = args[: op.n_inputs]
1758 l_name = l_args[0].reg_name
1759 r_name = r_args[0].reg_name
1760 # check if whole register can be set at once
1761 if (
1762 l_args == self.cregs[l_name].to_list()
1763 and r_args == self.cregs[r_name].to_list()
1764 ):
1765 label = self.strings.add_string(f"{l_name} = {r_name};\n")
1766 self.mark_as_written(label, f"{l_name}")
1767 else:
1768 for bit_l, bit_r in zip(l_args, r_args, strict=False):
1769 label = self.strings.add_string(f"{bit_l} = {bit_r};\n")
1770 self.mark_as_written(label, f"{bit_l}")
1772 def add_multi_bit(self, op: MultiBitOp, args: list[Bit]) -> None:
1773 basic_op = op.basic_op
1774 basic_n = basic_op.n_inputs + basic_op.n_outputs + basic_op.n_input_outputs
1775 n_args = len(args)
1776 assert n_args % basic_n == 0
1777 arity = n_args // basic_n
1779 # If the operation is register-aligned we can write it more succinctly.
1780 poss_regs = [
1781 tuple(args[basic_n * i + j] for i in range(arity)) for j in range(basic_n)
1782 ]
1783 if all(poss_reg in self.cregs_as_bitseqs for poss_reg in poss_regs):
1784 # The operation is register-aligned.
1785 self.add_op(basic_op, [poss_regs[j][0].reg_name for j in range(basic_n)]) # type: ignore
1786 else:
1787 # The operation is not register-aligned.
1788 for i in range(arity):
1789 basic_args = args[basic_n * i : basic_n * (i + 1)]
1790 self.add_op(basic_op, basic_args)
1792 def add_explicit_op(self, op: Op, args: list[Bit]) -> None:
1793 # &, ^ and | gates
1794 opstr = str(op)
1795 if opstr not in _classical_gatestr_map: 1795 ↛ 1796line 1795 didn't jump to line 1796 because the condition on line 1795 was never true
1796 raise QASMUnsupportedError(f"Classical gate {opstr} not supported.")
1797 label = self.strings.add_string(
1798 f"{args[-1]} = {args[0]} {_classical_gatestr_map[opstr]} {args[1]};\n"
1799 )
1800 self.mark_as_written(label, f"{args[-1]}")
1802 def add_wired_clexpr(self, op: ClExprOp, args: list[Bit]) -> None:
1803 wexpr: WiredClExpr = op.expr
1804 # 1. Determine the mappings from bit variables to bits and from register
1805 # variables to registers.
1806 expr: ClExpr = wexpr.expr
1807 bit_posn: dict[int, int] = wexpr.bit_posn
1808 reg_posn: dict[int, list[int]] = wexpr.reg_posn
1809 output_posn: list[int] = wexpr.output_posn
1810 input_bits: dict[int, Bit] = {i: args[j] for i, j in bit_posn.items()}
1811 input_regs: dict[int, BitRegister] = {}
1812 all_cregs = set(self.cregs.values())
1813 for i, posns in reg_posn.items():
1814 reg_args = [args[j] for j in posns]
1815 for creg in all_cregs: 1815 ↛ 1820line 1815 didn't jump to line 1820 because the loop on line 1815 didn't complete
1816 if creg.to_list() == reg_args:
1817 input_regs[i] = creg
1818 break
1819 else:
1820 assert (
1821 not f"ClExprOp ({wexpr}) contains a register variable (r{i}) that "
1822 "is not wired to any BitRegister in the circuit."
1823 )
1824 # 2. Write the left-hand side of the assignment.
1825 output_repr: str | None = None
1826 output_args: list[Bit] = [args[j] for j in output_posn]
1827 n_output_args = len(output_args)
1828 expect_reg_output = has_reg_output(expr.op)
1829 if n_output_args == 0: 1829 ↛ 1830line 1829 didn't jump to line 1830 because the condition on line 1829 was never true
1830 raise QASMUnsupportedError("Expression has no output.")
1831 if n_output_args == 1:
1832 output_arg = output_args[0]
1833 output_repr = output_arg.reg_name if expect_reg_output else str(output_arg)
1834 else:
1835 if not expect_reg_output: 1835 ↛ 1836line 1835 didn't jump to line 1836 because the condition on line 1835 was never true
1836 raise QASMUnsupportedError("Unexpected output for operation.")
1837 for creg in all_cregs: 1837 ↛ 1841line 1837 didn't jump to line 1841 because the loop on line 1837 didn't complete
1838 if creg.to_list() == output_args:
1839 output_repr = creg.name
1840 break
1841 assert output_repr is not None
1842 self.strings.add_string(f"{output_repr} = ")
1843 # 3. Write the right-hand side of the assignment.
1844 self.strings.add_string(
1845 expr.as_qasm(input_bits=input_bits, input_regs=input_regs)
1846 )
1847 self.strings.add_string(";\n")
1849 def add_wasm(self, op: WASMOp, args: list[Bit]) -> None:
1850 inputs: list[str] = []
1851 outputs: list[str] = []
1852 for reglist, sizes in [(inputs, op.input_widths), (outputs, op.output_widths)]:
1853 for in_width in sizes:
1854 bits = args[:in_width]
1855 args = args[in_width:]
1856 regname = bits[0].reg_name
1857 if bits != list(self.cregs[regname]): 1857 ↛ 1858line 1857 didn't jump to line 1858 because the condition on line 1857 was never true
1858 QASMUnsupportedError("WASM ops must act on entire registers.")
1859 reglist.append(regname)
1860 if outputs:
1861 label = self.strings.add_string(f"{', '.join(outputs)} = ")
1862 self.strings.add_string(f"{op.func_name}({', '.join(inputs)});\n")
1863 for variable in outputs:
1864 self.mark_as_written(label, variable)
1866 def add_measure(self, args: Sequence[UnitID]) -> None:
1867 label = self.strings.add_string(f"measure {args[0]} -> {args[1]};\n")
1868 self.mark_as_written(label, f"{args[1]}")
1870 def add_zzphase(self, param: float | Expr, args: Sequence[UnitID]) -> None:
1871 # as op.params returns reduced parameters, we can assume
1872 # that 0 <= param < 4
1873 if param > 1:
1874 # first get in to 0 <= param < 2 range
1875 param = Decimal(str(param)) % Decimal(2)
1876 # then flip 1 <= param < 2 range into
1877 # -1 <= param < 0
1878 if param > 1:
1879 param = -2 + param
1880 self.strings.add_string("RZZ")
1881 self.write_params([param])
1882 self.write_args(args)
1884 def add_cnx(self, args: Sequence[UnitID]) -> None:
1885 n_ctrls = len(args) - 1
1886 assert n_ctrls >= 0
1887 match n_ctrls:
1888 case 0: 1888 ↛ 1889line 1888 didn't jump to line 1889 because the pattern on line 1888 never matched
1889 self.strings.add_string("x")
1890 case 1: 1890 ↛ 1891line 1890 didn't jump to line 1891 because the pattern on line 1890 never matched
1891 self.strings.add_string("cx")
1892 case 2: 1892 ↛ 1893line 1892 didn't jump to line 1893 because the pattern on line 1892 never matched
1893 self.strings.add_string("ccx")
1894 case 3:
1895 self.strings.add_string("c3x")
1896 case 4: 1896 ↛ 1898line 1896 didn't jump to line 1898 because the pattern on line 1896 always matched
1897 self.strings.add_string("c4x")
1898 case _:
1899 raise QASMUnsupportedError("CnX with n > 4 not supported in QASM")
1900 self.strings.add_string(" ")
1901 self.write_args(args)
1903 def add_data(self, op: BarrierOp, args: Sequence[UnitID]) -> None:
1904 opstr = _tk_to_qasm_noparams[OpType.Barrier] if op.data == "" else op.data
1905 self.strings.add_string(opstr)
1906 self.strings.add_string(" ")
1907 self.write_args(args)
1909 def add_gate_noparams(self, op: Op, args: Sequence[UnitID]) -> None:
1910 self.strings.add_string(_tk_to_qasm_noparams[op.type])
1911 self.strings.add_string(" ")
1912 self.write_args(args)
1914 def add_gate_params(self, op: Op, args: Sequence[UnitID]) -> None:
1915 optype, params = _get_optype_and_params(op)
1916 self.strings.add_string(_tk_to_qasm_params[optype])
1917 self.write_params(params)
1918 self.write_args(args)
1920 def add_extra_noparams(self, op: Op, args: Sequence[UnitID]) -> tuple[str, str]:
1921 optype = op.type
1922 opstr = _tk_to_qasm_extra_noparams[optype]
1923 gatedefstr = ""
1924 if opstr not in self.added_gate_definitions:
1925 self.added_gate_definitions.add(opstr)
1926 gatedefstr = self.make_gate_definition(op.n_qubits, opstr, optype)
1927 mainstr = opstr + " " + _make_args_str(args)
1928 return gatedefstr, mainstr
1930 def add_extra_params(self, op: Op, args: Sequence[UnitID]) -> tuple[str, str]:
1931 optype, params = _get_optype_and_params(op)
1932 assert params is not None
1933 opstr = _tk_to_qasm_extra_params[optype]
1934 gatedefstr = ""
1935 if opstr not in self.added_gate_definitions:
1936 self.added_gate_definitions.add(opstr)
1937 gatedefstr = self.make_gate_definition(
1938 op.n_qubits, opstr, optype, len(params)
1939 )
1940 mainstr = opstr + _make_params_str(params) + _make_args_str(args)
1941 return gatedefstr, mainstr
1943 def add_op(self, op: Op, args: Sequence[UnitID]) -> None: # noqa: PLR0912
1944 optype, _params = _get_optype_and_params(op)
1945 if optype == OpType.RangePredicate:
1946 assert isinstance(op, RangePredicateOp)
1947 self.add_range_predicate(op, cast("list[Bit]", args))
1948 elif optype == OpType.Conditional:
1949 assert isinstance(op, Conditional)
1950 self.add_conditional(op, args)
1951 elif optype == OpType.Phase:
1952 # global phase is ignored in QASM
1953 pass
1954 elif optype == OpType.SetBits:
1955 assert isinstance(op, SetBitsOp)
1956 self.add_set_bits(op, cast("list[Bit]", args))
1957 elif optype == OpType.CopyBits:
1958 assert isinstance(op, CopyBitsOp)
1959 self.add_copy_bits(op, cast("list[Bit]", args))
1960 elif optype == OpType.MultiBit:
1961 assert isinstance(op, MultiBitOp)
1962 self.add_multi_bit(op, cast("list[Bit]", args))
1963 elif optype in (OpType.ExplicitPredicate, OpType.ExplicitModifier):
1964 self.add_explicit_op(op, cast("list[Bit]", args))
1965 elif optype == OpType.ClExpr:
1966 assert isinstance(op, ClExprOp)
1967 self.add_wired_clexpr(op, cast("list[Bit]", args))
1968 elif optype == OpType.WASM:
1969 assert isinstance(op, WASMOp)
1970 self.add_wasm(op, cast("list[Bit]", args))
1971 elif optype == OpType.Measure:
1972 self.add_measure(args)
1973 elif _hqs_header(self.header) and optype == OpType.ZZPhase:
1974 # special handling for zzphase
1975 assert len(op.params) == 1
1976 self.add_zzphase(op.params[0], args)
1977 elif optype == OpType.CnX:
1978 self.add_cnx(args)
1979 elif optype == OpType.Barrier and self.header == "hqslib1_dev":
1980 assert isinstance(op, BarrierOp)
1981 self.add_data(op, args)
1982 elif (
1983 optype in _tk_to_qasm_noparams
1984 and _tk_to_qasm_noparams[optype] in self.include_module_gates
1985 ):
1986 self.add_gate_noparams(op, args)
1987 elif (
1988 optype in _tk_to_qasm_params
1989 and _tk_to_qasm_params[optype] in self.include_module_gates
1990 ):
1991 self.add_gate_params(op, args)
1992 elif optype in _tk_to_qasm_extra_noparams:
1993 gatedefstr, mainstr = self.add_extra_noparams(op, args)
1994 self.gatedefs += gatedefstr
1995 self.strings.add_string(mainstr)
1996 elif optype in _tk_to_qasm_extra_params: 1996 ↛ 2001line 1996 didn't jump to line 2001 because the condition on line 1996 was always true
1997 gatedefstr, mainstr = self.add_extra_params(op, args)
1998 self.gatedefs += gatedefstr
1999 self.strings.add_string(mainstr)
2000 else:
2001 raise QASMUnsupportedError(f"Cannot print command of type: {op.get_name()}")
2003 def finalize(self) -> str:
2004 # try removing unused predicates
2005 pred_labels = list(self.range_preds.keys())
2006 for label in pred_labels:
2007 # try replacing conditions with a predicate
2008 self.replace_condition(label)
2009 # try removing the predicate
2010 self.remove_unused_predicate(label)
2011 reg_strings = _LabelledStringList()
2012 for reg in self.qregs.values():
2013 reg_strings.add_string(f"qreg {reg.name}[{reg.size}];\n")
2014 for bit_reg in self.cregs.values():
2015 reg_strings.add_string(f"creg {bit_reg.name}[{bit_reg.size}];\n")
2016 if self.scratch_reg.size > 0:
2017 reg_strings.add_string(
2018 f"creg {self.scratch_reg.name}[{self.scratch_reg.size}];\n"
2019 )
2020 return (
2021 self.prefix
2022 + self.gatedefs
2023 + _filtered_qasm_str(
2024 reg_strings.get_full_string() + self.strings.get_full_string()
2025 )
2026 )
2029def circuit_to_qasm_io(
2030 circ: Circuit,
2031 stream_out: TextIO,
2032 header: str = "qelib1",
2033 include_gate_defs: set[str] | None = None,
2034 maxwidth: int = 32,
2035) -> None:
2036 """Convert a Circuit to QASM and write to a text stream.
2038 Classical bits in the pytket circuit must be singly-indexed.
2040 Note that this will not account for implicit qubit permutations in the Circuit.
2042 :param circ: pytket circuit
2043 :param stream_out: text stream to be written to
2044 :param header: qasm header (default "qelib1")
2045 :param include_gate_defs: optional set of gates to include
2046 :param maxwidth: maximum allowed width of classical registers (default 32)
2047 """
2048 stream_out.write(
2049 circuit_to_qasm_str(
2050 circ, header=header, include_gate_defs=include_gate_defs, maxwidth=maxwidth
2051 )
2052 )