Source code for lambeq.rewrite.rewrite_diagram

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Diagram Rewrite
===============
Class hierarchy for allowing rewriting at the diagram level (as opposed
to rewrite rules that apply on the box level).

Subclass :py:class:'DiagramRewriter' to define a custom diagram rewriter.
"""
from __future__ import annotations

__all__ = ['DiagramRewriter',
           'RemoveCupsRewriter',
           'RemoveSwapsRewriter',
           'UnifyCodomainRewriter',]

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import overload

from lambeq.backend.grammar import (Box, Cup, Diagram, Id, Swap,
                                    Ty, Word)
from lambeq.core.types import AtomicType

N = AtomicType.NOUN
S = AtomicType.SENTENCE
CUP_TOKEN = '**CUP**'


[docs]class DiagramRewriter(ABC): """Base class for diagram level rewriters."""
[docs] @abstractmethod def matches(self, diagram: Diagram) -> bool: """Check if the given diagram should be rewritten."""
[docs] @abstractmethod def rewrite(self, diagram: Diagram) -> Diagram: """Rewrite the given diagram."""
@overload def __call__(self, target: list[Diagram]) -> list[Diagram]: ... @overload def __call__(self, target: Diagram) -> Diagram: ...
[docs] def __call__(self, target: list[Diagram] | Diagram) -> list[Diagram] | Diagram: """Rewrite the given diagram(s) if the rule applies. Parameters ---------- diagram : :py:class:`lambeq.backend.grammar.Diagram` or list of Diagram The candidate diagram(s) to be rewritten. Returns ------- :py:class:`lambeq.backend.gramar.Diagram` or list of Diagram The rewritten diagram. If the rule does not apply, the original diagram is returned. """ if isinstance(target, list): return [self(d) for d in target] else: return self.rewrite(target) if self.matches(target) else target
[docs]@dataclass class UnifyCodomainRewriter(DiagramRewriter): """Unifies the codomain of diagrams to match a given type. A rewriter that takes diagrams with ``d.cod != output_type`` and append a ``d.cod -> output_type`` box. Attributes ---------- output_type : :py:class:`lambeq.backend.grammar.Ty`, default ``S`` The output type of the appended box. """ output_type: Ty = S
[docs] def matches(self, diagram: Diagram) -> bool: return bool(diagram.cod != self.output_type)
[docs] def rewrite(self, diagram: Diagram) -> Diagram: return diagram >> Box(f'MERGE_{diagram.cod}', diagram.cod, self.output_type)
[docs]class RemoveCupsRewriter(DiagramRewriter): """Removes cups from a given diagram. Diagrams with less cups become circuits with less post-selection, which results in faster QML experiments. """
[docs] def matches(self, diagram: Diagram) -> bool: return True
def _compress_cups(self, diagram: Diagram) -> Diagram: layers: list[tuple[Box, int]] = [] for box, offset in zip(diagram.boxes, diagram.offsets): nested_cup = (isinstance(box, Cup) and layers and isinstance(layers[-1][0].boxes[0], Cup) and offset == layers[-1][1] - 1) if nested_cup: dom = box.dom[:1] @ layers[-1][0].dom @ box.dom[1:] layers[-1] = (Box(CUP_TOKEN, dom, Ty()), offset) else: layers.append((box, offset)) compressed_diag = Id(diagram.dom) for box, offset in layers: compressed_diag = compressed_diag.then_at(box, offset) return compressed_diag def _remove_cups(self, diagram: Diagram) -> Diagram: diags: list[Diagram | Box] = [Id(diagram.dom)] for box, offset in zip(diagram.boxes, diagram.offsets): i = 0 off = offset # find the first box to contract while i < len(diags) and off >= len(diags[i].cod): off -= len(diags[i].cod) i += 1 if off == 0 and not box.dom: diags.insert(i, box) else: left, right = diags[i], Id(Ty()) j = 1 # add boxes to the right until they are enough to contract # | left | right | # off | box | while len(left.cod @ right.cod) < off + len(box.dom): assert i + j < len(diags) right = right @ diags[i + j] j += 1 cod = left.cod @ right.cod wires_l = Id(cod[:off]) wires_r = Id(cod[off + len(box.dom):]) if box.name == CUP_TOKEN or isinstance(box, Cup): # contract greedily, else combine pg_len = len(box.dom) // 2 pg_type1, pg_type2 = box.dom[:pg_len], box.dom[pg_len:] if len(left.cod) == pg_len and not left.dom: if pg_type1.r == pg_type2: new_diag = right >> (left.dagger().r @ wires_r) else: # illegal cup new_diag = right >> (left.dagger().l @ wires_r) elif len(right.cod) == pg_len and not right.dom: if pg_type1.r == pg_type2: new_diag = left >> (wires_l @ right.dagger().l) else: new_diag = left >> (wires_l @ right.dagger().r) else: nbox = Diagram.cups(pg_type1, pg_type2, is_reversed=pg_type2 != pg_type1.r) new_diag = left @ right >> wires_l @ nbox @ wires_r else: new_diag = left @ right >> wires_l @ box @ wires_r diags[i:i+j] = [new_diag] return Id().tensor(*diags)
[docs] def rewrite(self, diagram: Diagram) -> Diagram: # Logic from remove_cups should go here return self._remove_cups( self._compress_cups(self._remove_cups(diagram)) )
[docs]class RemoveSwapsRewriter(DiagramRewriter): """Produce a proper pregroup diagram by removing any swaps. Direct conversion of a CCG derivation into a string diagram form may introduce swaps, caused by cross-composition rules and unary rules that may change types and the directionality of composition at any point of the derivation. This class removes swaps, producing a valid pregroup diagram (in J. Lambek's sense) as follows: 1. Eliminate swap morphisms by swapping the actual atomic types of the words. 2. Scan the new diagram for any detached parts, and remove them by merging words together when possible. Parameters ---------- diagram : :py:class:`lambeq.backend.grammar.Diagram` The input diagram. Returns ------- :py:class:`lambeq.backend.grammar.Diagram` A copy of the input diagram without swaps. Raises ------ ValueError If the input diagram is not in "pregroup" form, i.e. when words do not strictly precede the morphisms. Notes ----- The class trades off diagrammatic simplicity and conformance to a formal pregroup grammar for a larger vocabulary, since each word is associated with more types than before and new words (combined tokens) are added to the vocabulary. Depending on the size of your dataset, this might lead to data sparsity problems during training. Examples -------- In the following example, "am" and "not" are combined at the CCG level using cross composition, which introduces the interwoven pattern of wires. .. code-block:: text I am not sleeping ─ ─────────── ─────────────── ──────── n n.r·s·s.l·n s.r·n.r.r·n.r·s n.r·s │ │ │ │ ╰─╮─╯ │ │ │ │ │ │ │ │ │ ╭─╰─╮ │ │ │ │ │ │ │ │ ╰╮─╯ ╰─╮──╯ │ │ │ │ │ │ │ ╭╰─╮ ╭─╰──╮ │ │ │ │ │ │ ╰──╯ ╰─╮─╯ ╰─╮──╯ │ │ │ │ │ ╭─╰─╮ ╭─╰──╮ │ │ │ │ ╰────────╯ ╰─╮──╯ ╰╮─╯ │ │ │ ╭─╰──╮ ╭╰─╮ │ │ ╰────────────────╯ ╰─╮──╯ ╰────╯ │ ╭─╰──╮ │ │ ╰──────────╯ Rewriting with the :py:class:`RemoveSwapsRewriter` class will return: .. code-block:: text I am not sleeping ─ ─────────── ──────── n n.r·s·s.l·n n.r·s ╰───╯ │ │ ╰────╯ │ │ ╰──────────╯ removing the swaps and combining "am" and "not" into one token. """ @dataclass class _Word: """Helper class for :py:method:`RemoveSwapsRewriter._remove_detached_cups` method.""" word: Word offset: int @dataclass class _Morphism: """Helper class for :py:method:`RemoveSwapsRewriter._remove_detached_cups` method.""" morphism: Box start: int end: int offset: int deleted: bool = False
[docs] def matches(self, diagram: Diagram) -> bool: if not diagram.is_pregroup: try: diagram = diagram.normal_form() except ValueError as e: raise ValueError('Not a valid pregroup diagram.') from e return True
def _remove_detached_cups(self, diagram: Diagram) -> Diagram: """Remove any detached cups from a diagram. Helper function for :py:method:`RemoveSwapsRewriter.remove_swaps` method. """ if not diagram.is_pregroup: raise ValueError('Not a valid pregroup diagram.') atomic_types = [ob for b in diagram.boxes for ob in b.cod if isinstance(b, Word)] scan = list(range(len(atomic_types))) # Create lists with offset info for words and morphisms words: list[RemoveSwapsRewriter._Word] = [] morphisms: list[RemoveSwapsRewriter._Morphism] = [] for box, offset in zip(diagram.boxes, diagram.offsets): if isinstance(box, Word): words.append(self._Word(box, offset)) else: start = scan[offset] end = scan[offset + len(box.dom) - 1] if isinstance(box, Cup): del scan[offset : offset + len(box.dom)] morphisms.append(self._Morphism(box, start, end, offset)) # Scan each word for detached cups new_words: list[Word] = [] for w_idx, wrd in enumerate(words): rng = range(wrd.offset, wrd.offset + len(wrd.word.cod)) scan = list(rng) for mor in morphisms: if (isinstance(mor.morphism, Cup) and mor.start in rng and mor.end in rng): del scan[mor.start - wrd.offset: mor.start - wrd.offset + 2] mor.deleted = True if len(scan) == len(rng): # word type hasn't changed new_words.append(wrd.word) elif len(scan) > 0: # word type has been reduced in length typ = Ty().tensor(*[atomic_types[i] for i in scan]) new_words.append(Word(wrd.word.name, typ)) else: # word type has been eliminated, merge word label # with next one next_wrd = words[w_idx + 1] new_wrd = Word(f'{wrd.word.name} {next_wrd.word.name}', next_wrd.word.cod) next_wrd.word = new_wrd # Compute new word offsets total_ofs = 0 wrd_offsets = [] for w in new_words: wrd_offsets.append(total_ofs) total_ofs += len(w.cod) # Create new morphism and offset lists new_morphisms: list[Box] = [] mor_offsets: list[int] = [] for m_idx, m in enumerate(morphisms): if not m.deleted: # morphism is not deleted, add it with its offset new_morphisms.append(m.morphism) mor_offsets.append(m.offset) else: # cup is deleted, adjust all above offsets if required for j in range(m_idx): if (not morphisms[j].deleted and morphisms[j].start > morphisms[m_idx].start): mor_offsets[j] -= 2 new_diag = Id(diagram.dom) for box, offset in zip(new_words+new_morphisms, wrd_offsets+mor_offsets): new_diag = new_diag.then_at(box, offset) return new_diag
[docs] def rewrite(self, diagram: Diagram) -> Diagram: atomic_types = [ob for b in diagram.boxes for ob in b.cod if isinstance(b, Word)] scan = list(range(len(atomic_types))) # Create lists with offset info for words and morphisms words: list[tuple[Box, int]] = [] morphisms: list[tuple[Box, int]] = [] for box, offset in zip(diagram.boxes, diagram.offsets): if isinstance(box, Word): words.append((box, offset)) else: morphisms.append((box, offset)) # Detect Swaps and swap the actual types for box, ofs in morphisms: if isinstance(box, Swap): tidx_l = scan[ofs] tidx_r = scan[ofs + 1] tmp = atomic_types[tidx_l] atomic_types[tidx_l] = atomic_types[tidx_r] atomic_types[tidx_r] = tmp elif isinstance(box, Cup): del scan[ofs: ofs + 2] new_diagr = Id(diagram.dom) for wrd, ofs in words: new_diagr = new_diagr.then_at( Word(wrd.name, Ty().tensor(*atomic_types[ofs:ofs+len(wrd.cod)])), ofs ) for mor, ofs in morphisms: if not isinstance(mor, Swap): new_diagr = new_diagr.then_at(mor, ofs) return self._remove_detached_cups(new_diagr)