# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
__all__ = ['CCGRule', 'CCGRuleUseError']
from collections.abc import Sequence
from enum import Enum
from typing import Any
from lambeq.backend.grammar import Diagram, Id
from lambeq.text2diagram.ccg_type import CCGType
[docs]
class CCGRuleUseError(Exception):
"""Error raised when a :py:class:`CCGRule` is applied incorrectly."""
[docs]
def __init__(self, rule: CCGRule, message: str) -> None:
self.rule = rule
self.message = message
def __str__(self) -> str:
return f'Illegal use of {self.rule}: {self.message}.'
[docs]
class CCGRule(str, Enum):
"""An enumeration of the available CCG rules."""
_symbol: str
UNKNOWN = 'UNK', ''
LEXICAL = 'L', ''
UNARY = 'U', '<U>'
FORWARD_APPLICATION = 'FA', '>'
BACKWARD_APPLICATION = 'BA', '<'
FORWARD_COMPOSITION = 'FC', '>B'
BACKWARD_COMPOSITION = 'BC', '<B'
FORWARD_CROSSED_COMPOSITION = 'FX', '>Bx'
BACKWARD_CROSSED_COMPOSITION = 'BX', '<Bx'
GENERALIZED_FORWARD_COMPOSITION = 'GFC', '>Bⁿ'
GENERALIZED_BACKWARD_COMPOSITION = 'GBC', '<Bⁿ'
GENERALIZED_FORWARD_CROSSED_COMPOSITION = 'GFX', '>Bxⁿ'
GENERALIZED_BACKWARD_CROSSED_COMPOSITION = 'GBX', '<Bxⁿ'
REMOVE_PUNCTUATION_LEFT = 'LP', '<p'
REMOVE_PUNCTUATION_RIGHT = 'RP', '>p'
FORWARD_TYPE_RAISING = 'FTR', '>T'
BACKWARD_TYPE_RAISING = 'BTR', '<T'
CONJUNCTION = 'CONJ', '<&>'
def __new__(cls, name: str, symbol: str = '') -> CCGRule:
obj = str.__new__(cls, name)
obj._value_ = name
obj._symbol = symbol
return obj
@property
def symbol(self) -> str:
"""The standard CCG symbol for the rule."""
if self == CCGRule.UNKNOWN:
raise CCGRuleUseError(self, 'unknown CCG rule')
else:
return self._symbol
@classmethod
def _missing_(cls, _: Any) -> CCGRule:
return cls.UNKNOWN
[docs]
def check_match(self, /, left: CCGType, right: CCGType) -> None:
"""Raise an exception if the two arguments do not match."""
if left != right:
raise CCGRuleUseError(self,
f'mismatched types - {left} != {right}')
[docs]
def resolve(self,
dom: Sequence[CCGType],
cod: CCGType) -> tuple[CCGType, ...]:
"""Perform type resolution on this rule use.
This is used to propagate any type changes that has occured in
the codomain to the domain, such that applying this rule to the
rewritten domain produces the provided codomain, while remaining
as compatible as possible with the provided domain.
Parameters
----------
dom : list of CCGType
The original domain of this rule use.
cod : CCGType
The required codomain of this rule use.
Returns
-------
tuple of CCGType
The rewritten domain.
"""
if self == CCGRule.UNKNOWN:
raise CCGRuleUseError(self, 'unknown CCG rule')
elif self == CCGRule.LEXICAL:
assert not dom
return ()
elif self == CCGRule.UNARY:
return cod,
elif self in (CCGRule.BACKWARD_TYPE_RAISING,
CCGRule.FORWARD_TYPE_RAISING):
return cod.argument.argument,
left, right = dom
new_left: CCGType | None
new_right: CCGType | None
if self == CCGRule.FORWARD_APPLICATION:
return cod << right, right
elif self == CCGRule.BACKWARD_APPLICATION:
return left, left >> cod
elif self == CCGRule.FORWARD_COMPOSITION:
self.check_match(left.right, right.left)
return cod.result << left.right, right.left << cod.argument
elif self == CCGRule.BACKWARD_COMPOSITION:
self.check_match(left.right, right.left)
return cod.argument >> left.right, right.left >> cod.result
elif self == CCGRule.FORWARD_CROSSED_COMPOSITION:
self.check_match(left.right, right.right)
return cod.right << left.right, cod.left >> right.right
elif self == CCGRule.BACKWARD_CROSSED_COMPOSITION:
self.check_match(left.left, right.left)
return left.left << cod.right, right.left >> cod.left
elif self == CCGRule.GENERALIZED_FORWARD_COMPOSITION:
ll, lr = left.left, left.right
new_right, new_left = cod.replace_result(ll, lr, '/')
assert new_left is not None
return new_left << left.right, new_right
elif self == CCGRule.GENERALIZED_BACKWARD_COMPOSITION:
rl, rr = right.left, right.right
new_left, new_right = cod.replace_result(rr, rl, '\\')
assert new_right is not None
return new_left, rl >> new_right
elif self == CCGRule.GENERALIZED_FORWARD_CROSSED_COMPOSITION:
ll, lr = left.left, left.right
new_right, new_left = cod.replace_result(ll, lr, r'\|')
assert new_left is not None
return new_left << lr, new_right
elif self == CCGRule.GENERALIZED_BACKWARD_CROSSED_COMPOSITION:
rl, rr = right.left, right.right
new_left, new_right = cod.replace_result(rr, rl, '/|')
assert new_right is not None
return new_left, right.left >> new_right
elif self == CCGRule.REMOVE_PUNCTUATION_LEFT:
return left, cod
elif self == CCGRule.REMOVE_PUNCTUATION_RIGHT:
return cod, right
elif self == CCGRule.CONJUNCTION:
if left.is_conjoinable:
return cod << right, right
elif right.is_conjoinable:
return left, left >> cod
else:
raise CCGRuleUseError(self, 'no conjunction found')
raise AssertionError('unreachable code')
[docs]
def __call__(self,
dom: Sequence[CCGType],
cod: CCGType | None = None) -> Diagram:
return self.apply(dom, cod)
[docs]
def apply(self,
dom: Sequence[CCGType],
cod: CCGType | None = None) -> Diagram:
"""Produce a lambeq diagram for this rule.
This is primarily used by CCG trees that have been resolved.
This means, for example, that diagrams cannot be produced for
the conjunction rule, since they are rewritten when resolved.
Parameters
----------
dom : list of CCGType
The domain of the diagram.
cod : CCGType, optional
The codomain of the diagram. This is only used for
type-raising rules.
Returns
-------
:py:class:`lambeq.backend.grammar.Diagram`
The resulting diagram.
Raises
------
CCGRuleUseError
If a diagram cannot be produced.
"""
if self == CCGRule.UNKNOWN:
raise CCGRuleUseError(self, 'unknown CCG rule')
elif self == CCGRule.LEXICAL:
raise CCGRuleUseError(self, 'lexical rules are not applicable')
elif self == CCGRule.CONJUNCTION:
raise CCGRuleUseError(
self, 'conjunctions should be resolved before drawing'
)
# unary rules
elif self in (CCGRule.UNARY,
CCGRule.BACKWARD_TYPE_RAISING,
CCGRule.FORWARD_TYPE_RAISING):
if len(dom) != 1:
raise CCGRuleUseError(
self, f'expected a domain of length 1, got {len(dom)}'
)
if self == CCGRule.UNARY:
return Id(dom[0].to_grammar())
# else type-raising rule
if cod is None:
raise CCGRuleUseError(
self,
'The codomain is required for type-raising rules.'
)
result = cod.result.to_grammar()
if self == CCGRule.BACKWARD_TYPE_RAISING:
return Id(dom[0].to_grammar()) @ Diagram.caps(result.r, result)
else:
return Diagram.caps(result, result.l) @ Id(dom[0].to_grammar())
# binary rules
if len(dom) != 2:
raise CCGRuleUseError(
self, f'expected a domain of length 2, got {len(dom)}'
)
left, right = dom
if self == CCGRule.FORWARD_APPLICATION:
# X/Y + Y -> X
# X @ Y.l + Y -> X
return Diagram.fa(left.result.to_grammar(), right.to_grammar())
elif self == CCGRule.BACKWARD_APPLICATION:
# Y + X\Y -> X
# Y + Y.r @ X -> X
return Diagram.ba(left.to_grammar(), right.result.to_grammar())
elif self == CCGRule.FORWARD_COMPOSITION:
# X/Y + Y/Z -> X/Z
# X @ Y.l + Y @ Z.l -> X @ Z.l
return Diagram.fc(left.left.to_grammar(),
left.right.to_grammar(),
right.right.to_grammar())
elif self == CCGRule.BACKWARD_COMPOSITION:
# Z\Y + X\Y -> X\Z
# Z.r @ Y + Y.r @ X -> Z.r @ X
return Diagram.bc(left.left.to_grammar(),
left.right.to_grammar(),
right.right.to_grammar())
elif self == CCGRule.FORWARD_CROSSED_COMPOSITION:
# X/Y + Y\Z -> X\Z
# X @ Y.l + Z.r @ Y -> Z.r @ X
return Diagram.fx(left.left.to_grammar(),
left.right.to_grammar(),
right.left.to_grammar())
elif self == CCGRule.BACKWARD_CROSSED_COMPOSITION:
# Y/Z + X\Y -> X/Z
# Y @ Z.l + Y.r @ X -> X @ Z.l
return Diagram.bx(left.right.to_grammar(),
left.left.to_grammar(),
right.right.to_grammar())
elif self == CCGRule.GENERALIZED_FORWARD_COMPOSITION:
# X/Y + (Y/Z)/... -> (X/Z)/...
# X @ Y.l + Y @ Z.l @ ... -> X @ Z.l @ ...
mid = left.argument.to_grammar()
return (Id(left.result.to_grammar())
[docs]
@ Diagram.cups(mid.l, mid)
@ Id(right.to_grammar()[len(mid):]))
elif self == CCGRule.GENERALIZED_BACKWARD_COMPOSITION:
# (Y\Z)\... + X\Y -> (X\Z)\...
# ... @ Z.r @ Y + Y.r @ X -> ... @ Z.r @ X
mid = right.argument.to_grammar()
return (Id(left.to_grammar()[:-len(mid)])
@ Diagram.cups(mid, mid.r)
@ Id(right.result.to_grammar()))
elif self == CCGRule.GENERALIZED_FORWARD_CROSSED_COMPOSITION:
# X/Y + (Y\Z)|... -> (X\Z)|...
# X @ Y.l + ... @ Z.r @ Y @ ... -> ... @ Z.r @ X @ ...
mid = left.left.to_grammar()
l, join, r = right.split(left.right)
return (
Diagram.swap(mid << join, l) @ Id(join)
>> Id(l @ mid) @ Diagram.cups(join.l, join)
) @ Id(r)
elif self == CCGRule.GENERALIZED_BACKWARD_CROSSED_COMPOSITION:
# (Y/Z)|... + X\Y -> (X/Z)|...
# ... @ Y @ Z.l @ ... + Y.r @ X -> ... @ X @ Z.l @ ...
mid = right.right.to_grammar()
l, join, r = left.split(right.left)
return Id(l) @ (
Id(join) @ Diagram.swap(r, join >> mid)
>> Diagram.cups(join, join.r) @ Id(mid @ r)
)
elif self == CCGRule.REMOVE_PUNCTUATION_LEFT:
# punc + X -> X
return Id(right.to_grammar())
elif self == CCGRule.REMOVE_PUNCTUATION_RIGHT:
# X + punc -> X
return Id(left.to_grammar())
raise AssertionError('unreachable code')
@classmethod
def infer_rule(cls, dom: Sequence[CCGType], cod: CCGType) -> CCGRule:
"""Infer the CCG rule that admits the given domain and codomain.
Return :py:attr:`CCGRule.UNKNOWN` if no other rule matches.
Parameters
----------
dom : list of CCGType
The domain of the rule.
cod : CCGType
The codomain of the rule.
Returns
-------
CCGRule
A CCG rule that admits the required domain and codomain.
"""
if not dom:
return CCGRule.LEXICAL
elif len(dom) == 1:
if cod.is_complex:
if cod == cod.result.over(cod.result.under(dom[0])):
return CCGRule.FORWARD_TYPE_RAISING
if cod == cod.result.under(cod.result.over(dom[0])):
return CCGRule.BACKWARD_TYPE_RAISING
return CCGRule.UNARY
elif len(dom) == 2:
left, right = dom
if left == CCGType.PUNCTUATION:
if cod == right >> right:
return CCGRule.CONJUNCTION
else:
return CCGRule.REMOVE_PUNCTUATION_LEFT
if right == CCGType.PUNCTUATION:
if cod == left << left:
return CCGRule.CONJUNCTION
else:
return CCGRule.REMOVE_PUNCTUATION_RIGHT
if left == cod << right:
return CCGRule.FORWARD_APPLICATION
if right == left >> cod:
return CCGRule.BACKWARD_APPLICATION
if CCGType.CONJUNCTION in (left, right):
return CCGRule.CONJUNCTION
if cod.is_complex and left.is_complex and right.is_complex:
ll = left.left
lr = left.right
rl = right.left
rr = right.right
if lr == rl and (cod.left, cod.right) == (ll, rr):
if cod.is_over and left.is_over and right.is_over:
return CCGRule.FORWARD_COMPOSITION
if cod.is_under and left.is_under and right.is_under:
return CCGRule.BACKWARD_COMPOSITION
if right.is_under:
if left.is_over and ll == rl and cod == rr << lr:
return CCGRule.BACKWARD_CROSSED_COMPOSITION
if left.replace_result(rl, rr, '\\') == (cod, rl):
return CCGRule.GENERALIZED_BACKWARD_COMPOSITION
if left.replace_result(rl, rr, '/|') == (cod, rl):
return CCGRule.GENERALIZED_BACKWARD_CROSSED_COMPOSITION
if left.is_over:
if right.is_under and lr == rr and cod == rl >> ll:
return CCGRule.FORWARD_CROSSED_COMPOSITION
if right.replace_result(lr, ll, '/') == (cod, lr):
return CCGRule.GENERALIZED_FORWARD_COMPOSITION
if right.replace_result(lr, ll, r'\|') == (cod, lr):
return CCGRule.GENERALIZED_FORWARD_CROSSED_COMPOSITION
return CCGRule.UNKNOWN