# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Tensor category
===============
Lambeq's internal representation of the tensor category. This work is
based on DisCoPy (https://discopy.org/) which is released under the
BSD 3-Clause 'New' or 'Revised' License.
"""
from __future__ import annotations
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field, replace
from functools import cached_property
import math
from typing import Mapping
import numpy as np
import sympy
import tensornetwork as tn
from typing_extensions import Any, Self
from lambeq.backend import grammar
from lambeq.backend.numerical_backend import get_backend
tensor = grammar.Category('tensor')
[docs]
@tensor('Ty')
@dataclass(init=False)
class Dim(grammar.Ty):
"""Dimension in the tensor category.
Attributes
----------
dim : tuple of int
Tuple of dimensions represented by the object.
product: int
Product of contained dimensions.
"""
objects: list[Self] # type: ignore[assignment,misc]
[docs]
def __init__(self,
*dim: int,
objects: list[Self] | None = None) -> None:
"""Initialise a Dim type.
Parameters
----------
dim : list[int]
List of dimensions to initialise.
objects: list[Self] or None, default None
List of `Dim`s, to prepare a non-atomic `Dim` object.
"""
if objects:
assert not len(dim)
super().__init__(objects=objects) # type: ignore[arg-type]
else:
dims: list[int] = list(filter(lambda x: x > 1, dim))
if not len(dims):
super().__init__()
if len(dims) == 1:
super().__init__(str(dims[0]))
else:
super().__init__(objects=[Dim(d) for d in dims])
@property
def dim(self) -> tuple[int, ...]:
if self.is_atomic:
return (int(self.name), ) # type: ignore[arg-type]
return tuple(dim for subdim in self.objects for dim in subdim.dim)
[docs]
def rotate(self, z: int) -> Self:
if self.is_atomic:
return self
return super().rotate(z)
def _repr_rec(self) -> str:
if self.is_empty:
return '1'
elif self.is_atomic:
return self.name # type: ignore[return-value]
else:
return ', '.join(d._repr_rec() for d in self.objects)
@property
def product(self) -> int:
return math.prod(self.dim)
def __repr__(self) -> str:
return f'Dim({self._repr_rec()})'
def __hash__(self) -> int:
return hash(repr(self))
[docs]
@dataclass(init=False)
@tensor
class Box(grammar.Box):
"""Box (tensor) in the the tensor category.
Attributes
----------
data : np.array or float or None
Data used to represent the `array` attribute. Typically either
the array itself, or a symbolic array.
array : np.array or float
Tensor which the box represents.
free_symbols : set of sympy.Symbol
In case of a symbolic tensor, set of symbols in the box's data.
"""
name: str
dom: Dim
cod: Dim
data: float | np.ndarray | None
z: int
[docs]
def __init__(self,
name: str,
dom: Dim,
cod: Dim,
data: float | np.ndarray | None = None,
z: int = 0):
"""Initialise a `tensor.Box` type.
Parameters
----------
name : str
Name for the box.
dom : Dim
Dimension of the box's domain.
cod : Dim
Dimension of the box's codomain.
data : float or np.ndarray, optional
The concrete tensor the box represents.
z : int, optional
Winding number of the box, indicating conjugation. Starts at
0 if not provided.
"""
self.name = name
self.dom = dom
self.cod = cod
self.data = data
self.z = z
@property
def array(self):
if self.data is not None:
if self.z % 2:
ret_arr = self._conjugate_array()
else:
ret_arr = get_backend().array(self.data)
return ret_arr.reshape(self.dom.dim + self.cod.dim)
def _adjoint_array(self):
"""Returns the adjoint of the box's data"""
arr = self.array
source = range(len(self.dom @ self.cod))
target = [i + len(self.cod) if i < len(self.dom) else
i - len(self.dom) for i in range(len(self.dom @ self.cod))]
return np.conjugate(np.moveaxis(arr, source, target))
def _conjugate_array(self):
"""Returns the diagrammtic conjugate of the box's data"""
dom, cod = self.dom, self.cod
array = np.moveaxis(self.data,
range(len(dom @ cod)),
[len(dom) - i - 1 for i in range(len(dom @ cod))])
return np.conjugate(array)
[docs]
def dagger(self):
"""Get the dagger (adjoint) of the box.
Returns
-------
Box
Dagger of the box.
"""
return Daggered(self)
[docs]
def rotate(self, z: int):
"""Get the result of conjugating the box `z` times.
Parameters
----------
z : int
Winding count. The number of conjugations to apply to the box.
Returns
-------
Box
The box conjugated z times.
"""
return replace(self,
dom=self.dom.rotate(z),
cod=self.cod.rotate(z),
z=(self.z + z) % 2)
@cached_property
def free_symbols(self) -> set[sympy.Symbol]:
def recursive_free_symbols(data) -> set[sympy.Symbol]:
if isinstance(data, Mapping):
data = data.values()
if isinstance(data, Iterable):
if not hasattr(data, 'shape') or data.shape != ():
return set().union(*map(recursive_free_symbols, data))
return getattr(data, 'free_symbols', set())
return recursive_free_symbols(self.data)
[docs]
def lambdify(self, *symbols: 'sympy.Symbol', **kwargs) -> Callable:
"""Get a lambdified version of a symbolic box.
Returns a function which when provided appropriate parameters,
initialises a concrete box.
Parameters
----------
symbols : list of sympy.Symbols
List of symbols in the box in the order in which their
assigned values will appear in the concretisation call.
kwargs:
Additional parameters to pass to `sympy.lambdify`.
Returns
-------
Callable[..., Box]:
A lambda function which when invoked with appropriate
parameters, returns a concrete version of the box.
"""
if not any(x in self.free_symbols for x in symbols):
return lambda *xs: self
return lambda *xs: type(self)(
self.name, self.dom, self.cod,
sympy.lambdify(symbols, self.data, **kwargs)(*xs))
def __repr__(self) -> str:
return (f'[{self.name}{".l"*(-self.z)}{".r"*self.z}; '
f'{repr(self.dom)} -> {repr(self.cod)}]')
def __hash__(self) -> int:
return hash(repr(self))
[docs]
@dataclass
@tensor
class Layer(grammar.Layer):
"""Layer in the tensor category."""
left: Dim
box: Box
right: Dim
[docs]
@dataclass
@tensor
class Diagram(grammar.Diagram):
"""Diagram in the tensor category."""
dom: Dim
cod: Dim
layers: list[Layer] # type: ignore[assignment]
[docs]
def lambdify(self, *symbols, **kwargs):
lambdified_layers = [(l_,
bx.lambdify(*symbols, **kwargs),
r_) for l_, bx, r_ in self.layers]
def lambda_diagram(*xs):
return type(self)(
self.dom,
self.cod,
[self.category.Layer(l_,
bx_lambda(*xs),
r_) for (l_,
bx_lambda,
r_) in lambdified_layers])
return lambda_diagram
@cached_property
def free_symbols(self) -> set[sympy.Symbol]:
return set().union(*(box.free_symbols for box in self.boxes))
[docs]
def eval(self, contractor=tn.contractors.auto, dtype: type | None = None):
"""Evaluate the tensor diagram.
Parameters
----------
contractor : tn contractor
`tensornetwork` contractor for chosen contraction algorithm.
dtype : type, optional
Data type of the resulting array. Defaults to `np.float32`.
Returns
-------
numpy.ndarray
n-dimension array representing the contracted tensor.
"""
return contractor(*self.to_tn(dtype=dtype)).tensor
[docs]
def to_tn(self, dtype: type | None = None):
"""Convert the diagram to a `tensornetwork` TN.
Parameters
----------
dtype : type, optional
Data type of the resulting array. Defaults to `np.float32`.
Returns
-------
tuple[list[tn.Node], list[tn.Edge]]
`tensornetwork` representation of the diagram. An edge
object is returned for each dangling edge in the network.
"""
if dtype is None:
dtype = np.float32
backend = get_backend().name
nodes = [tn.CopyNode(2, dim, dtype=dtype,
backend=backend) for dim in self.dom.dim]
inputs = [node[0] for node in nodes]
scan = [node[1] for node in nodes]
diag = self.category.Diagram.id(self.dom)
for layer in self.layers:
left, box, right = layer.unpack()
subdiag = box
if hasattr(box, 'decompose'):
subdiag = box.decompose()
diag >>= (self.category.Diagram.id(left)
[docs]
@ subdiag
@ self.category.Diagram.id(right))
for lyr in diag.layers:
l, box, r = lyr.unpack()
if isinstance(box, Swap):
scan[len(l)], scan[len(l) + 1] = scan[len(l) + 1], scan[len(l)]
elif isinstance(box, Cup):
tn.connect(scan[len(l)], scan[len(l) + 1])
del scan[len(l): len(l) + 2]
else:
if isinstance(box, Spider):
node = tn.CopyNode(box.n_legs_in + box.n_legs_out,
box.type.product, dtype=dtype,
backend=backend)
else:
node = tn.Node(box.array,
str(box.name),
backend=backend)
nodes.append(node)
for i in range(len(box.dom)):
tn.connect(scan[len(l) + i], node[i])
scan = (scan[:len(l)]
+ node[len(box.dom):]
+ scan[len(l) + len(box.dom):])
# nodes, input_edge_order, output_edge_order
return nodes, inputs + scan
__hash__ = grammar.Diagram.__hash__
@Diagram.register_special_box('cap')
class Cap(grammar.Cap, Box):
"""A Cap in the tensor category."""
left: Dim
right: Dim
dom: Dim
cod: Dim
z: int = 0
is_reversed: bool = False
[docs]
def __init__(self, left: Dim, right: Dim, is_reversed: bool = False):
"""Initialise a tensor Cap.
Parameters
----------
left : Dim
Dimension (type) of the left leg of the cap. Must be the
conjugate of `right`.
right : Dim
Dimension (type) of the right leg of the cap. Must be the
conjugate of `left`.
is_reversed : bool, default False
Ignored parameter, since left and right conjugates are
equivalent in the tensor category. Necessary to inherit
from `grammar.Cap` appropriately.
"""
super().__init__(left, right)
arr = np.zeros(left.product ** 2)
arr[0] = 1
arr[-1] = 1
self.data = arr
__hash__ = Box.__hash__
__repr__ = Box.__repr__
[docs]
@Diagram.register_special_box('cup')
class Cup(grammar.Cup, Box):
"""A Cup in the tensor category."""
left: Dim
right: Dim
name: str
dom: Dim
cod: Dim
z: int = 0
is_reversed: bool = False
[docs]
def __init__(self, left: Dim, right: Dim, is_reversed: bool = False):
"""Initialise a tensor Cup.
Parameters
----------
left : Dim
Dimension (type) of the left leg of the cup. Must be the
conjugate of `right`.
right : Dim
Dimension (type) of the right leg of the cup. Must be the
conjugate of `left`.
is_reversed : bool, default False
Ignored parameter, since left and right conjugates are
equivalent in the tensor category. Necessary to inherit
from `grammar.Cup` appropriately.
"""
super().__init__(left, right)
arr = np.zeros(left.product ** 2)
arr[0] = 1
arr[-1] = 1
self.data = arr
__hash__ = Box.__hash__
__repr__ = Box.__repr__
[docs]
@Diagram.register_special_box('swap')
class Swap(grammar.Swap, Box):
"""A Swap in the tensor category."""
left: Dim
right: Dim
name: str
dom: Dim
cod: Dim
z: int = 0
[docs]
def __init__(self, left: Dim, right: Dim):
"""Initialise a tensor Swap.
Parameters
----------
left : Dim
Dimension (type) of the left input of the swap.
right : Dim
Dimension (type) of the right input of the swap.
"""
grammar.Swap.__init__(self, left, right)
Box.__init__(self, 'SWAP', left @ right, right @ left)
[docs]
def dagger(self):
return type(self)(self.right, self.left)
__hash__ = Box.__hash__
__repr__ = Box.__repr__
[docs]
@Diagram.register_special_box('spider')
class Spider(grammar.Spider, Box):
"""A Spider in the tensor category.
Concretely represented by a copy node.
"""
type: Dim
n_legs_in: int
n_legs_out: int
name: str
dom: Dim
cod: Dim
z: int = 0
[docs]
def __init__(self, type: Dim, n_legs_in: int, n_legs_out: int):
"""Initialise a tensor Spider.
Parameters
----------
type : Dim
Dimension (type) of each leg of the spider.
n_legs_in : int
Number of input legs of the spider.
n_legs_out : int
Number of input legs of the spider.
"""
Box.__init__(self, 'SPIDER', type ** n_legs_in, type ** n_legs_out)
grammar.Spider.__init__(self, type, n_legs_in, n_legs_out)
[docs]
def dagger(self) -> Self:
return type(self)(self.type, self.n_legs_out, self.n_legs_in)
__hash__ = Box.__hash__
__repr__ = Box.__repr__
Id = Diagram.id
[docs]
@dataclass
class Daggered(grammar.Daggered, Box):
"""A daggered box.
Attributes
----------
box : Box
The box to be daggered.
"""
box: Box
name: str = field(init=False)
dom: Dim = field(init=False)
cod: Dim = field(init=False)
data: float | np.ndarray | None = field(default=None, init=False)
z: int = field(init=False)
def __post_init__(self) -> None:
self.name = self.box.name + '†'
self.dom = self.box.cod
self.cod = self.box.dom
self.data = self.box.data
self.z = self.box.z
def __setattr__(self, __name: str, __value: Any) -> None:
if __name == 'data':
self.box.data = __value
return super().__setattr__(__name, __value)
@property
def array(self):
return self.box._adjoint_array()
__hash__ = Box.__hash__
__repr__ = Box.__repr__