Source code for lambeq.backend.drawing.drawable

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Drawable Components
===================
Utilities to convert a grammar diagram into a drawable form.

"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
import math
import sys
from typing import Optional

from typing_extensions import Self

from lambeq.backend import grammar
from lambeq.backend.quantum import quantum


X_SPACING = 2.5  # Minimum space between adjacent wires
BOX_SPACING = 0.5   # Minimum space between adjacent boxes
LEDGE = 0.5  # Space from last wire to right box edge
BOX_HEIGHT = 0.5
HALF_BOX_HEIGHT = 0.25
FRAME_COMPONENTS_SPACING = 1.5 * LEDGE


class WireEndpointType(Enum):
    """An enumeration for :py:class:`WireEndpoint`.

    WireEndpoints in diagrams can be of 4 types:

    .. glossary::

        DOM
            Domain of a box.

        COD
            Codomain of a box.

        INPUT
            Input wire to the diagram.

        OUTPUT
            Output wire from the diagram.

    """

    DOM = 0
    COD = 1
    INPUT = 2
    OUTPUT = 3

    def __repr__(self) -> str:
        return self.name


@dataclass
class WireEndpoint:
    """
    One end of a wire in a DrawableDiagram.

    Attributes
    ----------
    kind: WireEndpointType
        Type of wire endpoint.
    obj: grammar.Ty
        Categorial type carried by the wire.
    x: float
        X coordinate of the wire end.
    y: float
        Y coordinate of the wire end.
    coordinates: (float, float)
        (x, y) coordinates.

    """

    kind: WireEndpointType
    obj: grammar.Ty

    x: float
    y: float
    noun_id: int = 0    # New attribute for wire noun
    parent: Optional['BoxNode'] = None

    @property
    def coordinates(self) -> tuple[float, float]:
        return (self.x, self.y)

    def __eq__(self, other: object) -> bool:
        """Check if these are the same instances, excluding
        the parent to avoid recursion."""
        if not isinstance(other, WireEndpoint):
            return NotImplemented
        else:
            return all([
                self.kind == other.kind,
                self.obj == other.obj,
                math.isclose(self.x, other.x),
                math.isclose(self.y, other.y),
            ])

    def __hash__(self) -> int:
        return hash(repr(self))

    def _apply_drawing_offset(self,
                              offset: tuple[float, float]) -> None:
        """Apply the offset to all the components inside the drawable.

        Parameters
        ----------
        offset : tuple[float, float]
            The x and y offsets to be applied.
        """
        self.x += offset[0]
        self.y += offset[1]


@dataclass
class BoxNode:
    """
    Box in a DrawableDiagram.

    Attributes
    ----------
    obj: grammar.Box
        Grammar box represented by the node.
    x: float
        X coordinate of the box.
    y: float
        Y coordinate of the box.
    coordinates: (float, float)
        (x, y) coordinates.
    dom_wires: list of int
        Wire endpoints in the domain of the box, represented by
        indices into an array maintained by `DrawableDiagram`.
    com_wires: list of int
        Wire endpoints in the codomain of the box, represented by
        indices into an array maintained by `DrawableDiagram`.

    """

    obj: grammar.Diagrammable

    x: float
    y: float
    h: Optional[float] = None
    w: Optional[float] = None

    dom_wires: list[int] = field(default_factory=list)
    cod_wires: list[int] = field(default_factory=list)

    child_boxes: list[Self] = field(default_factory=list)
    child_wire_endpoints: list[WireEndpoint] = field(default_factory=list)
    child_wires: list[tuple[int, int]] = field(default_factory=list)
    parent: Optional[Self] = None

    @property
    def coordinates(self):
        return (self.x, self.y)

    def __eq__(self, other: object) -> bool:
        """Check if these are the same instances, excluding
        the parent to avoid recursion."""
        if not isinstance(other, BoxNode):
            return NotImplemented
        else:
            return all([
                self.obj == other.obj,
                math.isclose(self.x, other.x),
                math.isclose(self.y, other.y),
                math.isclose(
                    self.h if self.h else float('inf'),
                    other.h if other.h else float('inf'),
                ),
                math.isclose(
                    self.w if self.w else float('inf'),
                    other.w if other.w else float('inf'),
                ),
                self.child_boxes == other.child_boxes,
                self.child_wire_endpoints == other.child_wire_endpoints,
                self.child_wires == other.child_wires,
            ])

    @property
    def has_wires(self):
        return self.dom_wires or self.cod_wires

    def __hash__(self) -> int:
        return hash(repr(self))

    def add_dom_wire(self, idx: int) -> None:
        """
        Add a wire to to box's domain.

        Parameters
        ----------
        idx : int
            Index of wire in associated `DrawableDiagram`'s
            `wire_endpoints` attribute.

        """
        self.dom_wires.append(idx)

    def add_cod_wire(self, idx: int) -> None:
        """
        Add a wire to to box's codomain.

        Parameters
        ----------
        idx : int
            Index of wire in associated `DrawableDiagram`'s
            `wire_endpoints` attribute.

        """
        self.cod_wires.append(idx)

    def get_x_lims(self,
                   drawable_diagram: DrawableDiagram) -> tuple[float, float]:
        """
        Get left and right limits of the box.

        Parameters
        ----------
        drawable_diagram : DrawableDiagram
            `DrawableDiagram` with which this box is associated.

        """

        if self.w is None:
            all_wires_pos = [drawable_diagram.wire_endpoints[wire].x
                             for wire in self.cod_wires + self.dom_wires]

            if not all_wires_pos:  # scalar box
                all_wires_pos = [self.x]

            left = min(all_wires_pos) - LEDGE
            right = max(all_wires_pos) + LEDGE
        else:
            left = self.x - self.w / 2
            right = self.x + self.w / 2

        return left, right

    def get_y_lims(self,
                   drawable_diagram: DrawableDiagram) -> tuple[float, float]:
        """
        Get top and bottom limits of the box.

        Parameters
        ----------
        drawable_diagram : DrawableDiagram
            `DrawableDiagram` with which this box is associated.

        """

        if self.h is None:
            all_wires_pos = [drawable_diagram.wire_endpoints[wire].y
                             for wire in self.cod_wires + self.dom_wires]

            if not all_wires_pos:  # scalar box
                all_wires_pos = [self.y]

            top = max(all_wires_pos) + LEDGE
            bottom = min(all_wires_pos) - LEDGE
        else:
            top = self.y + self.h / 2
            bottom = self.y - self.h / 2

        return top, bottom

    def _apply_drawing_offset(self,
                              offset: tuple[float, float]) -> None:
        """Apply the offset to all the components inside the drawable.

        Parameters
        ----------
        offset : tuple[float, float]
            The x and y offsets to be applied.
        """
        self.x += offset[0]
        self.y += offset[1]
        for obj in self.child_boxes + self.child_wire_endpoints:
            obj._apply_drawing_offset(offset)


[docs] @dataclass class DrawableDiagram: """ Representation of a lambeq diagram carrying all information necessary to render it. Attributes ---------- boxes: list of BoxNode Boxes in the diagram. wire_endpoints: list of WireEndpoint Endpoints for all wires in the diagram. wires: list of tuple of the form (int, int) The wires in a diagram, each represented by the indices of its 2 endpoints in `wire_endpoints`. """ boxes: list[BoxNode] = field(default_factory=list) wire_endpoints: list[WireEndpoint] = field(default_factory=list) wires: list[tuple[int, int]] = field(default_factory=list) def _add_wire(self, source: int, target: int) -> None: """Add an edge between 2 connected wire endpoints.""" self.wires.append((source, target)) def _add_wire_end(self, wire_end: WireEndpoint) -> int: """Add a `WireEndpoint` to the diagram.""" self.wire_endpoints.append(wire_end) return len(self.wire_endpoints) - 1 def _add_boxnode(self, box: BoxNode) -> int: """Add a `BoxNode` to the diagram.""" self.boxes.append(box) return len(self.boxes) - 1 def _add_box(self, scan: list[int], box: grammar.Box, off: int, x_pos: float, y_pos: float) -> tuple[list[int], int]: """Add a box to the graph, creating necessary wire endpoints. Returns ------- list[int] The new scan of wire endpoints after adding the box box_ind : int The index of the newly added `BoxNode` """ node = BoxNode(box, x_pos, y_pos) box_ind = self._add_boxnode(node) # Create a node representing each element in the box's domain for i, obj in enumerate(box.dom): nbr_idx = scan[off + i] wire_end = WireEndpoint(WireEndpointType.DOM, obj=obj, x=self.wire_endpoints[nbr_idx].x, y=y_pos + HALF_BOX_HEIGHT) wire_idx = self._add_wire_end(wire_end) node.add_dom_wire(wire_idx) self._add_wire(nbr_idx, wire_idx) scan_insert = [] # Create a node representing each element in the box's codomain for i, obj in enumerate(box.cod): # If the box is a quantum gate, retain x coordinate of wires if box.category == quantum and len(box.dom) == len(box.cod): nbr_idx = scan[off + i] x = self.wire_endpoints[nbr_idx].x else: x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2) y = y_pos - HALF_BOX_HEIGHT wire_end = WireEndpoint(WireEndpointType.COD, obj=obj, x=x, y=y) wire_idx = self._add_wire_end(wire_end) scan_insert.append(wire_idx) node.add_cod_wire(wire_idx) # Replace node's dom with its cod in scan return scan[:off] + scan_insert + scan[off + len(box.dom):], box_ind def _find_box_edges(self, box: grammar.Box, x: float, off: int, scan: list[int]): left_edge = x right_edge = x # dom edges come from upstream wire endpoints if box.dom: left_edge = min(self.wire_endpoints[scan[off]].x, left_edge) right_edge = max( self.wire_endpoints[scan[off + len(box.dom) - 1]].x, right_edge) # cod edges are evenly spaced if box.cod: left_edge = min(x - X_SPACING * len(box.cod[1:]) / 2, left_edge) right_edge = max(x + X_SPACING * (len(box.cod[1:]) - len(box.cod[1:]) / 2), right_edge) return left_edge - LEDGE, right_edge + LEDGE def _make_space(self, scan: list[int], box: grammar.Box, off: int, foliated: bool) -> tuple[float, float]: """Determines x and y coords for a new box. Modifies x coordinates of existing nodes to make space. Returns ------- x, y : tuple[float, float] The x and y coordinates of the box. """ if not scan: return 0, 0 half_width = X_SPACING * (len(box.cod[:-1]) / 2 + 1) if not box.dom: if not off: x = self.wire_endpoints[scan[0]].x - half_width elif off == len(scan): x = self.wire_endpoints[scan[-1]].x + half_width else: right = self.wire_endpoints[scan[off + len(box.dom)]].x x = (self.wire_endpoints[scan[off - 1]].x + right) / 2 else: right = self.wire_endpoints[scan[off + len(box.dom) - 1]].x x = (self.wire_endpoints[scan[off]].x + right) / 2 if off and self.wire_endpoints[scan[off - 1]].x > x - half_width: limit = self.wire_endpoints[scan[off - 1]].x pad = limit - x + half_width for node in self.boxes + self.wire_endpoints: if node.x <= limit: node.x -= pad if (off + len(box.dom) < len(scan) and (self.wire_endpoints[scan[off + len(box.dom)]].x < x + half_width)): limit = self.wire_endpoints[scan[off + len(box.dom)]].x pad = x + half_width - limit for node in self.boxes + self.wire_endpoints: if node.x >= limit: node.x += pad left_edge, right_edge = self._find_box_edges(box, x, off, scan) y = 0.0 for upstream_box in self.boxes: bl, br = upstream_box.get_x_lims(self) if not (bl > right_edge or br < left_edge): # Boxes overlap y = min(y, upstream_box.y - 1.0) return x, y def _move_to_origin(self) -> None: """Set the min x and middle-y coordinates of the diagram to 0. Setting the diagram to be centred on the y axis allows us to avoid precomputing the diagram's height. """ min_x = min( [node.x for node in self.boxes + self.wire_endpoints]) min_y = min( [node.y for node in self.boxes + self.wire_endpoints]) max_y = max( [node.y for node in self.boxes + self.wire_endpoints]) mid_y = (min_y + max_y) / 2 for node in self.boxes + self.wire_endpoints: node.x -= min_x node.y -= mid_y
[docs] @classmethod def from_diagram(cls, diagram: grammar.Diagram, foliated: bool = False) -> Self: """ Builds a graph representation of the diagram, calculating coordinates for each box and wire. Parameters ---------- diagram : grammar Diagram A lambeq diagram. foliated : bool, default: False If true, each box of the diagram is drawn in a separate layer. By default boxes are compressed upwards into available space. Returns ------- drawable : DrawableDiagram Representation of diagram including all coordinates necessary to draw it. """ drawable = cls() scan = [] for i, obj in enumerate(diagram.dom): wire_end = WireEndpoint(WireEndpointType.INPUT, obj=obj, x=X_SPACING * i, y=1) wire_end_idx = drawable._add_wire_end(wire_end) scan.append(wire_end_idx) min_y = 1.0 for depth, (box, off) in enumerate(zip(diagram.boxes, diagram.offsets)): x, y = drawable._make_space(scan, box, off, foliated) y = -depth if foliated else y scan, _ = drawable._add_box(scan, box, off, x, y) min_y = min(min_y, y) for i, obj in enumerate(diagram.cod): wire_end = WireEndpoint(WireEndpointType.OUTPUT, obj=obj, x=drawable.wire_endpoints[scan[i]].x, y=min_y - 1) wire_end_idx = drawable._add_wire_end(wire_end) drawable._add_wire(scan[i], wire_end_idx) drawable._move_to_origin() return drawable
[docs] def scale_and_pad(self, scale: tuple[float, float], pad: tuple[float, float]): """Scales and pads the diagram as specified. Parameters ---------- scale : tuple of 2 floats Scaling factors for x and y axes respectively. pad : tuple of 2 floats Padding values for x and y axes respectively. """ min_x = min([node.x for node in self.boxes + self.wire_endpoints]) min_y = min([node.y for node in self.boxes + self.wire_endpoints]) for wire_end in self.wire_endpoints: wire_end.x = min_x + (wire_end.x - min_x) * scale[0] + pad[0] wire_end.y = min_y + (wire_end.y - min_y) * scale[1] + pad[1] for box in self.boxes: box.x = min_x + (box.x - min_x) * scale[0] + pad[0] box.y = min_y + (box.y - min_y) * scale[1] + pad[1] for wire_end_idx in box.dom_wires: self.wire_endpoints[wire_end_idx].y = ( box.y + HALF_BOX_HEIGHT * scale[1]) for wire_end_idx in box.cod_wires: self.wire_endpoints[wire_end_idx].y = ( box.y - HALF_BOX_HEIGHT * scale[1])
class PregroupError(Exception): def __init__(self, diagram): super().__init__(f'Diagram {diagram} is not a pregroup diagram. ' 'A pregroup diagram must be structured like ' '(State @ State ... State) >> (Cups and Swaps)') @dataclass class DrawablePregroup(DrawableDiagram): """ Representation of a lambeq pregroup diagram carrying all information necessary to render it. Attributes ---------- x_tracks: list of int Stores the "track" on which the corresponding `WireEndpoint` in `wire_endpoints` lies. This helps determine the depth of pregroup grammar boxes in the diagram. """ x_tracks: list[int] = field(default_factory=list) def _add_wire_end(self, wire_end: WireEndpoint, x_track=-1) -> int: """Add a `WireEndpoint` to the diagram, with track information.""" self.x_tracks.append(x_track) return super()._add_wire_end(wire_end) @classmethod def from_diagram(cls, diagram: grammar.Diagram, foliated: bool = False) -> Self: """ Builds a graph representation of the diagram, calculating coordinates for each box and wire. Parameters ---------- diagram : grammar.Diagram A lambeq diagram. foliated : bool, default: False This parameter is not used for pregroup diagrams, which are always drawn un-foliated. Returns ------- drawable : DrawableDiagram Representation of diagram including all coordinates necessary to draw it. """ if foliated: print('Pregroup diagrams cannot be drawn foliated.' ' Set `draw_as_pregroup` to `False` to see' ' foliation for this diagram.', file=sys.stderr) words = [] grammar_start_idx = len(diagram) for i, layer in enumerate(diagram.layers): if (isinstance(layer.box, grammar.Cup) or isinstance(layer.box, grammar.Swap)): grammar_start_idx = i break if layer.right or layer.box.dom: raise PregroupError(diagram) words.append(layer.box) HSPACE = 0.5 VSPACE = 0.75 BOX_WIDTH = 2 drawable = cls() scan = [] track_ctr = 0 for i, word in enumerate(words): node = BoxNode(word, (HSPACE + BOX_WIDTH) * i + (0.5 * BOX_WIDTH * isinstance(word, grammar.Cap)), 0) for j, ty in enumerate(word.cod): wire_x = ((HSPACE + BOX_WIDTH) * i + (BOX_WIDTH / (len(word.cod) + 1)) * (j + 1)) wire_end_idx = drawable._add_wire_end( WireEndpoint(WireEndpointType.COD, ty, wire_x, 0.25), track_ctr) node.add_cod_wire(wire_end_idx) scan.append(wire_end_idx) track_ctr += 1 drawable.boxes.append(node) depth_map = [0.0 for _ in range(track_ctr)] for layer in diagram.layers[grammar_start_idx:]: off = len(layer.left) box = layer.box lx = drawable.wire_endpoints[scan[off]].x rx = drawable.wire_endpoints[scan[off + 1]].x l_track = drawable.x_tracks[scan[off]] r_track = drawable.x_tracks[scan[off + 1]] y = min(depth_map[l_track: r_track + 1]) l_wire_end_idx = drawable._add_wire_end( WireEndpoint(WireEndpointType.DOM, box.dom[0], lx, y - VSPACE / 2), l_track) r_wire_end_idx = drawable._add_wire_end( WireEndpoint(WireEndpointType.DOM, box.dom[1], rx, y - VSPACE / 2), r_track) drawable._add_wire(scan[off], l_wire_end_idx) drawable._add_wire(scan[off + 1], r_wire_end_idx) grammar_box = BoxNode(box, (lx + rx) / 2, y - VSPACE) grammar_box.add_dom_wire(l_wire_end_idx) grammar_box.add_dom_wire(r_wire_end_idx) if isinstance(box, grammar.Swap): l_idx = drawable._add_wire_end( WireEndpoint(WireEndpointType.COD, box.cod[0], lx, y - VSPACE), l_track) r_idx = drawable._add_wire_end( WireEndpoint(WireEndpointType.COD, box.cod[1], rx, y - VSPACE), r_track) grammar_box.add_cod_wire(l_idx) grammar_box.add_cod_wire(r_idx) scan[off] = l_idx scan[off + 1] = r_idx elif isinstance(box, grammar.Cup): # 2 elements of the codomain are consumed. scan = scan[:off] + scan[off + 2:] else: raise PregroupError(diagram) drawable.boxes.append(grammar_box) for i in range(l_track, r_track + 1): depth_map[i] = y - VSPACE min_y = min(depth_map) for i, obj in enumerate(diagram.cod): wire_end = WireEndpoint(WireEndpointType.OUTPUT, obj, drawable.wire_endpoints[scan[i]].x, min_y - VSPACE) wire_end_idx = drawable._add_wire_end(wire_end) drawable._add_wire(scan[i], wire_end_idx) drawable._move_to_origin() return drawable @dataclass class DrawableDiagramWithFrames(DrawableDiagram): """ Representation of a lambeq diagram that contains at least one frame, carrying all information necessary to render it. """ noun_id_counter: int = 1 def _make_space(self, scan: list[int], box: grammar.Box, off: int, foliated: bool) -> tuple[float, float]: """Determines x and y coords for a new box. Modifies x coordinates of existing nodes to make space.""" if not scan: return 0, 0 half_width = X_SPACING * (len(box.cod[:-1]) / 2 + 1) if not box.dom: if not off: x = self.wire_endpoints[scan[0]].x - half_width elif off == len(scan): x = self.wire_endpoints[scan[-1]].x + half_width else: right = self.wire_endpoints[scan[off + len(box.dom)]].x x = (self.wire_endpoints[scan[off - 1]].x + right) / 2 else: right = self.wire_endpoints[scan[off + len(box.dom) - 1]].x x = (self.wire_endpoints[scan[off]].x + right) / 2 if off and self.wire_endpoints[scan[off - 1]].x > x - half_width: limit = self.wire_endpoints[scan[off - 1]].x pad = limit - x + half_width for node in self.boxes + self.wire_endpoints: if node.parent is None and node.x <= limit: node._apply_drawing_offset((-pad, 0)) if (off + len(box.dom) < len(scan) and (self.wire_endpoints[scan[off + len(box.dom)]].x < x + half_width)): limit = self.wire_endpoints[scan[off + len(box.dom)]].x pad = x + half_width - limit for node in self.boxes + self.wire_endpoints: if node.parent is None and node.x >= limit: node._apply_drawing_offset((pad, 0)) left_edge, right_edge = self._find_box_edges(box, x, off, scan) y = 0.0 for upstream_box in self.boxes: if upstream_box.parent is None: bl, br = upstream_box.get_x_lims(self) if not (bl > right_edge or br < left_edge) or foliated: # Boxes overlap upstream_box_h = upstream_box.h or BOX_HEIGHT y = min( y, (upstream_box.y - 0.5 * upstream_box_h - 4.5 * BOX_HEIGHT) ) return x, y def _add_box_with_nouns( self, scan: list[int], box: grammar.Box, off: int, x_pos: float, y_pos: float, input_nouns: list[int] ) -> tuple[list[int], int, list[int]]: """Add a box to the graph, creating necessary wire endpoints. Returns ------- list : int The new scan of wire endpoints after adding the box box_ind : int The index of the newly added `BoxNode` input_nouns : list[int] The new order of input_nouns after adding the box """ node = BoxNode(box, x_pos, y_pos) box_ind = self._add_boxnode(node) num_input = len(box.dom) input_nouns = input_nouns or [] for i in range(num_input): if i < len(input_nouns): pass else: # If we run out of input nouns, generate new ones new_color = self.get_noun_id() input_nouns.append(new_color) # Create a node representing each element in the box's domain for i, obj in enumerate(box.dom): idx = off + i nbr_idx = scan[off + i] noun_id = ( input_nouns[idx] if (input_nouns and idx < len(input_nouns)) else self.get_noun_id() ) # generate new noun_id if needed wire_end = WireEndpoint(WireEndpointType.DOM, obj=obj, x=self.wire_endpoints[nbr_idx].x, y=y_pos + HALF_BOX_HEIGHT, noun_id=noun_id) wire_idx = self._add_wire_end(wire_end) node.add_dom_wire(wire_idx) self._add_wire(nbr_idx, wire_idx) scan_insert = [] if isinstance(box, grammar.Swap): # If Swap, exchange the noun_ids if input_nouns and len(box.dom) > 1: dom_idx_1 = off dom_idx_2 = off + 1 input_nouns[dom_idx_1], input_nouns[dom_idx_2] = ( input_nouns[dom_idx_2], input_nouns[dom_idx_1] ) elif isinstance(node.obj, grammar.Spider): # If Spider, expand or shrink the noun_ids based on type if len(box.dom) == 1 and len(box.cod) > 1: dom_noun = (input_nouns[off] if input_nouns and off < len(input_nouns) else self.get_noun_id()) expanded_colors = [dom_noun] * len(box.cod) input_nouns = (input_nouns[:off] + expanded_colors + input_nouns[off + len(box.dom):]) elif len(box.dom) > 1 and len(box.cod) == 1: cod_noun = (input_nouns[off] if input_nouns and off < len(input_nouns) else self.get_noun_id()) input_nouns = (input_nouns[:off] + [cod_noun] + input_nouns[off + len(box.dom):]) num_output = off + len(box.cod) for i in range(num_output): if i < len(input_nouns): pass else: # If we run out of input nouns, generate new ones new_color = self.get_noun_id() input_nouns.append(new_color) # Create a node representing each element in the box's codomain for i, obj in enumerate(box.cod): # If the box is a quantum gate, retain x coordinate of wires if box.category == quantum and len(box.dom) == len(box.cod): nbr_idx = scan[off + i] x = self.wire_endpoints[nbr_idx].x else: x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2) y = y_pos - HALF_BOX_HEIGHT idx = off + i noun_id = (input_nouns[idx] if input_nouns and idx < len(input_nouns) else self.get_noun_id()) wire_end = WireEndpoint(WireEndpointType.COD, obj=obj, x=x, y=y, noun_id=noun_id) wire_idx = self._add_wire_end(wire_end) scan_insert.append(wire_idx) node.add_cod_wire(wire_idx) # Replace node's dom with its cod in scan return (scan[:off] + scan_insert + scan[off + len(box.dom):], box_ind, input_nouns) def _make_space_for_frame(self, scan: list[int], off: int, outer_box: BoxNode, foliated: bool) -> None: """Shift x and y coords for a new box. Modifies x coordinates of existing nodes to make space.""" assert outer_box.w is not None components_to_left = self._get_components_connected_to_top(scan[:off]) components_to_outer_box = self._get_components_connected_to_top( outer_box.cod_wires, ) if components_to_left: # Get rightmost edge rightmost_edge = float('-inf') for obj in components_to_left: if obj.parent is None: if isinstance(obj, WireEndpoint): obj_right = obj.x else: obj_right = obj.get_x_lims(self)[1] rightmost_edge = max(rightmost_edge, obj_right) left_frame_end = outer_box.x - (outer_box.w / 2) if not set(components_to_left).intersection( set(components_to_outer_box)): for obj in components_to_outer_box: if obj.parent is None: if isinstance(obj, WireEndpoint): obj_left = obj.x else: obj_left = obj.get_x_lims(self)[0] left_frame_end = min(left_frame_end, obj_left) pad = rightmost_edge + BOX_SPACING - left_frame_end for obj in self.boxes + self.wire_endpoints: if obj.parent is None and obj not in components_to_left: obj._apply_drawing_offset((pad, 0)) # Move components to the right of this frame box to the right components_to_right = [] for obj in self.boxes + self.wire_endpoints: if (obj.parent is None and obj not in components_to_left and obj not in components_to_outer_box and (isinstance(obj, WireEndpoint) or obj.has_wires)): components_to_right.append(obj) right_frame_end = outer_box.x + (outer_box.w / 2) for obj in components_to_outer_box: if obj.parent is None: if isinstance(obj, WireEndpoint): obj_right = obj.x else: obj_right = obj.get_x_lims(self)[1] right_frame_end = max(right_frame_end, obj_right) if components_to_right: leftmost_edge = float('inf') for obj in components_to_right: if isinstance(obj, WireEndpoint): obj_left = obj.x else: obj_left = obj.get_x_lims(self)[0] leftmost_edge = min(leftmost_edge, obj_left) pad = right_frame_end - leftmost_edge + BOX_SPACING for obj in components_to_right: obj._apply_drawing_offset((pad, 0)) left_edge, right_edge = ( (outer_box.x - outer_box.w / 2), (outer_box.x + outer_box.w / 2) ) y = 0.0 for upstream_box in self.boxes: bl, br = upstream_box.get_x_lims(self) if not (bl > right_edge or br < left_edge) or foliated: # Boxes overlap upstream_box_h = upstream_box.h or BOX_HEIGHT y = min( y, upstream_box.y - 0.5 * upstream_box_h - 4.5 * BOX_HEIGHT ) def calculate_bounds(self) -> tuple[float, float, float, float]: """Calculate the bounding box of the drawable. Returns ------- tuple of (min_x, min_y, max_x, max_y) The bounds of the drawable. """ # Iterate over boxes all_xs = [wire.x for wire in self.wire_endpoints] all_ys = [obj.y for obj in self.wire_endpoints] for box in self.boxes: all_xs.extend(box.get_x_lims(self)) all_ys.extend(box.get_y_lims(self)) return min(all_xs), min(all_ys), max(all_xs), max(all_ys) def get_noun_id(self) -> int: """Get the latest available numerical ID for the noun wire. Returns ------- noun_id : int The latest noun wire ID. """ # Increment and return the next available ID noun_id = self.noun_id_counter self.noun_id_counter += 1 return noun_id @classmethod def from_diagram(cls, diagram: grammar.Diagram, foliated: bool = False) -> Self: """ Build a graph representation of the diagram, calculating coordinates for each box and wire. Parameters ---------- diagram : grammar.Diagram A lambeq diagram. foliated : bool, default: False If true, each box of the diagram is drawn in a separate layer. By default boxes are compressed upwards into available space. Returns ------- drawable : DrawableDiagram Representation of diagram including all coordinates necessary to draw it. """ drawable = cls() scan = [] # Generate unique noun_ids for input wires num_input = len(diagram.dom) input_nouns = [] for _ in range(num_input): new_color = drawable.get_noun_id() input_nouns.append(new_color) for i, obj in enumerate(diagram.dom): wire_end = WireEndpoint(WireEndpointType.INPUT, obj=obj, x=X_SPACING * i, y=1, noun_id=input_nouns[i]) wire_end_idx = drawable._add_wire_end(wire_end) scan.append(wire_end_idx) min_y = 1.0 max_box_half_height = 0. for _, (box, off) in enumerate(zip(diagram.boxes, diagram.offsets)): # TODO: Debug issues with y coord x, y = drawable._make_space(scan, box, off, foliated=foliated) scan, box_ind, input_nouns = drawable._add_box_with_nouns( scan, box, off, x, y, input_nouns) box_height = BOX_HEIGHT # Add drawables for the inside of the frame if isinstance(box, grammar.Frame): x, y, box_height = drawable._add_components_inside_frame( scan, box, box_ind, off, foliated=foliated, ) max_box_half_height = max(max_box_half_height, (box_height / 2)) min_y = min(min_y, y) num_output = len(diagram.cod) # Match output nouns with input nouns as much as possible for i in range(num_output): if i < len(input_nouns): pass else: # If we run out of input nouns, generate new ones new_color = drawable.get_noun_id() input_nouns.append(new_color) for i, obj in enumerate(diagram.cod): wire_end = WireEndpoint( WireEndpointType.OUTPUT, obj=obj, x=drawable.wire_endpoints[scan[i]].x, y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT, noun_id=input_nouns[i] ) wire_end_idx = drawable._add_wire_end(wire_end) drawable._add_wire(scan[i], wire_end_idx) drawable._move_to_origin() # Push top-level floating boxes, i.e. boxes without # any wires to the right. drawable._relocate_floating_boxes(scan) # Center top-level frames on its dom/cod wires drawable._center_frames_on_wires() return drawable def _center_frames_on_wires(self) -> None: """Center frames to bounds defined by its wires.""" for box in self.boxes: if box.parent is None and isinstance(box.obj, grammar.Frame): # New width dom_width = 0. if box.dom_wires: dom_width = ( self.wire_endpoints[box.dom_wires[-1]].x - self.wire_endpoints[box.dom_wires[0]].x + 2 * LEDGE ) cod_width = 0. if box.cod_wires: cod_width = ( self.wire_endpoints[box.cod_wires[-1]].x - self.wire_endpoints[box.cod_wires[0]].x + 2 * LEDGE ) candidate_width = dom_width ref_wires = box.dom_wires if cod_width > candidate_width: candidate_width = cod_width ref_wires = box.cod_wires if box.w is not None and candidate_width > box.w: box.w = candidate_width # Also shift the box if ref_wires: ref_wires_center = self.wire_endpoints[ref_wires[0]].x ref_wires_center += ( self.wire_endpoints[ref_wires[-1]].x - self.wire_endpoints[ref_wires[0]].x ) / 2 box._apply_drawing_offset( (ref_wires_center - box.x, 0) ) def _relocate_floating_boxes(self, diagram_output: list[int]) -> None: """Push floating boxes to rightmost side of diagram.""" connected_components = self._get_components_connected_to_top( diagram_output, ) rightmost_edge = None if connected_components: rightmost_edge = float('-inf') for obj in connected_components: if obj.parent is None: if isinstance(obj, WireEndpoint): obj_right = obj.x else: obj_right = obj.get_x_lims(self)[1] rightmost_edge = max(rightmost_edge, obj_right) floating_boxes = [] for box in self.boxes: if box.parent is None and not (box.dom_wires or box.cod_wires): floating_boxes.append(box) floating_boxes = sorted( floating_boxes, key=lambda b: b.get_x_lims(self)[0] ) box_start = (rightmost_edge + BOX_SPACING if rightmost_edge is not None else floating_boxes[0].get_x_lims(self)[0]) for i, box in enumerate(floating_boxes): pad = (box_start - box.get_x_lims(self)[0] + (BOX_SPACING if i else 0)) box._apply_drawing_offset((pad, 0)) box_start = box.get_x_lims(self)[1] def _calculate_pos_and_size(self) -> tuple[float, float, float, float]: """Calculate the position and dimensions of the drawable. Returns ------- tuple of 4 floats The x-coordinate, y-coordinate, height, and width of the drawable. """ bl_x, bl_y, tr_x, tr_y = self.calculate_bounds() w = tr_x - bl_x h = tr_y - bl_y x = bl_x + w / 2 y = bl_y + h / 2 return (x, y, h, w) def _add_components_inside_frame( self, scan: list[int], frame: grammar.Frame, box_ind: int, off: int, foliated: bool = False ) -> tuple[float, float, float]: """ Add the drawable components (boxes, wire endpoints, etc.) that come from the frame components to the drawable components in `self`. Parameters ---------- scan : list of int The indices of the wire endpoints we can immediately connect to. frame : grammar.Frame A lambeq frame. box_ind : int The index of the `BoxNode` for the outermost box of the frame in `self.boxes`. off : int The (wire) offset of the frame. foliated : bool, default: False If true, each box of the diagram is drawn in a separate layer. By default boxes are compressed upwards into available space. Returns ------- tuple of 3 floats The x-coordinate, y-coordinate, and height of the modified outermost box of the frame after considering all the drawables inside it. """ # We've just added this box - this is the box # where the dom and cod wires of the frame originate from frame_outer_box = self.boxes[box_ind] component_x_offset = 0. component_y_offset = 2. * LEDGE # Create an empty drawable that would contain all the components # inside the frame frame_drawable = self.__class__() for component in frame.components: # Create a drawable for each component component_drawable = self.__class__.from_diagram( component.to_diagram(), foliated=foliated ) for obj in (component_drawable.boxes + component_drawable.wire_endpoints): obj.parent = None if isinstance(obj, BoxNode): obj.child_boxes = [] obj.child_wire_endpoints = [] obj.child_wires = [] # Assume first that the following is the final # position and size of the box (component_x, component_y, component_h, component_w) = component_drawable._calculate_pos_and_size() # Give some horizontal breathing room component_w += 2 * LEDGE # Add space when component doesn't have dom, cod wires if not component.dom: component_h += LEDGE component_y += LEDGE / 2 if not component.cod: component_h += LEDGE component_y -= LEDGE / 2 # Create wrapper box for the component component_wrapper_box = BoxNode( obj=component.to_diagram(), x=component_x, y=component_y, h=component_h, w=component_w, ) # Put wrapper box to head of list so that it gets # rendered first because boxes are opaque component_drawable.boxes = ([component_wrapper_box] + component_drawable.boxes) component_bounds = component_drawable.calculate_bounds() if component_bounds[0] < 0: # Apply offset so that leftmost edge of component # drawable sits at x=0 in its local coordinates, # otherwise, it will overlap with the component # to its left component_drawable._apply_drawing_offset(( -component_bounds[0], 0 )) # Apply horizontal offset component_drawable._apply_drawing_offset( (component_x_offset, component_y_offset), ) # Compute new offset component_bounds = component_drawable.calculate_bounds() component_x_offset = (component_bounds[2] + FRAME_COMPONENTS_SPACING) # Add this drawable to the main drawable frame_drawable._merge_with(component_drawable) # Create a box node for the entire frame drawable frame_drawable._move_to_origin() (frame_x, frame_y, frame_h, frame_w) = frame_drawable._calculate_pos_and_size() frame_w += 2 * LEDGE # Extra vertical clearance for the name of the frame frame_h += 4 * LEDGE (frame_outer_box_left, frame_outer_box_right) = frame_outer_box.get_x_lims(self) frame_wire_based_width = frame_outer_box_right - frame_outer_box_left frame_outer_box_y_offset = -(frame_h - BOX_HEIGHT) / 2 # We follow the bigger width between # 1) the width computed after considering all # the wires connected to the frame, vs # 2) the width computed for the tightest box # that can contain all the components frame_w = max(frame_w, frame_wire_based_width) frame_drawable_x_offset = (frame_outer_box_left + frame_wire_based_width / 2 - frame_x) # Adjust size of the outer box based on the above data frame_outer_box.w, frame_outer_box.h = frame_w, frame_h frame_outer_box.y += frame_outer_box_y_offset frame_components_offset = ( frame_drawable_x_offset, frame_outer_box.y - frame_y, ) frame_drawable._apply_drawing_offset(frame_components_offset) # Assign parent, child relationship frame_outer_box.child_boxes = frame_drawable.boxes frame_outer_box.child_wire_endpoints = frame_drawable.wire_endpoints frame_outer_box.child_wires = frame_drawable.wires for obj in (frame_outer_box.child_boxes + frame_outer_box.child_wire_endpoints): obj.parent = frame_outer_box # Update y values of cod wires connected to frame_outer_box for i in range(len(frame_outer_box.cod_wires)): self.wire_endpoints[-(i + 1)].y += frame_outer_box_y_offset * 2 # Adjust spacing around frame before merging self._make_space_for_frame(scan, off, frame_outer_box, foliated=foliated) # Merge frame to calling diagram self._merge_with(frame_drawable) return frame_outer_box.x, frame_outer_box.y, frame_outer_box.h def _get_components_connected_to_top( self, scan: list[int], ) -> list[BoxNode | WireEndpoint]: """Return all the boxes and wire endpoints connected to the provided box from the top. This is used to determine the boxes and wire endpoints that shouldn't be moved when making space for a frame. Parameters ---------- scan : list of int The wire endpoints to start with. Returns ------- list of `BoxNode` or `WireEndpoint` instances The objects reachable from the starting wire endpoints. """ list_of_components: list[BoxNode | WireEndpoint] = [] # These are wire indices curr_scan: list[int | BoxNode] = list(scan) new_scan: list[int | BoxNode] = [] while curr_scan: for obj in curr_scan: # If wire index: if isinstance(obj, int): we = self.wire_endpoints[obj] if we.parent is None: # Update output with actual wire endpoint object list_of_components.append(we) if we.kind == WireEndpointType.DOM: # Add other end of wire (just the index) # to `new_scan` for start, end in self.wires: if (start == obj and self.wire_endpoints[end].y > we.y and end not in new_scan): new_scan.append(end) break elif (end == obj and self.wire_endpoints[start].y > we.y and start not in new_scan): new_scan.append(start) break elif we.kind == WireEndpointType.COD: # Check boxes that have the wire as cod for bx in self.boxes: if (bx.parent is None and obj in bx.cod_wires and bx not in new_scan): new_scan.append(bx) elif isinstance(obj, BoxNode) and obj.parent is None: if obj.has_wires: list_of_components.append(obj) # Add cod wire endpoints for we_ind in obj.cod_wires: we = self.wire_endpoints[we_ind] list_of_components.append(we) # Add dom to next scan for wire_ind in obj.dom_wires: if wire_ind not in new_scan: new_scan.append(wire_ind) curr_scan = new_scan new_scan = [] return list_of_components def _apply_drawing_offset(self, offset: tuple[float, float]) -> None: """Apply the offset to all the components inside the drawable. Parameters ---------- offset : tuple[float, float] The x and y offsets to be applied. """ for obj in self.boxes + self.wire_endpoints: obj._apply_drawing_offset(offset) def _merge_with(self, drawable: 'DrawableDiagramWithFrames') -> None: """Merge the passed drawable into the calling drawable. The box, wire endpoint, and wire lists are combined with reindexing. Parameters ---------- drawable : DrawableDiagramWithFrames The drawable to be merged. """ last_wire_endpoint = len(self.wire_endpoints) for wire_endpoint in drawable.wire_endpoints: wire_endpoint.noun_id = 0 self.wire_endpoints.append(wire_endpoint) for box in drawable.boxes: box.dom_wires = [dom_wire + last_wire_endpoint for dom_wire in box.dom_wires] box.cod_wires = [cod_wire + last_wire_endpoint for cod_wire in box.cod_wires] self.boxes.append(box) for wire in drawable.wires: self.wires.append( (wire[0] + last_wire_endpoint, wire[1] + last_wire_endpoint) ) def scale_and_pad(self, scale: tuple[float, float], pad: tuple[float, float]): """Scales and pads the diagram as specified. Parameters ---------- scale : tuple of 2 floats Scaling factors for x and y axes respectively. pad : tuple of 2 floats Padding values for x and y axes respectively. """ min_x = min([node.x for node in self.boxes + self.wire_endpoints]) min_y = min([node.y for node in self.boxes + self.wire_endpoints]) for wire_end in self.wire_endpoints: wire_end.x = min_x + (wire_end.x - min_x) * scale[0] + pad[0] wire_end.y = min_y + (wire_end.y - min_y) * scale[1] + pad[1] for box in self.boxes: box.x = min_x + (box.x - min_x) * scale[0] + pad[0] box.y = min_y + (box.y - min_y) * scale[1] + pad[1] half_box_height = (box.h / 2 if box.h is not None else HALF_BOX_HEIGHT) for wire_end_idx in box.dom_wires: self.wire_endpoints[wire_end_idx].y = ( box.y + half_box_height * scale[1]) for wire_end_idx in box.cod_wires: self.wire_endpoints[wire_end_idx].y = ( box.y - half_box_height * scale[1])