# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rewrite
=======
A rewrite rule is a schema for transforming/simplifying a diagram.
The :py:class:`Rewriter` applies a set of rewrite rules functorially to
a given diagram.
Subclass :py:class:`RewriteRule` to define a custom rewrite rule. An
example rewrite rule :py:class:`SimpleRewriteRule` has been provided for
basic rewrites, as well as a number of example rules. These can be used
by specifying their name when instantiating a :py:class:`Reader`. A list
of provided rules can be retrieved using
:py:meth:`Rewriter.available_rules`. They are:
.. glossary::
auxiliary
The auxiliary rule removes auxiliary verbs (such as "do") by
replacing them with caps.
connector
The connector rule removes sentence connectors (such as "that")
by replacing them with caps.
coordination
The coordination rule simplifies "and" based on [Kar2016]_
by replacing it with a layer of interleaving spiders.
curry
The curry rewrite rule uses map-state duality to remove adjoint
types from the boxes. When used in conjunction with
:py:meth:`lambeq.backend.grammar.Diagram.pregroup_normal_form`,
this removes cups from the diagram.
determiner
The determiner rule removes determiners (such as "the") by
replacing them with caps.
object_rel_pronoun
The object relative pronoun rule simplifies object relative
pronouns based on [SCC2014a]_ using cups, spiders and a loop.
postadverb, preadverb
The adverb rules simplify adverbs by passing through the noun
wire transparently using a cup.
prepositional_phrase
The prepositional phrase rule simplifies the preposition in a
prepositional phrase by passing through the noun wire
transparently using a cup.
subject_rel_pronoun
The subject relative pronoun rule simplifies subject relative
pronouns based on [SCC2014a]_ using cups and spiders.
See `examples/rewrite.ipynb` for illustrative usage.
"""
from __future__ import annotations
__all__ = ['CoordinationRewriteRule', 'Rewriter',
'RewriteRule', 'SimpleRewriteRule', 'UnknownWordsRewriteRule']
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Container, Iterable
from lambeq.backend.grammar import (Box, Cap, Cup, Diagram, Diagrammable,
Functor, grammar, Id, Spider, Swap,
Ty, Word)
from lambeq.core.types import AtomicType
N = AtomicType.NOUN
S = AtomicType.SENTENCE
[docs]
class RewriteRule(ABC):
"""Base class for rewrite rules."""
[docs]
@abstractmethod
def matches(self, box: Box) -> bool:
"""Check if the given box should be rewritten."""
[docs]
@abstractmethod
def rewrite(self, box: Box) -> Diagrammable:
"""Rewrite the given box."""
[docs]
def __call__(self, box: Box) -> Diagrammable | None:
"""Apply the rewrite rule to a box.
Parameters
----------
box : :py:class:`lambeq.backend.grammar.Box`
The candidate box to be tested against this rewrite rule.
Returns
-------
:py:class:`lambeq.backend.grammar.Diagram`, optional
The rewritten diagram, or :py:obj:`None` if rule
does not apply.
Notes
-----
The default implementation uses the :py:meth:`matches` and
:py:meth:`rewrite` methods, but derived classes may choose to
not use them, since the default :py:class:`Rewriter`
implementation does not call those methods directly, only this
one.
"""
return self.rewrite(box) if self.matches(box) else None
[docs]
class SimpleRewriteRule(RewriteRule):
"""A simple rewrite rule.
This rule matches each box against a required codomain and, if
provided, a set of words. If they match, the word box is rewritten
into a set template.
"""
PLACEHOLDER_WORD = '<PLACEHOLDER>'
[docs]
def __init__(self,
cod: Ty,
template: Diagrammable,
words: Container[str] | None = None,
case_sensitive: bool = False) -> None:
"""Instantiate a simple rewrite rule.
Parameters
----------
cod : :py:class:`lambeq.backend.grammar.Ty`
The type that the codomain of each box is matched against.
template : :py:class:`lambeq.backend.grammar.Diagrammable`
The diagram that a matching box is replaced with. A special
placeholder box is replaced by the word in the matched box,
and can be created using
:py:meth:`SimpleRewriteRule.placeholder`.
words : container of str, optional
If provided, this is a list of words that are rewritten by
this rule. If a box does not have one of these words, it is
not rewritten, even if the codomain matches. If omitted,
all words are permitted.
case_sensitive : bool, default: False
This indicates whether the list of words specified above are
compared case-sensitively. The default is :py:obj:`False`.
"""
self.cod = cod
self.template = template
self.words = words
self.case_sensitive = case_sensitive
[docs]
def matches(self, box: Box) -> bool:
word = box.name if self.case_sensitive else box.name.lower()
return box.cod == self.cod and (self.words is None
or word in self.words)
[docs]
def rewrite(self, box: Box) -> Diagrammable:
def replace_placeholder(_, ar: Box) -> Box:
if ar.name == self.PLACEHOLDER_WORD:
return Word(box.name, ar.cod)
return ar
return Functor(target_category=grammar,
ob=lambda _, ob: ob,
ar=replace_placeholder)(self.template)
[docs]
@classmethod
def placeholder(cls, cod: Ty) -> Word:
"""Helper function to generate the placeholder for a template.
Parameters
----------
cod : :py:class:`lambeq.backend.grammar.Ty`
The codomain of the placeholder, and hence the word in the
resulting rewritten diagram.
Returns
-------
:py:class:`lambeq.backend.grammar.Word`
A placeholder word with the given codomain.
"""
return Word(cls.PLACEHOLDER_WORD, cod)
connector_rule = SimpleRewriteRule(
cod=S << S,
template=Cap(S, S.l),
words=['and', 'but', 'however', 'if', 'that', 'whether'])
determiner_rule = SimpleRewriteRule(cod=N << N,
words=['a', 'an', 'the'],
template=Cap(N, N.l))
postadverb_rule = SimpleRewriteRule(
cod=(N >> S) >> (N >> S),
template=(SimpleRewriteRule.placeholder(S >> S)
>> Id(S.r) @ Cap(N.r.r, N.r) @ Id(S)))
preadverb_rule = SimpleRewriteRule(
cod=(N >> S) << (N >> S),
template=(Cap(N.r, N)
>> Id(N.r) @ SimpleRewriteRule.placeholder(S << S) @ Id(N)))
auxiliary_rule = SimpleRewriteRule(
cod=preadverb_rule.cod,
template=Diagram.caps(preadverb_rule.cod[:2], preadverb_rule.cod[2:]),
words=['am', 'are', 'be', 'been', 'being', 'is', 'was', 'were',
'did', 'do', 'does',
"'d", 'had', 'has', 'have',
'may', 'might',
'will'])
prepositional_phrase_rule = SimpleRewriteRule(
cod=(N >> S) >> (N >> S << N),
template=(SimpleRewriteRule.placeholder(S >> S << N)
>> Id(S.r) @ Cap(N.r.r, N.r) @ Id(S @ N.l)))
_noun_loop = ((Cap(N.l, N.l.l) >> Swap(N.l, N.l.l)) @ Id(N)
>> Id(N.l.l) @ Cup(N.l, N))
object_rel_pronoun_rule = SimpleRewriteRule(
words=['that', 'which', 'who', 'whom', 'whose'],
cod=N.r @ N @ N.l.l @ S.l,
template=(Cap(N.r, N)
>> Id(N.r) @ Spider(N, 1, 2) @ Spider(S.l, 0, 1)
>> Id(N.r @ N) @ _noun_loop @ Id(S.l)))
subject_rel_pronoun_rule = SimpleRewriteRule(
words=['that', 'which', 'who', 'whom', 'whose'],
cod=N.r @ N @ S.l @ N,
template=(Cap(N.r, N)
>> Id(N.r) @ Spider(N, 1, 2)
>> Id(N.r @ N) @ Spider(S.l, 0, 1) @ Id(N)))
[docs]
class CoordinationRewriteRule(RewriteRule):
"""A rewrite rule for coordination.
This rule matches the word 'and' with codomain
:py:obj:`a.r @ a @ a.l` for pregroup type :py:obj:`a`, and replaces
the word, based on [Kar2016]_, with a layer of interleaving spiders.
"""
[docs]
def __init__(self, words: Container[str] | None = None) -> None:
"""Instantiate a CoordinationRewriteRule.
Parameters
----------
words : container of str, optional
A list of words to be rewritten by this rule. If a box does
not have one of these words, it will not be rewritten, even
if the codomain matches.
If omitted, the rewrite applies only to the word "and".
"""
self.words = ['and'] if words is None else words
[docs]
def matches(self, box: Box) -> bool:
if box.name in self.words and len(box.cod) % 3 == 0:
n = len(box.cod) // 3
left, mid, right = box.cod[:n], box.cod[n:2*n], box.cod[2*n:]
return bool(right.r == mid == left.l)
return False
[docs]
def rewrite(self, box: Box) -> Diagrammable:
n = len(box.cod) // 3
left, mid, right = box.cod[:n], box.cod[n:2*n], box.cod[2*n:]
assert right.r == mid == left.l
return (Diagram.caps(left, mid) @ Diagram.caps(mid, right)
>> Id(left) @ Spider(mid, 2, 1) @ Id(right))
[docs]
class CurryRewriteRule(RewriteRule):
"""A rewrite rule using map-state duality."""
[docs]
def __init__(self) -> None:
"""Instantiate a CurryRewriteRule.
This rule uses the map-state duality by iteratively
uncurrying on both sides of each box. When used in conjunction
with
:py:meth:`lambeq.backend.grammar.Diagram.pregroup_normal_form`,
this removes cups from the diagram in exchange for depth.
Diagrams with fewer cups become circuits with fewer
post-selection, which results in faster QML experiments.
"""
[docs]
def matches(self, box: Box) -> bool:
return bool(box.cod and (box.cod[0].z or box.cod[-1].z))
[docs]
def rewrite(self, box: Box) -> Diagrammable:
cod = box.cod
i = 0
while i < len(cod) and cod[i].z > 0:
i += 1
j = len(cod) - 1
while j >= 0 and cod[j].z < 0:
j -= 1
left, right = cod[:i], cod[j+1:]
dom = left.l @ box.dom @ right.r
new_box = Box(box.name, dom, cod[i:j+1])
if left:
new_box = new_box.curry(n=len(left), left=False)
if right:
new_box = new_box.curry(n=len(right), left=True)
return new_box
[docs]
class Rewriter:
"""Class that rewrites diagrams.
Comes with a set of default rules.
"""
_default_rules = {
'auxiliary': auxiliary_rule,
'connector': connector_rule,
'determiner': determiner_rule,
'postadverb': postadverb_rule,
'preadverb': preadverb_rule,
'prepositional_phrase': prepositional_phrase_rule,
}
_available_rules = {
**_default_rules,
'coordination': CoordinationRewriteRule(),
'curry': CurryRewriteRule(),
'object_rel_pronoun': object_rel_pronoun_rule,
'subject_rel_pronoun': subject_rel_pronoun_rule
}
[docs]
def __init__(self,
rules: Iterable[RewriteRule | str] | None = None) -> None:
"""Initialise a rewriter.
Parameters
----------
rules : iterable of str or RewriteRule, optional
A list of rewrite rules to use. :py:class:`RewriteRule`
instances are used directly, `str` objects are used as names
of the default rules. See
:py:meth:`Rewriter.available_rules` for the list of rule
names. If omitted, all the default rules are used.
"""
if rules is None:
self.rules: list[RewriteRule] = [*self._default_rules.values()]
else:
self.rules = []
self.add_rules(*rules)
self.apply_rewrites = Functor(target_category=grammar,
ob=self._ob,
ar=self._ar)
[docs]
@classmethod
def available_rules(cls) -> list[str]:
"""The list of default rule names."""
return [*cls._available_rules.keys()]
[docs]
def add_rules(self, *rules: RewriteRule | str) -> None:
"""Add rules to this rewriter."""
for rule in rules:
if isinstance(rule, RewriteRule):
self.rules.append(rule)
else:
try:
self.rules.append(self._available_rules[rule])
except KeyError as e:
raise ValueError(
f'`{rule}` is not a valid rewrite rule.'
) from e
[docs]
def __call__(self, diagram: Diagram) -> Diagram:
"""Apply the rewrite rules to the given diagram."""
return self.apply_rewrites(diagram)
def _ar(self, _: Functor, box: Box) -> Diagrammable:
for rule in self.rules:
rewritten_box = rule(box)
if rewritten_box is not None:
return rewritten_box
return box
def _ob(self, _: Functor, ob: Ty) -> Ty:
return ob
[docs]
class UnknownWordsRewriteRule(RewriteRule):
"""A rewrite rule for unknown words.
This rule matches any word not included in its vocabulary
and, when passed a diagram, replaces all the boxes
containing an unknown word with an `UNK` box corresponding to the
same pregroup type.
"""
[docs]
def __init__(self,
vocabulary: Container[str | tuple[str, Ty]],
unk_token: str = '<UNK>') -> None:
"""Instantiate an UnknownWordsRewriteRule.
Parameters
----------
vocabulary : container of str or tuple of str and Ty
A list of words (or words with specific output types) to not
be rewritten by this rule.
unk_token : str, default: '<UNK>'
The string to use for the UNK token.
"""
self.vocabulary = vocabulary
self.unk_token = unk_token
[docs]
def matches(self, box: Box) -> bool:
return (isinstance(box, Word)
and (box.name, box.cod) not in self.vocabulary
and box.name not in self.vocabulary)
[docs]
def rewrite(self, box: Box) -> Box:
return Word(self.unk_token, cod=box.cod)
[docs]
@classmethod
def from_diagrams(cls,
diagrams: Iterable[Diagram],
min_freq: int = 1,
unk_token: str = '<UNK>',
ignore_types: bool = False) -> UnknownWordsRewriteRule:
"""Create the rewrite rule from a set of diagrams.
The vocabulary is the set of words that occur at least
`min_freq` times throughout the set of diagrams.
Parameters
----------
diagrams : list of Diagram
Diagrams from which the vocabulary is created.
min_freq : int, default: 1
The minimum frequency required for a word to be included in
the vocabulary.
unk_token : str, default: '<UNK>'
The string to use for the UNK token.
ignore_types : bool, default: False
Whether to just consider the word when determining frequency
or to also consider the output type of the box (the default
behaviour).
"""
counts: Counter[str | tuple[str, Ty]] = Counter(
box.name if ignore_types else (box.name, box.cod)
for diagram in diagrams
for box in diagram.boxes
if isinstance(box, Word))
vocabulary = {word for word, count in counts.items()
if count >= min_freq}
return cls(vocabulary=vocabulary, unk_token=unk_token)