# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Drawable Components
===================
Utilities to convert a grammar diagram into a drawable form.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
import sys
from typing_extensions import Self
from lambeq.backend import grammar
from lambeq.backend.quantum import quantum
X_SPACING = 2.5 # Minimum space between adjacent wires
LEDGE = 0.5 # Space from last wire to right box edge
BOX_HEIGHT = 0.5
HALF_BOX_HEIGHT = 0.25
class WireEndpointType(Enum):
"""An enumeration for :py:class:`WireEndpoint`.
WireEndpoints in diagrams can be of 4 types:
.. glossary::
DOM
Domain of a box.
COD
Codomain of a box.
INPUT
Input wire to the diagram.
OUTPUT
Output wire from the diagram.
"""
DOM = 0
COD = 1
INPUT = 2
OUTPUT = 3
def __repr__(self) -> str:
return self.name
@dataclass
class WireEndpoint:
"""
One end of a wire in a DrawableDiagram.
Attributes
----------
kind: WireEndpointType
Type of wire endpoint.
obj: grammar.Ty
Categorial type carried by the wire.
x: float
X coordinate of the wire end.
y: float
Y coordinate of the wire end.
coordinates: (float, float)
(x, y) coordinates.
"""
kind: WireEndpointType
obj: grammar.Ty
x: float
y: float
@property
def coordinates(self) -> tuple[float, float]:
return (self.x, self.y)
@dataclass
class BoxNode:
"""
Box in a DrawableDiagram.
Attributes
----------
obj: grammar.Box
Grammar box represented by the node.
x: float
X coordinate of the box.
y: float
Y coordinate of the box.
coordinates: (float, float)
(x, y) coordinates.
dom_wires: list of int
Wire endpoints in the domain of the box, represented by
indices into an array maintained by `DrawableDiagram`.
com_wires: list of int
Wire endpoints in the codomain of the box, represented by
indices into an array maintained by `DrawableDiagram`.
"""
obj: grammar.Box
x: float
y: float
dom_wires: list[int] = field(default_factory=list)
cod_wires: list[int] = field(default_factory=list)
@property
def coordinates(self):
return (self.x, self.y)
def add_dom_wire(self, idx: int) -> None:
"""
Add a wire to to box's domain.
Parameters
----------
idx : int
Index of wire in associated `DrawableDiagram`'s
`wire_endpoints` attribute.
"""
self.dom_wires.append(idx)
def add_cod_wire(self, idx: int) -> None:
"""
Add a wire to to box's codomain.
Parameters
----------
idx : int
Index of wire in associated `DrawableDiagram`'s
`wire_endpoints` attribute.
"""
self.cod_wires.append(idx)
def get_x_lims(self,
drawable_diagram: DrawableDiagram) -> tuple[float, float]:
"""
Get left and right limits of the box.
Parameters
----------
drawable_diagram : DrawableDiagram
`DrawableDiagram` with which this box is associated.
"""
all_wires_pos = [drawable_diagram.wire_endpoints[wire].x
for wire in self.cod_wires + self.dom_wires]
if not all_wires_pos: # scalar box
all_wires_pos = [self.x]
left = min(all_wires_pos) - LEDGE
right = max(all_wires_pos) + LEDGE
return left, right
[docs]
@dataclass
class DrawableDiagram:
"""
Representation of a lambeq diagram carrying all
information necessary to render it.
Attributes
----------
boxes: list of BoxNode
Boxes in the diagram.
wire_endpoints: list of WireEndpoint
Endpoints for all wires in the diagram.
wires: list of tuple of the form (int, int)
The wires in a diagram, each represented by the indices of
its 2 endpoints in `wire_endpoints`.
"""
boxes: list[BoxNode] = field(default_factory=list)
wire_endpoints: list[WireEndpoint] = field(default_factory=list)
wires: list[tuple[int, int]] = field(default_factory=list)
def _add_wire(self,
source: int,
target: int) -> None:
"""Add an edge between 2 connected wire endpoints."""
self.wires.append((source, target))
def _add_wire_end(self, wire_end: WireEndpoint) -> int:
"""Add a `WireEndpoint` to the diagram."""
self.wire_endpoints.append(wire_end)
return len(self.wire_endpoints) - 1
def _add_boxnode(self, box: BoxNode) -> int:
"""Add a `BoxNode` to the diagram."""
self.boxes.append(box)
return len(self.boxes) - 1
def _add_box(self,
scan: list[int],
box: grammar.Box,
off: int,
x_pos: float,
y_pos: float) -> list[int]:
"""Add a box to the graph, creating necessary wire endpoints."""
node = BoxNode(box, x_pos, y_pos)
self._add_boxnode(node)
# Create a node representing each element in the box's domain
for i, obj in enumerate(box.dom):
nbr_idx = scan[off + i]
wire_end = WireEndpoint(WireEndpointType.DOM,
obj=obj,
x=self.wire_endpoints[nbr_idx].x,
y=y_pos + HALF_BOX_HEIGHT)
wire_idx = self._add_wire_end(wire_end)
node.add_dom_wire(wire_idx)
self._add_wire(nbr_idx, wire_idx)
scan_insert = []
# Create a node representing each element in the box's codomain
for i, obj in enumerate(box.cod):
# If the box is a quantum gate, retain x coordinate of wires
if box.category == quantum and len(box.dom) == len(box.cod):
nbr_idx = scan[off + i]
x = self.wire_endpoints[nbr_idx].x
else:
x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2)
y = y_pos - HALF_BOX_HEIGHT
wire_end = WireEndpoint(WireEndpointType.COD,
obj=obj,
x=x,
y=y)
wire_idx = self._add_wire_end(wire_end)
scan_insert.append(wire_idx)
node.add_cod_wire(wire_idx)
# Replace node's dom with its cod in scan
return scan[:off] + scan_insert + scan[off + len(box.dom):]
def _find_box_edges(self,
box: grammar.Box,
x: float,
off: int,
scan: list[int]):
left_edge = x
right_edge = x
# dom edges come from upstream wire endpoints
if box.dom:
left_edge = min(self.wire_endpoints[scan[off]].x, left_edge)
right_edge = max(
self.wire_endpoints[scan[off + len(box.dom) - 1]].x,
right_edge)
# cod edges are evenly spaced
if box.cod:
left_edge = min(x - X_SPACING * len(box.cod[1:]) / 2, left_edge)
right_edge = max(x + X_SPACING * (len(box.cod[1:])
- len(box.cod[1:]) / 2),
right_edge)
return left_edge - LEDGE, right_edge + LEDGE
def _make_space(self,
scan: list[int],
box: grammar.Box,
off: int) -> tuple[float, float]:
"""Determines x and y coords for a new box.
Modifies x coordinates of existing nodes to make space."""
if not scan:
return 0, 0
half_width = X_SPACING * (len(box.cod[:-1]) / 2 + 1)
if not box.dom:
if not off:
x = self.wire_endpoints[scan[0]].x - half_width
elif off == len(scan):
x = self.wire_endpoints[scan[-1]].x + half_width
else:
right = self.wire_endpoints[scan[off + len(box.dom)]].x
x = (self.wire_endpoints[scan[off - 1]].x + right) / 2
else:
right = self.wire_endpoints[scan[off + len(box.dom) - 1]].x
x = (self.wire_endpoints[scan[off]].x + right) / 2
if off and self.wire_endpoints[scan[off - 1]].x > x - half_width:
limit = self.wire_endpoints[scan[off - 1]].x
pad = limit - x + half_width
for node in self.boxes + self.wire_endpoints:
if node.x <= limit:
node.x -= pad
if (off + len(box.dom) < len(scan)
and (self.wire_endpoints[scan[off + len(box.dom)]].x
< x + half_width)):
limit = self.wire_endpoints[scan[off + len(box.dom)]].x
pad = x + half_width - limit
for node in self.boxes + self.wire_endpoints:
if node.x >= limit:
node.x += pad
left_edge, right_edge = self._find_box_edges(box, x, off, scan)
y = 0.0
for upstream_box in self.boxes:
bl, br = upstream_box.get_x_lims(self)
if not (bl > right_edge or br < left_edge):
# Boxes overlap
y = min(y, upstream_box.y - 1.0)
return x, y
def _move_to_origin(self) -> None:
"""Set the min x and middle-y coordinates of the diagram to 0.
Setting the diagram to be centred on the y axis allows us to
avoid precomputing the diagram's height.
"""
min_x = min(
[node.x for node in self.boxes + self.wire_endpoints])
min_y = min(
[node.y for node in self.boxes + self.wire_endpoints])
max_y = max(
[node.y for node in self.boxes + self.wire_endpoints])
mid_y = (min_y + max_y) / 2
for node in self.boxes + self.wire_endpoints:
node.x -= min_x
node.y -= mid_y
[docs]
@classmethod
def from_diagram(cls,
diagram: grammar.Diagram,
foliated: bool = False) -> Self:
"""
Builds a graph representation of the diagram, calculating
coordinates for each box and wire.
Parameters
----------
diagram : grammar Diagram
A lambeq diagram.
foliated : bool, default: False
If true, each box of the diagram is drawn in a separate
layer. By default boxes are compressed upwards into
available space.
Returns
-------
drawable : DrawableDiagram
Representation of diagram including all coordinates
necessary to draw it.
"""
drawable = cls()
scan = []
for i, obj in enumerate(diagram.dom):
wire_end = WireEndpoint(WireEndpointType.INPUT,
obj=obj,
x=X_SPACING * i,
y=1)
wire_end_idx = drawable._add_wire_end(wire_end)
scan.append(wire_end_idx)
min_y = 1.0
for depth, (box, off) in enumerate(zip(diagram.boxes,
diagram.offsets)):
x, y = drawable._make_space(scan, box, off)
y = -depth if foliated else y
scan = drawable._add_box(scan, box, off, x, y)
min_y = min(min_y, y)
for i, obj in enumerate(diagram.cod):
wire_end = WireEndpoint(WireEndpointType.OUTPUT,
obj=obj,
x=drawable.wire_endpoints[scan[i]].x,
y=min_y - 1)
wire_end_idx = drawable._add_wire_end(wire_end)
drawable._add_wire(scan[i], wire_end_idx)
drawable._move_to_origin()
return drawable
[docs]
def scale_and_pad(self,
scale: tuple[float, float],
pad: tuple[float, float]):
"""Scales and pads the diagram as specified.
Parameters
----------
scale : tuple of 2 floats
Scaling factors for x and y axes respectively.
pad : tuple of 2 floats
Padding values for x and y axes respectively.
"""
min_x = min([node.x for node in self.boxes + self.wire_endpoints])
min_y = min([node.y for node in self.boxes + self.wire_endpoints])
for wire_end in self.wire_endpoints:
wire_end.x = min_x + (wire_end.x - min_x) * scale[0] + pad[0]
wire_end.y = min_y + (wire_end.y - min_y) * scale[1] + pad[1]
for box in self.boxes:
box.x = min_x + (box.x - min_x) * scale[0] + pad[0]
box.y = min_y + (box.y - min_y) * scale[1] + pad[1]
for wire_end_idx in box.dom_wires:
self.wire_endpoints[wire_end_idx].y = (
box.y + HALF_BOX_HEIGHT * scale[1])
for wire_end_idx in box.cod_wires:
self.wire_endpoints[wire_end_idx].y = (
box.y - HALF_BOX_HEIGHT * scale[1])
class PregroupError(Exception):
def __init__(self, diagram):
super().__init__(f'Diagram {diagram} is not a pregroup diagram. '
'A pregroup diagram must be structured like '
'(State @ State ... State) >> (Cups and Swaps)')
@dataclass
class DrawablePregroup(DrawableDiagram):
"""
Representation of a lambeq pregroup diagram carrying all
information necessary to render it.
Attributes
----------
x_tracks: list of int
Stores the "track" on which the corresponding `WireEndpoint` in
`wire_endpoints` lies. This helps determine the depth of
pregroup grammar boxes in the diagram.
"""
x_tracks: list[int] = field(default_factory=list)
def _add_wire_end(self, wire_end: WireEndpoint, x_track=-1) -> int:
"""Add a `WireEndpoint` to the diagram, with track information."""
self.x_tracks.append(x_track)
return super()._add_wire_end(wire_end)
@classmethod
def from_diagram(cls,
diagram: grammar.Diagram,
foliated: bool = False) -> Self:
"""
Builds a graph representation of the diagram, calculating
coordinates for each box and wire.
Parameters
----------
diagram : grammar Diagram
A lambeq diagram.
foliated : bool, default: False
This parameter is not used for pregroup diagrams, which are
always drawn un-foliated.
Returns
-------
drawable : DrawableDiagram
Representation of diagram including all coordinates
necessary to draw it.
"""
if foliated:
print('Pregroup diagrams cannot be drawn foliated.'
' Set `draw_as_pregroup` to `False` to see'
' foliation for this diagram.', file=sys.stderr)
words = []
grammar_start_idx = len(diagram)
for i, layer in enumerate(diagram.layers):
if (isinstance(layer.box, grammar.Cup)
or isinstance(layer.box, grammar.Swap)):
grammar_start_idx = i
break
if layer.right or layer.box.dom:
raise PregroupError(diagram)
words.append(layer.box)
HSPACE = 0.5
VSPACE = 0.75
BOX_WIDTH = 2
drawable = cls()
scan = []
track_ctr = 0
for i, word in enumerate(words):
node = BoxNode(word, (HSPACE + BOX_WIDTH) * i
+ (0.5 * BOX_WIDTH * isinstance(word, grammar.Cap)),
0)
for j, ty in enumerate(word.cod):
wire_x = ((HSPACE + BOX_WIDTH) * i
+ (BOX_WIDTH / (len(word.cod) + 1)) * (j + 1))
wire_end_idx = drawable._add_wire_end(
WireEndpoint(WireEndpointType.COD,
ty,
wire_x,
0.25), track_ctr)
node.add_cod_wire(wire_end_idx)
scan.append(wire_end_idx)
track_ctr += 1
drawable.boxes.append(node)
depth_map = [0.0 for _ in range(track_ctr)]
for layer in diagram.layers[grammar_start_idx:]:
off = len(layer.left)
box = layer.box
lx = drawable.wire_endpoints[scan[off]].x
rx = drawable.wire_endpoints[scan[off + 1]].x
l_track = drawable.x_tracks[scan[off]]
r_track = drawable.x_tracks[scan[off + 1]]
y = min(depth_map[l_track: r_track + 1])
l_wire_end_idx = drawable._add_wire_end(
WireEndpoint(WireEndpointType.DOM,
box.dom[0],
lx,
y - VSPACE / 2), l_track)
r_wire_end_idx = drawable._add_wire_end(
WireEndpoint(WireEndpointType.DOM,
box.dom[1],
rx,
y - VSPACE / 2), r_track)
drawable._add_wire(scan[off], l_wire_end_idx)
drawable._add_wire(scan[off + 1], r_wire_end_idx)
grammar_box = BoxNode(box, (lx + rx) / 2, y - VSPACE)
grammar_box.add_dom_wire(l_wire_end_idx)
grammar_box.add_dom_wire(r_wire_end_idx)
if isinstance(box, grammar.Swap):
l_idx = drawable._add_wire_end(
WireEndpoint(WireEndpointType.COD,
box.cod[0],
lx,
y - VSPACE), l_track)
r_idx = drawable._add_wire_end(
WireEndpoint(WireEndpointType.COD,
box.cod[1],
rx,
y - VSPACE), r_track)
grammar_box.add_cod_wire(l_idx)
grammar_box.add_cod_wire(r_idx)
scan[off] = l_idx
scan[off + 1] = r_idx
elif isinstance(box, grammar.Cup):
# 2 elements of the codomain are consumed.
scan = scan[:off] + scan[off + 2:]
else:
raise PregroupError(diagram)
drawable.boxes.append(grammar_box)
for i in range(l_track, r_track + 1):
depth_map[i] = y - VSPACE
min_y = min(depth_map)
for i, obj in enumerate(diagram.cod):
wire_end = WireEndpoint(WireEndpointType.OUTPUT,
obj,
drawable.wire_endpoints[scan[i]].x,
min_y - VSPACE)
wire_end_idx = drawable._add_wire_end(wire_end)
drawable._add_wire(scan[i], wire_end_idx)
drawable._move_to_origin()
return drawable