# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Grammar category
================
Lambeq's internal representation of the grammar category. This work is
based on DisCoPy (https://discopy.org/) which is released under the
BSD 3-Clause "New" or "Revised" License.
"""
from __future__ import annotations
from collections.abc import Callable, Iterable, Iterator
from copy import deepcopy
from dataclasses import dataclass, field, InitVar, replace
import json
from typing import Any, ClassVar, Dict, Protocol, Type, TypeVar
from typing import cast, overload, TYPE_CHECKING
from typing_extensions import Self
if TYPE_CHECKING:
import discopy
[docs]
@dataclass
class Entity:
category: ClassVar[Category]
# Types
_JSONDictT = Dict[str, Any]
_EntityType = TypeVar('_EntityType', bound=Type[Entity])
[docs]
@dataclass
class Category:
"""The base class for all categories."""
name: str
Ty: type[Ty] = field(init=False)
Box: type[Box] = field(init=False)
Layer: type[Layer] = field(init=False)
Diagram: type[Diagram] = field(init=False)
[docs]
def set(self, name: str, entity: _EntityType) -> _EntityType:
setattr(self, name, entity)
entity.category = self
return entity
@overload
def __call__(self, name_or_entity: str) -> Callable[[_EntityType],
_EntityType]:
...
@overload
def __call__(self, name_or_entity: _EntityType) -> _EntityType: ...
[docs]
def __call__(
self,
name_or_entity: _EntityType | str
) -> _EntityType | Callable[[_EntityType], _EntityType]:
if isinstance(name_or_entity, str):
name = name_or_entity
def set_(entity: _EntityType) -> _EntityType:
return self.set(name, entity)
return set_
else:
return self.set(name_or_entity.__name__, name_or_entity)
[docs]
def from_json(self, data: _JSONDictT | str) -> Entity:
"""Decode a JSON object or string into an entity from
this category.
Returns
-------
:py:class:`~lambeq.backend.grammar.Entity`
The entity generated from the JSON data. This could be
a :py:class:`~lambeq.backend.Ty`,
a :py:class:`~lambeq.backend.Box` subclass,
or a :py:class:`~lambeq.backend.Diagram` instance.
"""
data_dict = json.loads(data) if isinstance(data, str) else data
_entity_mapping = {
'Cap': Cap,
'Cup': Cup,
'Daggered': Daggered,
'Spider': Spider,
'Swap': Swap,
'Word': Word,
'Ty': self.Ty,
'Box': self.Box,
'Layer': self.Layer,
'Diagram': self.Diagram,
}
entity_cls = _entity_mapping[data_dict['entity']]
return ( # type: ignore[no-any-return]
entity_cls.from_json( # type: ignore[attr-defined]
data_dict
)
)
grammar = Category('grammar')
[docs]
@grammar
@dataclass
class Ty(Entity):
"""A type in the grammar category.
Every type is either atomic, complex, or empty. Complex types are
tensor products of atomic types, and empty types are the identity
type.
Parameters
----------
name : str, optional
The name of the type, by default None.
objects : list[Ty], optional
The objects defining a complex type, by default [].
z : int, optional
The winding number of the type, by default 0.
"""
name: str | None = None
objects: list[Self] = field(default_factory=list)
z: int = 0
category: ClassVar[Category]
def __post_init__(self) -> None:
assert len(self.objects) != 1
assert not (len(self.objects) > 1 and self.name is not None)
if not self.is_atomic:
assert self.z == 0
@property
def is_empty(self) -> bool:
return not self.objects and self.name is None
@property
def is_atomic(self) -> bool:
return not self.objects and self.name is not None
@property
def is_complex(self) -> bool:
return bool(self.objects)
[docs]
def to_diagram(self) -> Diagram:
return self.category.Diagram.id(self)
def __repr__(self) -> str:
if self.is_empty:
return 'Ty()'
elif self.is_atomic:
return f'Ty({self.name}){".l"*(-self.z)}{".r"*self.z}'
else:
return ' @ '.join(map(repr, self.objects))
def __str__(self) -> str:
if self.is_empty:
return 'Ty()'
elif self.is_atomic:
return f'{self.name}{".l"*(-self.z)}{".r"*self.z}'
else:
return ' @ '.join(map(str, self.objects))
def __hash__(self) -> int:
return hash(repr(self))
def __len__(self) -> int:
return 1 if self.is_atomic else len(self.objects)
def __iter__(self) -> Iterator[Self]:
if self.is_atomic:
yield self
else:
yield from self.objects
def __getitem__(self, index: int | slice) -> Self:
objects = [*self]
if TYPE_CHECKING:
objects = cast(list[Self], objects)
if isinstance(index, int):
return objects[index]
else:
return self._fromiter(objects[index])
@classmethod
def _fromiter(cls, objects: Iterable[Self]) -> Self:
"""Create a Ty from an iterable of atomic objects."""
objects = list(objects)
if not objects:
return cls()
elif len(objects) == 1:
return objects[0]
else:
return cls(objects=objects) # type: ignore[arg-type]
[docs]
def count(self, other: Self) -> int:
assert other.is_atomic
return sum(1 for ob in self if ob == other)
@overload
def tensor(self, other: Iterable[Self]) -> Self: ...
@overload
def tensor(self, other: Self, *rest: Self) -> Self: ...
[docs]
def tensor(self, other: Self | Iterable[Self], *rest: Self) -> Self:
try:
tys = [*other, *rest]
except TypeError:
return NotImplemented
if any(not isinstance(ty, type(self))
or self.category != ty.category for ty in tys):
return NotImplemented
return self._fromiter(ob for ty in (self, *tys) for ob in ty)
def __matmul__(self, rhs: Self) -> Self:
return self.tensor(rhs)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the type, changing the winding number."""
if self.is_empty or z == 0:
return self
elif self.is_atomic:
return replace(self, z=self.z + z)
else:
objects = reversed(self.objects) if z % 2 == 1 else self.objects
return type(self)(objects=[ob.rotate(z) for ob in objects])
[docs]
def unwind(self) -> Self:
return self.rotate(-self.z)
@property
def l(self) -> Self: # noqa: E741, E743
return self.rotate(-1)
@property
def r(self) -> Self:
return self.rotate(1)
def __lshift__(self, rhs: Self) -> Self:
if not isinstance(rhs, type(self)) or self.category != rhs.category:
return NotImplemented
return self @ rhs.l
def __rshift__(self, rhs: Self) -> Self:
if not isinstance(rhs, type(self)) or self.category != rhs.category:
return NotImplemented
return self.r @ rhs
[docs]
def repeat(self, times: int) -> Self:
assert times >= 0
return type(self)().tensor([self] * times)
def __pow__(self, times: int) -> Self:
return self.repeat(times)
[docs]
def apply_functor(self, functor: Functor) -> Ty:
assert not self.is_empty
if self.is_complex:
return functor.target_category.Ty().tensor(
functor(ob) for ob in self.objects
)
elif self.z != 0:
return functor(self.unwind()).rotate(self.z)
else:
return functor.ob(self)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
"""Decode a JSON object or string into a
:py:class:`~lambeq.backend.Ty`.
Returns
-------
:py:class:`~lambeq.backend.Ty`
The type generated from the JSON data.
"""
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['objects'] = [cls.from_json(obj_data)
for obj_data
in data_dict['objects']]
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
"""Encode this type to a JSON object.
Parameters
----------
is_top_level : bool, optional
This flag indicates that this object is the top-most object
and should have the global metadata (e.g. category). This
should be set to `False` when calling `to_json` on attribute
instances to avoid duplication of said global metadata.
"""
data_dict: _JSONDictT = {
'entity': self.__class__.__name__,
'name': self.name,
'objects': [obj.to_json(is_top_level=False)
for obj in self.objects],
'z': self.z
}
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
[docs]
class Diagrammable(Protocol):
"""An abstract base class describing the behavior of a diagram.
This is used by static type checkers that recognize structural
sub-typing (duck-typing) and does not need to be explicitly
subclassed.
"""
@property
def dom(self) -> Ty:
"""The domain of the diagram."""
@property
def cod(self) -> Ty:
"""The co-domain of the diagram."""
[docs]
def to_diagram(self) -> Diagram:
"""Transform the current object into an actual Diagram object."""
@property
def is_id(self) -> bool:
"""Whether the current diagram is an identity diagram."""
[docs]
def apply_functor(self, functor: Functor) -> Diagrammable:
"""Apply a functor to the current object."""
[docs]
def rotate(self, z: int) -> Diagrammable:
"""Apply the adjoint operation `z` times.
If `z` is positive, apply the right adjoint `z` times.
If `z` is negative, apply the left adjoint `-z` times.
"""
def __matmul__(self, rhs: Diagrammable | Ty) -> Diagrammable:
"""Implements the tensor operator `@` with another diagram."""
def __rshift__(self, rhs: Diagrammable) -> Diagrammable:
"""Implements composition `>>` with another diagram."""
[docs]
@grammar
@dataclass
class Box(Entity):
"""A box in the grammar category.
Parameters
----------
name : str
The name of the box.
dom : Ty
The domain of the box.
cod : Ty
The codomain of the box.
z : int, optional
The winding number of the box, by default 0.
"""
name: str
dom: Ty
cod: Ty
z: int = 0
def __getattr__(self, name: str) -> Any:
return getattr(self.to_diagram(), name)
def __repr__(self) -> str:
return (f'[{self.name}{".l"*(-self.z)}{".r"*self.z}; '
f'{repr(self.dom)} -> {repr(self.cod)}]')
def __str__(self) -> str:
return f'{self.name}{".l"*(-self.z)}{".r"*self.z}'
def __hash__(self) -> int:
return hash(repr(self))
[docs]
def to_diagram(self) -> Diagram:
ID = self.category.Ty()
dom = super().__getattribute__('dom')
cod = super().__getattribute__('cod')
return self.category.Diagram(dom=dom,
cod=cod,
layers=[self.category.Layer(box=self,
left=ID,
right=ID)])
def __matmul__(self, rhs: Diagrammable | Ty) -> Diagram:
return self.to_diagram().tensor(rhs.to_diagram())
def __rmatmul__(self, rhs: Diagrammable | Ty) -> Diagram:
return rhs.to_diagram().tensor(self.to_diagram())
def __rshift__(self, rhs: Diagrammable) -> Diagram:
return self.to_diagram().then(rhs.to_diagram())
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the box, changing the winding number."""
return replace(self,
dom=self.dom.rotate(z),
cod=self.cod.rotate(z),
z=self.z + z)
@property
def l(self) -> Self: # noqa: E741, E743
return self.rotate(-1)
@property
def r(self) -> Self:
return self.rotate(1)
[docs]
def unwind(self) -> Self:
return self.rotate(-self.z)
[docs]
def dagger(self) -> Daggered | Box:
return Daggered(self)
[docs]
def apply_functor(self, functor: Functor) -> Diagrammable:
if self.z != 0:
return functor(self.unwind()).rotate(self.z)
else:
return functor.ar(self)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
"""Decode a JSON object or string into a
:py:class:`~lambeq.backend.Box`.
Returns
-------
:py:class:`~lambeq.backend.Box`
The box generated from the JSON data.
"""
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['dom'] = cls.category.Ty.from_json(data_dict['dom'])
data_dict['cod'] = cls.category.Ty.from_json(data_dict['cod'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
"""Encode this box to a JSON object.
Parameters
----------
is_top_level : bool, optional
This flag indicates that this object is the top-most object
and should have the global metadata (e.g. category). This
should be set to `False` when calling `to_json` on attribute
instances to avoid duplication of said global metadata.
"""
data_dict: _JSONDictT = {
'entity': self.__class__.__name__,
'name': self.name,
'dom': self.dom.to_json(is_top_level=False),
'cod': self.cod.to_json(is_top_level=False),
'z': self.z
}
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
[docs]
@grammar
@dataclass
class Layer(Entity):
"""A layer in a diagram.
Parameters
----------
box : Box
The box in the layer.
left : Ty
The wire type to the left of the box.
right : Ty
The wire type to the right of the box.
"""
left: Ty
box: Box
right: Ty
def __repr__(self) -> str:
return f'|{repr(self.left)} @ {repr(self.box)} @ {repr(self.right)}|'
def __iter__(self) -> Iterator[Ty | Box]:
iterable_res: Iterable[Ty | Box] = self.unpack()
yield from iterable_res
@property
def dom(self) -> Ty:
return self.left @ self.box.dom @ self.right
@property
def cod(self) -> Ty:
return self.left @ self.box.cod @ self.right
[docs]
def unpack(self) -> tuple[Ty, Box, Ty]:
return self.left, self.box, self.right
[docs]
def extend(self,
left: Ty | None = None,
right: Ty | None = None) -> Self:
ID = self.category.Ty()
if left is None:
left = ID
if right is None:
right = ID
return replace(self, left=left @ self.left, right=self.right @ right)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the layer."""
if z % 2 == 1:
left, right = self.right, self.left
else:
left, right = self.left, self.right
return replace(self,
left=left.rotate(z),
box=self.box.rotate(z),
right=right.rotate(z))
[docs]
def dagger(self) -> Self:
return replace(self,
left=self.left,
box=self.box.dagger(),
right=self.right)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
"""Decode a JSON object or string into a
:py:class:`~lambeq.backend.grammar.Layer`.
Returns
-------
:py:class:`~lambeq.backend.grammar.Layer`
The layer generated from the JSON data.
"""
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['left'] = cls.category.Ty.from_json(data_dict['left'])
_entity_mapping = {
'Cap': Cap,
'Cup': Cup,
'Daggered': Daggered,
'Spider': Spider,
'Swap': Swap,
'Word': Word,
'Box': cls.category.Box,
}
box_cls = _entity_mapping[data_dict['box']['entity']]
data_dict['box'] = box_cls.from_json( # type: ignore[attr-defined]
data_dict['box']
)
data_dict['right'] = cls.category.Ty.from_json(data_dict['right'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
"""Encode this layer to a JSON object.
Parameters
----------
is_top_level : bool, optional
This flag indicates that this object is the top-most object
and should have the global metadata (e.g. category). This
should be set to `False` when calling `to_json` on attribute
instances to avoid duplication of said global metadata.
"""
data_dict: _JSONDictT = {'entity': self.__class__.__name__}
for attr in ('left', 'right', 'box'):
data_dict[attr] = getattr(self, attr).to_json(
is_top_level=False
)
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
_DiagrammableFactory = Callable[..., Diagrammable]
_DiagrammableFactoryT = TypeVar('_DiagrammableFactoryT',
bound=_DiagrammableFactory)
[docs]
@grammar
@dataclass
class Diagram(Entity):
"""A diagram in the grammar category.
Parameters
----------
dom : Ty
The type of the input wires.
cod : Ty
The type of the output wires.
layers : list[Layer]
The layers of the diagram.
"""
dom: Ty
cod: Ty
layers: list[Layer]
special_boxes: ClassVar[dict[str, _DiagrammableFactory]] = {}
def __init_subclass__(cls) -> None:
cls.special_boxes = {}
@classmethod
@overload
def register_special_box(
cls,
name: str,
diagram_factory: None = None
) -> Callable[[_DiagrammableFactoryT], _DiagrammableFactoryT]: ...
@classmethod
@overload
def register_special_box(
cls,
name: str,
diagram_factory: _DiagrammableFactory
) -> None: ...
[docs]
@classmethod
def register_special_box(
cls,
name: str,
diagram_factory: _DiagrammableFactory | None = None
) -> None | Callable[[_DiagrammableFactoryT], _DiagrammableFactoryT]:
def set_(
diagram_factory: _DiagrammableFactoryT
) -> _DiagrammableFactoryT:
cls.special_boxes[name] = diagram_factory
return diagram_factory
if diagram_factory is None:
return set_
else:
set_(diagram_factory)
return None
def __repr__(self) -> str:
if self.is_id:
return f'Id({repr(self.dom)})'
else:
return ' >> '.join(map(repr, self.layers))
def __hash__(self) -> int:
return hash(repr(self))
[docs]
@classmethod
def fa(cls, left, right) -> Self:
return cls.id(left) @ cls.cups(right.l, right)
[docs]
@classmethod
def ba(cls, left, right) -> Self:
return cls.id().tensor(cls.cups(left, left.r), cls.id(right))
[docs]
@classmethod
def fc(cls, left, middle, right) -> Self:
return cls.id(left) @ cls.cups(middle.l, middle) @ cls.id(right.l)
[docs]
@classmethod
def bc(cls, left, middle, right) -> Self:
return cls.id(left.r) @ cls.cups(middle, middle.r) @ cls.id(right)
[docs]
@classmethod
def fx(cls, left, middle, right) -> Self:
return (cls.id(left) @ cls.swap(middle.l, right.r) @ cls.id(middle)
>> cls.swap(left, right.r) @ cls.cups(middle.l, middle))
[docs]
@classmethod
def bx(cls, left, middle, right) -> Self:
return (cls.id(middle) @ cls.swap(left.l, middle.r) @ cls.id(right)
>> cls.cups(middle, middle.r) @ cls.swap(left.l, right))
[docs]
@classmethod
def caps(cls,
left: Ty,
right: Ty,
is_reversed=False) -> Diagrammable:
return cls.special_boxes['cap'](left, right, is_reversed)
[docs]
@classmethod
def cups(cls,
left: Ty,
right: Ty, is_reversed=False) -> Diagrammable:
return cls.special_boxes['cup'](left, right, is_reversed)
[docs]
@classmethod
def swap(cls,
left: Ty,
right: Ty) -> Diagrammable:
return cls.special_boxes['swap'](left, right)
[docs]
def to_diagram(self) -> Self:
return self
[docs]
@classmethod
def id(cls, dom: Ty | None = None) -> Self:
if dom is None:
dom = cls.category.Ty()
return cls(dom=dom, cod=dom, layers=[])
@property
def is_id(self) -> bool:
return not self.layers
@property
def boxes(self) -> list[Box]:
return [layer.box for layer in self.layers]
[docs]
@classmethod
def create_pregroup_diagram(
cls,
words: list[Word],
morphisms: list[tuple[type, int, int]]
) -> Self:
"""Create a :py:class:`~.Diagram` from cups and swaps.
>>> n, s = Ty('n'), Ty('s')
>>> words = [Word('she', n), Word('goes', n.r @ s @ n.l),
... Word('home', n)]
>>> morphs = [(Cup, 0, 1), (Cup, 3, 4)]
>>> diagram = Diagram.create_pregroup_diagram(words, morphs)
Parameters
----------
words : list of :py:class:`~lambeq.backend.Word`
A list of :py:class:`~lambeq.backend.Word` s
corresponding to the words of the sentence.
morphisms: list of tuple[type, int, int]
A list of tuples of the form:
(morphism, start_wire_idx, end_wire_idx).
Morphisms can be :py:class:`~lambeq.backend.Cup` s or
:py:class:`~lambeq.backend.Swap` s, while the two numbers
define the indices of the wires on which the morphism is
applied.
Returns
-------
:py:class:`~lambeq.backend.Diagram`
The generated pregroup diagram.
Raises
------
ValueError
If the provided morphism list does not type-check properly.
"""
types: Ty = cls.category.Ty()
boxes: list[Word] = []
offsets: list[int] = []
for w in words:
boxes.append(w)
offsets.append(len(types))
types @= w.cod
for idx, (typ, start, end) in enumerate(morphisms):
if typ not in (Cup, Swap):
raise ValueError(f'Unknown morphism type: {typ}')
box = typ(types[start:start+1], types[end:end+1])
boxes.append(box)
actual_idx = start
for pr_idx in range(idx):
if (morphisms[pr_idx][0] == Cup
and morphisms[pr_idx][1] < start):
actual_idx -= 2
offsets.append(actual_idx)
boxes_and_offsets = list(zip(boxes, offsets))
diagram = cls.id()
for box, offset in boxes_and_offsets:
left = diagram.cod[:offset]
right = diagram.cod[offset + len(box.dom):]
diagram = diagram >> cls.id(left) @ box @ cls.id(right)
return diagram
@property
def is_pregroup(self) -> bool:
"""Check if a diagram is a pregroup diagram.
Adapted from :py:class:`discopy.grammar.pregroup.draw`.
Returns
-------
bool
Whether the diagram is a pregroup diagram.
"""
if self.dom:
# pregroup diagrams must have empty domain
return False
in_words = True
for layer in self.layers:
if in_words and isinstance(layer.box, Word):
if not layer.right.is_empty:
return False
else:
if not isinstance(layer.box, (Cup, Swap)):
return False
in_words = False
return True
[docs]
@classmethod
def lift(cls, diagrams: Iterable[Diagrammable | Ty]) -> list[Self]:
"""Lift diagrams to the current category.
Given a list of boxes or diagrams, call `to_diagram` on each,
then check all of the diagrams are in the same category as the
calling class.
Parameters
----------
diagrams : iterable
The diagrams to lift and check.
Returns
-------
list of Diagram
The diagrams after calling `to_diagram` on each.
Raises
------
ValueError
If any of the diagrams are not in the same category of the
calling class.
"""
try:
diags = [diagram.to_diagram() for diagram in diagrams]
except AttributeError as e:
raise ValueError from e
if any(not isinstance(diagram, cls)
or cls.category != diagram.category for diagram in diags):
raise ValueError
return diags # type: ignore[return-value]
[docs]
def tensor(self, *diagrams: Diagrammable | Ty) -> Self:
try:
diags = self.lift([self, *diagrams])
except ValueError:
return NotImplemented
right = dom = self.dom.tensor(*[
diagram.to_diagram().dom for diagram in diagrams
])
left = self.category.Ty()
layers = []
for diagram in diags:
right = right[len(diagram.dom):]
layers += [layer.extend(left, right) for layer in diagram.layers]
left @= diagram.cod
return type(self)(dom=dom, cod=left, layers=layers)
def __matmul__(self, rhs: Diagrammable | Ty) -> Self:
return self.tensor(rhs)
def __rmatmul__(self, rhs: Diagrammable | Ty) -> Diagram:
return rhs.to_diagram().tensor(self)
@property
def offsets(self) -> list[int]:
"""The offset of a box is the length of the type on its left."""
return [len(layer.left) for layer in self.layers]
def __iter__(self) -> Iterator[Layer]:
yield from self.layers
def __len__(self) -> int:
return len(self.layers)
def __getitem__(self, key: int | slice) -> Self:
if isinstance(key, slice):
if key.step == -1:
layers = [layer.dagger() for layer in self.layers[key]]
return type(self)(self.cod, self.dom, layers)
if (key.step or 1) != 1:
raise IndexError
layers = self.layers[key]
if not layers:
if (key.start or 0) >= len(self):
return self.id(self.cod)
if (key.start or 0) <= -len(self):
return self.id(self.dom)
return self.id(self.layers[key.start or 0].dom)
return type(self)(
layers[0].dom, layers[-1].cod, layers)
if isinstance(key, int):
if key >= len(self) or key < -len(self):
raise IndexError
if key < 0:
return self[len(self) + key]
return self[key:key + 1]
raise TypeError
[docs]
def then(self, *diagrams: Diagrammable) -> Self:
try:
diags = self.lift(diagrams)
except ValueError:
return NotImplemented
layers = [*self.layers]
cod = self.cod
for n, diagram in enumerate(diags):
if diagram.dom != cod:
raise ValueError(f'Diagram {n} (cod={cod}) does not compose '
f'with diagram {n+1} (dom={diagram.dom})')
cod = diagram.cod
layers.extend(diagram.layers)
return type(self)(dom=self.dom, cod=cod, layers=layers)
[docs]
def then_at(self, diagram: Diagrammable, index: int) -> Self:
return (self
>> (self.id(self.cod[:index])
@ diagram
@ self.id(self.cod[index+len(diagram.dom):])))
def __rshift__(self, rhs: Diagrammable) -> Self:
return self.then(rhs)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the diagram."""
return type(self)(dom=self.dom.rotate(z),
cod=self.cod.rotate(z),
layers=[layer.rotate(z) for layer in self.layers])
@property
def l(self) -> Self: # noqa: E741, E743
return self.rotate(-1)
@property
def r(self) -> Self:
return self.rotate(1)
[docs]
def dagger(self) -> Self:
if self.is_id:
return self
else:
return type(self)(dom=self.cod,
cod=self.dom,
layers=[replace(layer, box=layer.box.dagger())
for layer in reversed(self.layers)])
[docs]
def transpose(self, left: bool = False) -> Self:
"""Construct the diagrammatic transpose.
The transpose of any diagram in a category with cups and caps
can be constructed as follows:
.. code-block:: console
(default)
Left transpose Right transpose
│╭╮ ╭╮│
│█│ │█│
╰╯│ │╰╯
The input and output types of the transposed diagram are the
adjoints of the respective types of the original diagram.
This means that for diagrams with composite types, the order of
the objects are reversed.
Parameters
----------
left : bool, default: False
Whether to transpose to the diagram to the left.
Returns
-------
Diagram
The transposed diagram, constructed as shown above.
"""
Cap = self.category.Diagram.special_boxes['cap']
Cup = self.category.Diagram.special_boxes['cup']
Id = self.id
if left:
top_layer = Id(self.cod.l) @ Cap(self.dom, self.dom.l)
mid_layer = Id(self.cod.l) @ self @ Id(self.dom.l)
bot_layer = Cup(self.cod.l, self.cod) @ Id(self.dom.l)
else:
top_layer = Cap(self.dom.r, self.dom) # type: ignore[assignment]
top_layer @= Id(self.cod.r)
mid_layer = Id(self.dom.r) @ self @ Id(self.cod.r)
bot_layer = Id(self.dom.r) @ Cup(self.cod, self.cod.r)
return top_layer >> mid_layer >> bot_layer
[docs]
def curry(self, n: int = 1, left: bool = True) -> Self:
"""
"""
Cap = self.category.Diagram.special_boxes['cap']
Id = self.id
if left:
base, exponent = self.dom[:-n], self.dom[-n:]
return (Id(base) @ Cap(exponent, exponent.l)
>> self @ Id(exponent.l))
else:
base, exponent = self.dom[n:], self.dom[:n]
return (Cap(exponent.r, exponent) @ Id(base) # type: ignore[return-value] # noqa: E501
>> Id(exponent.r) @ self)
[docs]
@classmethod
def permutation(cls, dom: Ty, permutation: Iterable[int]) -> Self:
"""Create a layer of Swaps that permutes the wires."""
permutation = list(permutation)
if not (len(permutation) == len(dom)
and set(permutation) == set(range(len(dom)))):
raise ValueError('Invalid permutation for type of length '
f'{len(dom)}: {permutation}')
wire_index = [*range(len(dom))]
diagram = cls.id(dom)
for out_index in range(len(dom) - 1):
in_index = wire_index[permutation[out_index]]
assert in_index >= out_index
for i in reversed(range(out_index, in_index)):
diagram >>= (
cls.id(diagram.cod[:i])
[docs]
@ cls.special_boxes['swap'](*diagram.cod[i:i+2])
@ cls.id(diagram.cod[i+2:])
)
for i in range(permutation[out_index]):
wire_index[i] += 1
return diagram
def permuted(self, permutation: Iterable[int]) -> Self:
return self >> self.permutation(self.cod, permutation)
[docs]
def draw(self, draw_as_pregroup=True, **kwargs: Any) -> None:
"""Draw the diagram.
Parameters
----------
draw_as_pregroup : bool, optional
Whether to try drawing the diagram as a pregroup diagram,
default is `True`.
draw_as_nodes : bool, optional
Whether to draw boxes as nodes, default is `False`.
color : string, optional
Color of the box or node, default is white (`'#ffffff'`) for
boxes and red (`'#ff0000'`) for nodes.
textpad : pair of floats, optional
Padding between text and wires, default is `(0.1, 0.1)`.
draw_type_labels : bool, optional
Whether to draw type labels, default is `False`.
draw_box_labels : bool, optional
Whether to draw box labels, default is `True`.
aspect : string, optional
Aspect ratio, one of `['auto', 'equal']`.
margins : tuple, optional
Margins, default is `(0.05, 0.05)`.
nodesize : float, optional
BoxNode size for spiders and controlled gates.
fontsize : int, optional
Font size for the boxes, default is `12`.
fontsize_types : int, optional
Font size for the types, default is `12`.
figsize : tuple, optional
Figure size.
path : str, optional
Where to save the image, if `None` we call `plt.show()`.
to_tikz : bool, optional
Whether to output tikz code instead of matplotlib.
asymmetry : float, optional
Make a box and its dagger mirror images, default is
`.25 * any(box.is_dagger for box in diagram.boxes)`.
"""
if draw_as_pregroup and self.is_pregroup:
from lambeq.backend.drawing import draw_pregroup
draw_pregroup(self, **kwargs)
else:
from lambeq.backend.drawing import draw
draw(self, **kwargs)
[docs]
def render_as_str(self, **kwargs: Any) -> str:
"""Render the diagram as text.
Presently only implemented for pregroup diagrams.
Parameters
----------
word_spacing : int, default: 2
The number of spaces between the words of the diagrams.
use_at_separator : bool, default: False
Whether to represent types using @ as the monoidal product.
Otherwise, use the unicode dot character.
compress_layers : bool, default: True
Whether to draw boxes in the same layer when they can occur
simultaneously, otherwise, draw one box per layer.
use_ascii: bool, default: False
Whether to draw using ASCII characters only, for
compatibility reasons.
Returns
-------
str
Drawing of diagram in string format.
"""
from lambeq.backend.drawing import render_as_str
return render_as_str(self, **kwargs)
[docs]
def apply_functor(self, functor: Functor) -> Diagram:
assert not self.is_id
diagram = functor(self.id(self.dom))
for layer in self.layers:
left, box, right = layer.unpack()
diagram >>= (functor(self.id(left))
[docs]
@ functor(box).to_diagram()
@ functor(self.id(right)))
return diagram
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
"""Decode a JSON object or string into a
:py:class:`~lambeq.backend.Diagram`.
Returns
-------
:py:class:`~lambeq.backend.Diagram`
The diagram generated from the JSON data.
"""
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['dom'] = cls.category.Ty.from_json(data_dict['dom'])
data_dict['cod'] = cls.category.Ty.from_json(data_dict['cod'])
data_dict['layers'] = [cls.category.Layer.from_json(layer_data)
for layer_data
in data_dict['layers']]
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
"""Encode this diagram to a JSON object.
Parameters
----------
is_top_level : bool, optional
This flag indicates that this object is the top-most object
and should have the global metadata (e.g. category). This
should be set to `False` when calling `to_json` on attribute
instances to avoid duplication of said global metadata.
"""
data_dict: _JSONDictT = {'entity': self.__class__.__name__}
for attr in ('dom', 'cod'):
data_dict[attr] = getattr(self, attr).to_json(is_top_level=False)
data_dict['layers'] = [layer.to_json(is_top_level=False)
for layer in self.layers]
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
[docs]
def to_discopy(self) -> 'discopy.monoidal.Diagram':
"""Export lambeq diagram to discopy diagram.
Returns
-------
:class:`discopy.monoidal.Diagram`
"""
from lambeq.backend.converters.discopy import to_discopy
return to_discopy(self)
[docs]
@classmethod
def from_discopy(cls,
diagram: 'discopy.monoidal.Diagram') -> Diagram:
"""Import discopy diagram to lambeq diagram.
Parameters
----------
diagram : :class:`discopy.monoidal.Diagram`
"""
from lambeq.backend.converters.discopy import from_discopy
return from_discopy(diagram)
[docs]
@Diagram.register_special_box('cap')
@dataclass
class Cap(Box):
"""The unit of the adjunction for an atomic type.
Parameters
----------
left : Ty
The type of the left output.
right : Ty
The type of the right output.
is_reversed : bool, default: False
Whether the cap is reversed or not. Normally, caps only allow
outputs where `right` is the left adjoint of `left`. However,
to facilitate operations like `dagger`, we pass in a flag that
indicates that the inputs are the opposite way round, which
initialises a reversed cap. Then, when a cap is adjointed, it
turns into a reversed cap, which can be adjointed again to turn
it back into a normal cap.
"""
left: Ty
right: Ty
is_reversed: InitVar[bool] = False
name: str = field(init=False)
dom: Ty = field(init=False)
cod: Ty = field(init=False)
z: int = field(init=False)
def __post_init__(self, is_reversed: bool) -> None:
if not self.left.is_atomic or not self.right.is_atomic:
raise ValueError('left and right need to be atomic types.')
self._check_adjoint(self.left, self.right, is_reversed)
self.name = 'CAP'
self.dom = self.category.Ty()
self.cod = self.left @ self.right
self.z = int(is_reversed)
@staticmethod
def _check_adjoint(left: Ty, right: Ty, is_reversed: bool) -> None:
if is_reversed:
if left != right.l:
raise ValueError('left and right need to be adjoints')
else:
if left != right.r:
raise ValueError('left and right need to be adjoints')
def __new__(cls, # type: ignore[misc]
left: Ty,
right: Ty,
is_reversed: bool = False) -> Diagrammable:
if left.is_atomic and right.is_atomic:
return super().__new__(cls)
else:
cls._check_adjoint(left, right, is_reversed)
diagram = cls.category.Diagram.id()
for i, (l_ob, r_ob) in enumerate(zip(left, reversed(right))):
diagram = diagram.then_at(cls(l_ob, r_ob), i)
return diagram
def __reduce__(self):
return (self.__class__, (self.left, self.right, bool(self.z % 2)))
def __deepcopy__(self, memo) -> Self:
left_copy = deepcopy(self.left, memo)
right_copy = deepcopy(self.right, memo)
return type(self)(left_copy, right_copy, bool(self.z % 2))
[docs]
@classmethod
def to_right(cls, left: Ty, is_reversed: bool = False) -> Self | Diagram:
return cls(left, left.r if is_reversed else left.l)
[docs]
@classmethod
def to_left(cls, right: Ty, is_reversed: bool = False) -> Self | Diagram:
return cls(right.l if is_reversed else right.r, right)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the cap."""
if z % 2 == 1:
left, right = self.right, self.left
else:
left, right = self.left, self.right
is_reversed = (self.z + z) % 2 == 1
return type(self)(left.rotate(z),
right.rotate(z),
is_reversed=is_reversed)
[docs]
def dagger(self) -> Cup:
Cup = self.category.Diagram.special_boxes['cup']
return Cup(self.left, # type: ignore[return-value]
self.right,
is_reversed=not self.z)
[docs]
def apply_functor(self, functor: Functor) -> Diagrammable:
return functor.target_category.Diagram.special_boxes['cap'](
functor(self.left),
functor(self.right),
is_reversed=bool(self.z)
)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['left'] = cls.category.Ty.from_json(data_dict['left'])
data_dict['right'] = cls.category.Ty.from_json(data_dict['right'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
data_dict: _JSONDictT = {'entity': self.__class__.__name__,
'is_reversed': bool(self.z)}
for attr in ('left', 'right'):
data_dict[attr] = getattr(self, attr).to_json(is_top_level=False)
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
__repr__ = Box.__repr__
__hash__ = Box.__hash__
[docs]
@Diagram.register_special_box('cup')
@dataclass
class Cup(Box):
"""The counit of the adjunction for an atomic type.
Parameters
----------
left : Ty
The type of the left output.
right : Ty
The type of the right output.
is_reversed : bool, default: False
Whether the cup is reversed or not. Normally, cups only allow
inputs where `right` is the right adjoint of `left`. However,
to facilitate operations like `dagger`, we pass in a flag that
indicates that the inputs are the opposite way round, which
initialises a reversed cup. Then, when a cup is adjointed, it
turns into a reversed cup, which can be adjointed again to turn
it back into a normal cup.
"""
left: Ty
right: Ty
is_reversed: InitVar[bool] = False
name: str = field(init=False)
dom: Ty = field(init=False)
cod: Ty = field(init=False)
z: int = field(init=False)
def __post_init__(self, is_reversed: bool) -> None:
if not self.left.is_atomic or not self.right.is_atomic:
raise ValueError('left and right need to be atomic types.')
self._check_adjoint(self.left, self.right, is_reversed)
self.name = 'CUP'
self.dom = self.left @ self.right
self.cod = self.category.Ty()
self.z = int(is_reversed)
@staticmethod
def _check_adjoint(left: Ty, right: Ty, is_reversed: bool) -> None:
if is_reversed:
if left != right.r:
raise ValueError('left and right need to be adjoints')
else:
if left != right.l:
raise ValueError('left and right need to be adjoints')
def __new__(cls, # type: ignore[misc]
left: Ty,
right: Ty,
is_reversed: bool = False) -> Diagrammable:
if left.is_atomic and right.is_atomic:
return super().__new__(cls)
else:
cls._check_adjoint(left, right, is_reversed)
diagram = cls.category.Diagram.id(left @ right)
for i, (l_ob, r_ob) in enumerate(zip(reversed(left), right)):
diagram = diagram.then_at(cls(l_ob, r_ob), len(left) - 1 - i)
return diagram
def __reduce__(self):
return (self.__class__, (self.left, self.right, bool(self.z % 2)))
def __deepcopy__(self, memo) -> Self:
left_copy = deepcopy(self.left, memo)
right_copy = deepcopy(self.right, memo)
return type(self)(left_copy, right_copy, bool(self.z % 2))
[docs]
@classmethod
def to_right(cls, left: Ty, is_reversed: bool = False) -> Self | Diagram:
return cls(left, left.l if is_reversed else left.r)
[docs]
@classmethod
def to_left(cls, right: Ty, is_reversed: bool = False) -> Self | Diagram:
return cls(right.r if is_reversed else right.l, right)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the cup."""
if z % 2 == 1:
left, right = self.right, self.left
else:
left, right = self.left, self.right
is_reversed = (self.z + z) % 2 == 1
return type(self)(left.rotate(z),
right.rotate(z),
is_reversed=is_reversed)
[docs]
def dagger(self) -> Cap:
Cap = self.category.Diagram.special_boxes['cap']
return Cap( # type: ignore[return-value]
self.left,
self.right,
is_reversed=not self.z
)
[docs]
def apply_functor(self, functor: Functor) -> Diagrammable:
return functor.target_category.Diagram.special_boxes['cup'](
functor(self.left),
functor(self.right),
is_reversed=bool(self.z)
)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['left'] = cls.category.Ty.from_json(data_dict['left'])
data_dict['right'] = cls.category.Ty.from_json(data_dict['right'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
data_dict: _JSONDictT = {'entity': self.__class__.__name__,
'is_reversed': bool(self.z)}
for attr in ('left', 'right'):
data_dict[attr] = getattr(self, attr).to_json(is_top_level=False)
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
__repr__ = Box.__repr__
__hash__ = Box.__hash__
[docs]
@dataclass
class Daggered(Box):
"""A daggered box.
Parameters
----------
box : Box
The box to be daggered.
"""
box: Box
name: str = field(init=False)
dom: Ty = field(init=False)
cod: Ty = field(init=False)
z: int = field(init=False)
def __post_init__(self) -> None:
self.name = self.box.name + '†'
self.dom = self.box.cod
self.cod = self.box.dom
self.z = self.box.z
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the daggered box."""
return type(self)(self.box.rotate(z))
[docs]
def dagger(self) -> Box:
return self.box
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
box = cls.category.Box.from_json(data_dict['box'])
return box.dagger() # type: ignore[return-value]
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
data_dict: _JSONDictT = {
'entity': self.__class__.__name__,
'box': self.box.to_json(is_top_level=False),
}
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
__repr__ = Box.__repr__
__hash__ = Box.__hash__
[docs]
@Diagram.register_special_box('spider')
@dataclass
class Spider(Box):
"""A spider in the grammar category.
Parameters
----------
type : Ty
The atomic type of the spider.
n_legs_in : int
The number of input legs.
n_legs_out : int
The number of output legs.
"""
type: Ty
n_legs_in: int
n_legs_out: int
name: str = field(init=False)
dom: Ty = field(init=False)
cod: Ty = field(init=False)
z: int = field(default=0, init=False)
def __post_init__(self) -> None:
if not self.type.is_atomic:
raise TypeError('Spider type needs to be atomic.')
self.name = 'SPIDER'
self.dom = self.type ** self.n_legs_in
self.cod = self.type ** self.n_legs_out
def __new__(cls, # type: ignore[misc]
type: Ty,
n_legs_in: int,
n_legs_out: int) -> Diagrammable:
if type.is_atomic:
return super().__new__(cls)
else:
size = len(type)
total_legs_in = size * n_legs_in
return (
cls.category.Diagram.permutation(
type ** n_legs_in,
[j
for i in range(size)
for j in range(i, total_legs_in, size)]
)
>> cls.category.Diagram.id().tensor(
*(cls(ob, n_legs_in, n_legs_out)
for ob in type)
).permuted([
j
for i in range(n_legs_out)
for j in range(i, len(type) * n_legs_out, n_legs_out)
])
)
def __reduce__(self):
return (self.__class__, (self.type, self.n_legs_in, self.n_legs_out))
def __deepcopy__(self, memo) -> Self:
typ = deepcopy(self.type, memo)
n_legs_in = deepcopy(self.n_legs_in, memo)
n_legs_out = deepcopy(self.n_legs_out)
return type(self)(typ, n_legs_in, n_legs_out)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the spider."""
return type(self)(self.type.rotate(z), len(self.dom), len(self.cod))
[docs]
def dagger(self) -> Self:
return type(self)(self.type, self.n_legs_out, self.n_legs_in)
[docs]
def apply_functor(self, functor: Functor) -> Diagrammable:
return functor.target_category.Diagram.special_boxes['spider'](
functor(self.type),
self.n_legs_in,
self.n_legs_out
)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['type'] = cls.category.Ty.from_json(data_dict['type'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
data_dict: _JSONDictT = {
'entity': self.__class__.__name__,
'type': self.type.to_json(is_top_level=False),
'n_legs_in': self.n_legs_in,
'n_legs_out': self.n_legs_out
}
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
__repr__ = Box.__repr__
__hash__ = Box.__hash__
[docs]
@Diagram.register_special_box('swap')
@dataclass
class Swap(Box):
"""A swap in the grammar category.
Swaps two wires.
Parameters
----------
left : Ty
The atomic type of the left input wire.
right : Ty
The atomic type of the right input wire.
"""
left: Ty
right: Ty
name: str = field(init=False)
dom: Ty = field(init=False)
cod: Ty = field(init=False)
z: int = field(default=0, init=False)
def __post_init__(self) -> None:
if not self.left.is_atomic or not self.right.is_atomic:
raise ValueError('Types need to be atomic')
self.name = 'SWAP'
self.dom = self.left @ self.right
self.cod = self.right @ self.left
def __new__(cls, # type: ignore[misc]
left: Ty,
right: Ty) -> Swap | Diagram:
if left.is_atomic and right.is_atomic:
return super().__new__(cls)
else:
diagram = cls.category.Diagram.id(left @ right)
for start, ob in enumerate(right):
for i in reversed(range(len(left))):
diagram = diagram.then_at(cls(left[i], ob), start + i)
return diagram
def __reduce__(self):
return (self.__class__, (self.left, self.right))
def __deepcopy__(self, memo) -> Self:
left_copy = deepcopy(self.left, memo)
right_copy = deepcopy(self.right, memo)
return type(self)(left_copy, right_copy)
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the swap."""
if z % 2 == 1:
left, right = self.right, self.left
else:
left, right = self.left, self.right
return type(self)(left.rotate(z), right.rotate(z))
[docs]
def dagger(self) -> Self:
return type(self)(self.right, self.left)
[docs]
def apply_functor(self, functor: Functor) -> Diagrammable:
return functor.target_category.Diagram.special_boxes['swap'](
functor(self.left),
functor(self.right)
)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['left'] = cls.category.Ty.from_json(data_dict['left'])
data_dict['right'] = cls.category.Ty.from_json(data_dict['right'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
data_dict: _JSONDictT = {'entity': self.__class__.__name__}
for attr in ('left', 'right'):
data_dict[attr] = getattr(self, attr).to_json(is_top_level=False)
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
__repr__ = Box.__repr__
__hash__ = Box.__hash__
[docs]
@dataclass
class Word(Box):
"""A word in the grammar category.
A word is a :py:class:`~.Box` with an empty domain.
Parameters
----------
name : str
The name of the word.
cod : Ty
The codomain of the word.
z : int, optional
The winding number of the word, by default 0
"""
name: str
cod: Ty
dom: Ty = field(init=False)
def __post_init__(self) -> None:
self.dom = self.category.Ty()
def __repr__(self) -> str:
return f'Word({self.name}, {repr(self.cod), {repr(self.z)}})'
def __hash__(self) -> int:
return hash(repr(self))
[docs]
def rotate(self, z: int) -> Self:
"""Rotate the Word box, changing the winding number."""
return type(self)(self.name, self.cod.rotate(z), self.z + z)
[docs]
def dagger(self) -> Daggered:
return Daggered(self)
[docs]
@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
data_dict = json.loads(data) if isinstance(data, str) else data
_, _ = data_dict.pop('category', None), data_dict.pop('entity')
data_dict['cod'] = cls.category.Ty.from_json(data_dict['cod'])
return cls(**data_dict)
[docs]
def to_json(self, is_top_level: bool = True) -> _JSONDictT:
data_dict: _JSONDictT = {
'entity': self.__class__.__name__,
'name': self.name,
'cod': self.cod.to_json(is_top_level=False),
'z': self.z,
}
if is_top_level:
data_dict['category'] = self.category.name
return data_dict
Id = Diagram.id
[docs]
@dataclass(init=False)
class Functor:
"""A functor that maps between categories.
Parameters
----------
target_category : Category
The category to which the functor maps.
ob : callable, optional
A function that maps types to types, by default None
ar : callable, optional
A function that maps boxes to Diagrammables, by default None
Examples
--------
>>> n = Ty('n')
>>> diag = Cap(n, n.l) @ Id(n) >> Id(n) @ Cup(n.l, n)
>>> diag.draw(
... figsize=(2, 2), path='./snake.png')
.. image:: ./_static/images/snake.png
:align: center
>>> F = Functor(grammar, lambda _, ty : ty @ ty)
>>> F(diag).draw(
... figsize=(2, 2), path='./snake-2.png')
.. image:: ./_static/images/snake-2.png
:align: center
"""
target_category: Category
[docs]
def __init__(
self,
target_category: Category,
ob: Callable[[Functor, Ty], Ty],
ar: Callable[[Functor, Box], Diagrammable] | None = None
) -> None:
self.target_category = target_category
self.custom_ob = ob
self.custom_ar = ar
self.ob_cache: dict[Ty, Ty] = {}
self.ar_cache: dict[Diagrammable, Diagrammable] = {}
@overload
def __call__(self, entity: Ty) -> Ty: ...
@overload
def __call__(self, entity: Box) -> Diagrammable: ...
@overload
def __call__(self, entity: Diagram) -> Diagram: ...
@overload
def __call__(self, entity: Diagrammable) -> Diagrammable: ...
[docs]
def __call__(self, entity: Ty | Diagrammable) -> Ty | Diagrammable:
"""Apply the functor to a type or a diagrammable.
Parameters
----------
entity : Ty or Diagrammable
The type or diagrammable to which the functor is applied.
"""
if isinstance(entity, Ty):
return self.ob_with_cache(entity)
else:
return self.ar_with_cache(entity)
[docs]
def ob_with_cache(self, ob: Ty) -> Ty:
"""Apply the functor to a type, caching the result."""
try:
return deepcopy(self.ob_cache[ob])
except KeyError:
pass
if ob.is_empty:
ret = self.target_category.Ty()
else:
ret = ob.apply_functor(self)
self.ob_cache[ob] = ret
return ret
[docs]
def ar_with_cache(self, ar: Diagrammable) -> Diagrammable:
"""Apply the functor to a diagrammable, caching the result."""
try:
return deepcopy(self.ar_cache[ar])
except KeyError:
pass
if not ar.is_id:
ret = ar.apply_functor(self)
else:
ret = self.target_category.Diagram.id(self.ob_with_cache(ar.dom))
self.ar_cache[ar] = ret
cod_check = self.ob_with_cache(ar.cod)
dom_check = self.ob_with_cache(ar.dom)
if ret.cod != cod_check or ret.dom != dom_check:
raise TypeError(f'The arrow is ill-defined. Applying the functor '
f'to a box returns dom = {ret.dom}, cod = '
f'{ret.cod} expected dom = {dom_check}, cod = '
f'{cod_check}')
return ret
[docs]
def ob(self, ob: Ty) -> Ty:
"""Apply the functor to a type."""
if self.custom_ob is None:
raise AttributeError('Specify a custom ob function if you want to '
'use the functor on types.')
return self.custom_ob(self, ob)
[docs]
def ar(self, ar: Box) -> Diagrammable:
"""Apply the functor to a box."""
if self.custom_ar is None:
raise AttributeError('Specify a custom ar function if you want to '
'use the functor on boxes.')
return self.custom_ar(self, ar)