Source code for lambeq.backend.snake_removal

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Snake removal
=============
This module contains a function for removing snakes from diagrams. This
work is based on DisCoPy (https://discopy.org/) which is released under
the BSD 3-Clause "New" or "Revised" License.

"""

from __future__ import annotations

from collections.abc import Iterator

from lambeq.backend.grammar import Box, Cap, Cup, Diagram


[docs] class InterchangerError(Exception): """ This is raised when we try to interchange conected boxes. """
[docs] def __init__(self, box0: Box, box1: Box) -> None: super().__init__(f'Boxes {box0} and {box1} do not commute.')
[docs] def snake_removal(diagram: Diagram, left: bool = False) -> Iterator[Diagram]: """ Returns a generator which yields normalization steps. Parameters ---------- left : bool, optional Whether to apply left interchangers. Yields ------ diagram : :class:`Diagram` Rewrite steps. Examples -------- >>> from lambeq.backend.grammar import Ty, Box, Cup, Cap, Id >>> n, s = Ty('n'), Ty('s') >>> cup, cap = Cup(n, n.r), Cap(n.r, n) >>> f = Box('f', n, n) >>> g = Box('g', s @ n, n) >>> h = Box('h', n, n @ s) >>> diagram = g @ cap >> f.dagger() @ Id(n.r) @ f >> cup @ h >>> for d in snake_removal(diagram): ... print(d) # doctest: +ELLIPSIS |Ty... >> |Ty() @ [CUP; Ty(n) @ Ty(n).r -> Ty()] @ Ty(n)| >>... |Ty... >> |Ty(n) @ [CAP; Ty() -> Ty(n).r @ Ty(n)] @ Ty()| >> \ |Ty() @ [CUP; Ty(n) @ Ty(n).r -> Ty()] @ Ty(n)| >>... |Ty() @ [g; Ty(s) @ Ty(n) -> Ty(n)] @ Ty()| >> \ |Ty() @ [f†; Ty(n) -> Ty(n)] @ Ty()| >> \ |Ty() @ [f; Ty(n) -> Ty(n)] @ Ty()| >> \ |Ty() @ [h; Ty(n) -> Ty(n) @ Ty(s)] @ Ty()| """ def follow_wire(diagram: Diagram, i: int, j: int) -> tuple[int, int, tuple[list[int], list[int]]]: """ Given a diagram, the index of a box i and the offset j of an output wire, returns (i, j, obstructions) where: - i is the index of the box which takes this wire as input, or len(diagram) if it is connected to the bottom boundary. - j is the offset of the wire at its bottom end. - obstructions is a pair of lists of indices for the boxes on the left and right of the wire we followed. """ left_obstruction = [] # type: list[int] right_obstruction = [] # type: list[int] while i < len(diagram) - 1: i += 1 box, off = diagram.boxes[i], diagram.offsets[i] if off <= j < off + len(box.dom): return i, j, (left_obstruction, right_obstruction) if off <= j: j += len(box.cod) - len(box.dom) left_obstruction.append(i) else: right_obstruction.append(i) return len(diagram), j, (left_obstruction, right_obstruction) def find_snake(diagram: Diagram) -> None | tuple[int, int, tuple[list[int], list[int]], bool]: """ Given a diagram, returns (cup, cap, obstructions, left_snake) if there is a yankable pair, otherwise returns None. """ for cap in range(len(diagram)): if not isinstance(diagram.boxes[cap], Cap): continue for left_snake, wire in [(True, diagram.offsets[cap]), (False, diagram.offsets[cap] + 1)]: cup, wire, obstructions = follow_wire(diagram, cap, wire) not_yankable = (cup == len(diagram) or not isinstance(diagram.boxes[cup], Cup) or (left_snake and diagram.offsets[cup] + 1 != wire) or (not left_snake and diagram.offsets[cup] != wire)) if not_yankable: continue return cup, cap, obstructions, left_snake return None def unsnake(diagram: Diagram, cup: int, cap: int, obstructions: tuple[list[int], list[int]], left_snake: bool = False) -> Iterator[Diagram]: """ Given a diagram and the indices for a cup and cap pair and a pair of lists of obstructions on the left and right, returns a new diagram with the snake removed. A left snake is one of the form Id @ Cap >> Cup @ Id. A right snake is one of the form Cap @ Id >> Id @ Cup. """ left_obstruction, right_obstruction = obstructions if left_snake: for box in left_obstruction: diagram = interchange(diagram, box, cap) yield diagram for i, right_box in enumerate(right_obstruction): if right_box < box: right_obstruction[i] += 1 cap += 1 for box in right_obstruction[::-1]: diagram = interchange(diagram, box, cup) yield diagram cup -= 1 else: for box in left_obstruction[::-1]: diagram = interchange(diagram, box, cup) yield diagram for i, right_box in enumerate(right_obstruction): if right_box > box: right_obstruction[i] -= 1 cup -= 1 for box in right_obstruction: diagram = interchange(diagram, box, cap) yield diagram cap += 1 layers = diagram.layers[:cap] + diagram.layers[cup + 1:] yield diagram.category.Diagram(diagram.dom, diagram.cod, layers) while True: yankable = find_snake(diagram) if yankable is None: break for _diagram in unsnake(diagram, *yankable): yield _diagram diagram = _diagram for _diagram in normalize(diagram, left=left): yield _diagram
[docs] def interchange(diagram: Diagram, i: int, j: int, left: bool = False) -> Diagram: """ Returns a new diagram with boxes i and j interchanged. Gets called recursively whenever :code:`i < j + 1 or j < i - 1`. Parameters ---------- diagram : :class:`Diagram` The diagram to interchange boxes in. i : int Index of the box to interchange. j : int Index of the new position for the box. left : bool, optional Whether to apply left interchangers. Notes ----- By default, we apply only right exchange moves:: top >> Id(left @ box1.dom @ mid) @ box0 @ Id(right) >> Id(left) @ box1 @ Id(mid @ box0.cod @ right) >> bottom gets rewritten to:: top >> Id(left) @ box1 @ Id(mid @ box0.dom @ right) >> Id(left @ box1.cod @ mid) @ box0 @ Id(right) >> bottom """ if not 0 <= i < len(diagram) or not 0 <= j < len(diagram): raise IndexError if i == j: return diagram if j < i - 1: result = diagram for k in range(i - j): result = interchange(result, i - k, i - k - 1, left=left) return result if j > i + 1: result = diagram for k in range(j - i): result = interchange(result, i + k, i + k + 1, left=left) return result if j < i: i, j = j, i off0, off1 = diagram.offsets[i], diagram.offsets[j] left0, box0, right0 = diagram.layers[i].unpack() left1, box1, right1 = diagram.layers[j].unpack() # By default, we check if box0 is to the right first, # then to the left. if left and off1 >= off0 + len(box0.cod): # box0 left of box1 middle = left1[len(left0 @ box0.cod):] layer0 = diagram.category.Layer(left0, box0, middle @ box1.cod @ right1) layer1 = diagram.category.Layer(left0 @ box0.dom @ middle, box1, right1) elif off0 >= off1 + len(box1.dom): # box0 right of box1 middle = left0[len(left1 @ box1.dom):] layer0 = diagram.category.Layer(left1 @ box1.cod @ middle, box0, right0) layer1 = diagram.category.Layer(left1, box1, middle @ box0.dom @ right0) elif off1 >= off0 + len(box0.cod): # box0 left of box1 middle = left1[len(left0 @ box0.cod):] layer0 = diagram.category.Layer(left0, box0, middle @ box1.cod @ right1) layer1 = diagram.category.Layer(left0 @ box0.dom @ middle, box1, right1) else: raise InterchangerError(box0, box1) layers = diagram.layers[:i] + [layer1, layer0] + diagram.layers[i + 2:] return diagram.category.Diagram(diagram.dom, diagram.cod, layers=layers)
[docs] def normalize(diagram: Diagram, left: bool = False) -> Iterator[Diagram]: """ Implements normalization of diagrams, see arXiv:1804.07832. Parameters ---------- diagram : :class:`Diagram` The diagram to normalize. left : bool, optional Passed to :func:`interchange`. Yields ------ diagram : :class:`Diagram` Rewrite steps. Examples -------- >>> from lambeq.backend.grammar import Ty, Box >>> s0, s1 = Box('s0', Ty(), Ty()), Box('s1', Ty(), Ty()) >>> gen = normalize(s0 @ s1) >>> for _ in range(3): print(next(gen)) |Ty() @ [s1; Ty() -> Ty()] @ Ty()| >> \ |Ty() @ [s0; Ty() -> Ty()] @ Ty()| |Ty() @ [s0; Ty() -> Ty()] @ Ty()| >> \ |Ty() @ [s1; Ty() -> Ty()] @ Ty()| |Ty() @ [s1; Ty() -> Ty()] @ Ty()| >> \ |Ty() @ [s0; Ty() -> Ty()] @ Ty()| """ no_more_moves = False while not no_more_moves: no_more_moves = True for i in range(len(diagram) - 1): box0, box1 = diagram.boxes[i], diagram.boxes[i + 1] off0, off1 = diagram.offsets[i], diagram.offsets[i + 1] if ((left and off1 >= off0 + len(box0.cod)) or (not left and off0 >= off1 + len(box1.dom))): diagram = interchange(diagram, i, i + 1, left=left) yield diagram no_more_moves = False