Source code for lambeq.text2diagram.ccg_tree

# Copyright 2021-2023 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

__all__ = ['CCGTree']

from collections.abc import Iterable
from copy import deepcopy
import json
from typing import Any, Dict
from typing import overload

from discopy import rigid, Word
from discopy.biclosed import biclosed2rigid_ob, unaryBoxConstructor
from discopy.biclosed import Box, Diagram, Functor, Id, Over, Ty, Under

from lambeq.text2diagram.ccg_rule import CCGRule, GBC, GBX, GFC, GFX
from lambeq.text2diagram.ccg_types import (biclosed2str, replace_cat_result,
                                           str2biclosed)
from lambeq.text2diagram.ccg_types import CCGAtomicType

# Types
_JSONDictT = Dict[str, Any]


class UnarySwap(unaryBoxConstructor('cod'), Box):  # type: ignore[misc]
    """Unary Swap box, for unary rules that change the type direction.

    For example, the unary rule `S/NP -> NP\\NP` is handled by
    `UnarySwap(n >> n)`, which becomes `Swap(n @ n.r)` in a rigid
    diagram.

    """

    def __init__(self, cod: Ty) -> None:
        if isinstance(cod, Over):
            # This defines a swap from Y.l @ X to X @ Y.l
            # since:
            #           Ty() << Y        -> Y.l
            #  Ty() << (Ty() << Y)       -> Y.l.l
            # (Ty() << (Ty() << Y)) >> X -> Y.l.l.r @ X = Y.l @ X
            dom = (Ty() << (Ty() << cod.left)) >> cod.right
        elif isinstance(cod, Under):
            # Similarly, this defines a swap from Y @ X.r to X.r @ Y
            dom = cod.left << ((cod.right >> Ty()) >> Ty())
        else:
            raise ValueError(f'Invalid type for unary swap: {cod}')
        super().__init__(f'Swap({dom} -> {cod})', dom, cod)


class PlanarBX(Box):
    """Planar Backward Crossed Composition Box."""
    def __init__(self, dom: Ty, diagram: Diagram) -> None:
        assert isinstance(dom, Over)
        assert not diagram.dom

        right = diagram.cod
        assert isinstance(right, Under)
        assert right.left == dom.left

        self.diagram = diagram
        cod = right.right << dom.right
        super().__init__(f'PlanarBX({dom}, {diagram})', dom, cod)


class PlanarFX(Box):
    """Planar Forward Crossed Composition Box."""
    def __init__(self, dom: Ty, diagram: Diagram) -> None:
        assert isinstance(dom, Under)
        assert not diagram.dom

        left = diagram.cod
        assert isinstance(left, Over)
        assert left.right == dom.right

        self.diagram = diagram
        cod = dom.left >> left.left
        super().__init__(f'PlanarFX({dom}, {diagram})', dom, cod)


class PlanarGBX(Box):
    def __init__(self, dom: Ty, diagram: Diagram):
        assert not diagram.dom

        right = diagram.cod
        assert isinstance(right, Under)

        cod, original = replace_cat_result(dom, right.left, right.right, '<|')
        assert original == right.left

        self.diagram = diagram
        super().__init__(f'PlanarGBX({dom}, {diagram})', dom, cod)


class PlanarGFX(Box):
    def __init__(self, dom: Ty, diagram: Diagram):
        assert not diagram.dom

        left = diagram.cod
        assert isinstance(left, Over)

        cod, original = replace_cat_result(dom, left.right, left.left, '>|')
        assert original == left.right

        self.diagram = diagram
        super().__init__(f'PlanarGFX({dom}, {diagram})', dom, cod)


[docs]class CCGTree: """Derivation tree for a CCG. This provides a standard derivation interface between the parser and the rest of the model. """
[docs] def __init__(self, text: str | None = None, *, rule: CCGRule | str = CCGRule.UNKNOWN, biclosed_type: Ty, children: Iterable[CCGTree] | None = None) -> None: """Initialise a CCG tree. Parameters ---------- text : str, optional The word or phrase associated to the whole tree. If :py:obj:`None`, it is inferred from its children. rule : CCGRule, default: CCGRule.UNKNOWN The final :py:class:`.CCGRule` used in the derivation. biclosed_type : discopy.biclosed.Ty The type associated to the derived phrase. children : list of CCGTree, optional A list of JSON subtrees. The types of these subtrees can be combined with the :py:obj:`rule` to produce the output :py:obj:`type`. A leaf node has an empty list of children. """ self._text = text self.rule = CCGRule(rule) self.biclosed_type = biclosed_type self.children = list(children) if children is not None else [] n_children = len(self.children) child_requirements = {CCGRule.LEXICAL: 0, CCGRule.UNARY: 1, CCGRule.FORWARD_TYPE_RAISING: 1, CCGRule.BACKWARD_TYPE_RAISING: 1} required_children = child_requirements.get(self.rule, 2) if self.rule != CCGRule.UNKNOWN and n_children != required_children: raise ValueError('Invalid number of children for rule ' f'`{self.rule}`: expected {required_children}, ' f'got {n_children}.') if text and not children: self.rule = CCGRule.LEXICAL self.is_leaf = len(self.children) == 0 self.is_unary = len(self.children) == 1 self.is_binary = len(self.children) == 2
@property def text(self) -> str: """The word or phrase associated to the tree.""" if self._text is None: self._text = ' '.join(child.text for child in self.children) return self._text @property def child(self) -> CCGTree: """Get the child of a unary tree.""" if not self.is_unary: raise ValueError('Cannot get the child of a non-unary tree.') return self.children[0] @property def left(self) -> CCGTree: """Get the left child of a binary tree.""" if not self.is_binary: raise ValueError('Cannot get the left child of a non-binary tree.') return self.children[0] @property def right(self) -> CCGTree: """Get the right child of a binary tree.""" if not self.is_binary: raise ValueError( 'Cannot get the right child of a non-binary tree.') return self.children[1] @overload @classmethod def from_json(cls, data: None) -> None: ... @overload @classmethod def from_json(cls, data: _JSONDictT | str) -> CCGTree: ...
[docs] @classmethod def from_json(cls, data: _JSONDictT | str | None) -> CCGTree | None: """Create a :py:class:`CCGTree` from a JSON representation. A JSON representation of a derivation contains the following fields: `text` : :py:obj:`str` or :py:obj:`None` The word or phrase associated to the whole tree. If :py:obj:`None`, it is inferred from its children. `rule` : :py:class:`.CCGRule` The final :py:class:`.CCGRule` used in the derivation. `type` : :py:class:`discopy.biclosed.Ty` The type associated to the derived phrase. `children` : :py:class:`list` or :py:class:`None` A list of JSON subtrees. The types of these subtrees can be combined with the :py:obj:`rule` to produce the output :py:obj:`type`. A leaf node has an empty list of children. """ if data is None: return None data_dict = json.loads(data) if isinstance(data, str) else data return cls(text=data_dict.get('text'), rule=data_dict.get('rule', CCGRule.UNKNOWN), biclosed_type=str2biclosed(data_dict['type']), children=[cls.from_json(child) for child in data_dict.get('children', [])])
[docs] def without_trivial_unary_rules(self) -> CCGTree: """Create a new CCGTree from the current tree, with all trivial unary rules (i.e. rules that map X to X) removed. This might happen because there is no exact correspondence between CCG types and pregroup types, e.g. both CCG types `NP` and `N` are mapped to the same pregroup type `n`. Returns ------- :py:class:`lambeq.text2diagram.CCGTree` A new tree free of trivial unary rules. """ new_tree = deepcopy(self) while (new_tree.rule == CCGRule.UNARY and (new_tree.biclosed_type == new_tree.children[0].biclosed_type)): new_tree = new_tree.children[0] new_tree.children = [child.without_trivial_unary_rules() for child in new_tree.children] return new_tree
[docs] def to_json(self) -> _JSONDictT: """Convert tree into JSON form.""" if self is None: # Allows doing CCGTree.to_json(X) for optional X return None # type: ignore[unreachable] data: _JSONDictT = {'type': biclosed2str(self.biclosed_type)} if self.rule != CCGRule.UNKNOWN: data['rule'] = self.rule.value if self.text != ' '.join(child.text for child in self.children): data['text'] = self.text if self.children: data['children'] = [child.to_json() for child in self.children] return data
def __eq__(self, other: Any) -> bool: return (isinstance(other, CCGTree) and self.text == other.text and self.rule == other.rule and self.biclosed_type == other.biclosed_type and len(self.children) == len(other.children) and all(c1 == c2 for c1, c2 in zip(self.children, other.children))) def __repr__(self) -> str: return f'{type(self).__name__}({self.text!r})' def _vert_deriv(self, chr_set: dict[str, str], use_slashes: bool, _prefix: str = '') -> str: # pragma: no cover """Create a vertical string representation of the CCG tree.""" output_type = biclosed2str(self.biclosed_type, not use_slashes) if self.rule == CCGRule.LEXICAL: deriv = f' {output_type} {chr_set["SUCH_THAT"]} {repr(self.text)}' else: deriv = (f'{self.rule.value}: {output_type} ' f'{chr_set["LEFT_ARROW"]} ' + ' + '.join(biclosed2str(child.biclosed_type, not use_slashes) for child in self.children)) deriv = f'{_prefix}{deriv}' if self.children: if _prefix: _prefix = _prefix[:-1] + (f'{chr_set["BAR"]} ' if _prefix[-1] == chr_set['JOINT'] else ' ') for child in self.children[:-1]: deriv += '\n' + child._vert_deriv(chr_set, use_slashes, _prefix + chr_set['JOINT']) deriv += '\n' + self.children[-1]._vert_deriv( chr_set, use_slashes, _prefix + chr_set['CORNER']) return deriv def _horiz_deriv(self, chr_set: dict[str, str], word_spacing: int, use_slashes: bool) -> str: # pragma: no cover """Create a standard CCG diagram for the tree with the words arranged horizontally.""" output_type = biclosed2str(self.biclosed_type, not use_slashes) if output_type.startswith('('): output_type = output_type[1:-1] if self.rule == CCGRule.LEXICAL: width = max(len(output_type), len(self.text)) return (f'{self.text:{" "}^{width}}\n' f'{chr_set["HEAVY_LINE"] * width}\n' f'{output_type:{" "}^{width}}') lines: list[str] = [] for child in self.children: child_deriv = child._horiz_deriv(chr_set, word_spacing, use_slashes).split('\n') for lidx in range(max(len(lines), len(child_deriv))): if lidx < len(lines) and lidx < len(child_deriv): lines[lidx] += (' ' * word_spacing) + child_deriv[lidx] elif lidx < len(lines): lines[lidx] += ' ' * (len(lines[lidx - 1]) - len(lines[lidx])) else: target_len = len(lines[lidx - 1]) if lidx > 0 else 0 lines.append(f'{child_deriv[lidx]:{" "}>{target_len}}') target_len = len(lines[lidx - 1]) if lidx > 0 else 0 if len(output_type) > target_len: target_len = len(output_type) for i in range(len(lines)): lines[i] = f'{lines[i]:{" "}^{target_len}}' rule_symbol = self.rule.symbol lines.append(f'{rule_symbol:{chr_set["LINE"]}>{target_len}}') lines.append(f'{output_type:{" "}^{target_len}}') return '\n'.join(lines)
[docs] def deriv(self, word_spacing: int = 2, use_slashes: bool = True, use_ascii: bool = False, vertical: bool = False) -> str: # pragma: no cover """Produce a string representation of the tree. Parameters ---------- word_spacing : int, default: 2 The minimum number of spaces between the words of the diagram. Only used for horizontal diagrams. use_slashes: bool, default: True Whether to use slashes in the CCG types instead of arrows. Automatically set to True when `use_ascii` is True. use_ascii: bool, default: False Whether to draw using ASCII characters only. vertical: bool, default: False Whether to create a vertical tree representation, instead of the standard horizontal one. Returns ------- str A string that contains the graphical representation of the CCG tree. """ UNICODE_CHAR_SET = { 'HEAVY_LINE': '═', 'LINE': '─', 'BAR': '│', 'SUCH_THAT': '∋', 'JOINT': '├', 'LEFT_ARROW': '←', 'CORNER': '└' } ASCII_CHAR_SET = { 'HEAVY_LINE': '=', 'LINE': '-', 'BAR': '│', 'SUCH_THAT': '<-', 'JOINT': '├', 'LEFT_ARROW': '<-', 'CORNER': '└' } if use_ascii: use_slashes = True chr_set = UNICODE_CHAR_SET if not use_ascii else ASCII_CHAR_SET if not vertical: deriv = self._horiz_deriv(chr_set, word_spacing, use_slashes) else: deriv = self._vert_deriv(chr_set, use_slashes, '') return deriv
[docs] def to_biclosed_diagram(self, planar: bool = False) -> Diagram: """Convert tree to a derivation in DisCoPy form. Parameters ---------- planar : bool, default: False Force the diagram to be planar. This only affects trees using cross composition. """ words, grammar = self._to_biclosed_diagram(planar) return words >> grammar
def _to_biclosed_diagram( self, planar: bool = False, resolved_output: Ty | None = None) -> tuple[Diagram, Diagram]: biclosed_type = resolved_output or self.biclosed_type if self.rule == CCGRule.LEXICAL: word = Box(self.text, Ty(), biclosed_type) return word, Id(biclosed_type) if (self.rule == CCGRule.UNARY and not planar and {type(biclosed_type), type(self.children[0].biclosed_type)} == {Over, Under}): this_layer = UnarySwap(biclosed_type) elif (self.rule == CCGRule.FORWARD_COMPOSITION and self.children[0].rule == CCGRule.FORWARD_TYPE_RAISING): left = biclosed_type.left right = biclosed_type.right mid = (self.children[0].biclosed_type.right.left) >> left this_layer = self.rule((left << mid) @ (mid << right), biclosed_type) else: child_types = [child.biclosed_type for child in self.children] this_layer = self.rule(Ty.tensor(*child_types), biclosed_type) children = [child._to_biclosed_diagram(planar, this_layer.dom[i:i+1]) for i, child in enumerate(self.children)] if planar and self.rule == CCGRule.BACKWARD_CROSSED_COMPOSITION: (words, left_diag), (right_words, right_diag) = children diag = (left_diag >> PlanarBX(left_diag.cod, right_words >> right_diag)) elif planar and self.rule == CCGRule.FORWARD_CROSSED_COMPOSITION: (left_words, left_diag), (words, right_diag) = children diag = (right_diag >> PlanarFX(right_diag.cod, left_words >> left_diag)) elif planar and (self.rule == CCGRule.GENERALIZED_BACKWARD_CROSSED_COMPOSITION): (words, left_diag), (right_words, right_diag) = children diag = (left_diag >> PlanarGBX(left_diag.cod, right_words >> right_diag)) elif planar and (self.rule == CCGRule.GENERALIZED_FORWARD_CROSSED_COMPOSITION): (left_words, left_diag), (words, right_diag) = children diag = (right_diag >> PlanarGFX(right_diag.cod, left_words >> left_diag)) else: words, diag = [Diagram.tensor(*d) for d in zip(*children)] diag >>= this_layer return words, diag
[docs] def to_diagram(self, planar: bool = False) -> rigid.Diagram: """Convert tree to a DisCoCat diagram. Parameters ---------- planar : bool, default: False Force the diagram to be planar. This only affects trees using cross composition. """ def ob_func(ob: Ty) -> rigid.Ty: return (rigid.Ty() if ob == CCGAtomicType.PUNCTUATION else biclosed2rigid_ob(ob)) def ar_func(box: Box) -> rigid.Diagram: # special box -> special diagram # RemovePunctuation box -> identity wire(s) # punctuation -> empty diagram # word box -> Word def split(cat: Ty, base: Ty) -> tuple[rigid.Ty, rigid.Ty, rigid.Ty]: left = right = rigid.Ty() while cat != base: if isinstance(cat, Over): right = to_rigid_diagram(cat.right).l @ right cat = cat.left else: left @= to_rigid_diagram(cat.left).r cat = cat.right return left, to_rigid_diagram(cat), right if isinstance(box, UnarySwap): cod = ob_func(box.cod) return rigid.Swap(cod[1:], cod[:1]) if isinstance(box, PlanarBX): join = to_rigid_diagram(box.dom.left) right = to_rigid_diagram(box.dom)[len(join):] inner = to_rigid_diagram(box.diagram) cups = rigid.cups(join, join.r) return (Id(join) @ inner >> cups @ Id(inner.cod[len(join):])) @ Id(right) if isinstance(box, PlanarFX): join = to_rigid_diagram(box.dom.right) left = to_rigid_diagram(box.dom)[:-len(join)] inner = to_rigid_diagram(box.diagram) cups = rigid.cups(join.l, join) return Id(left) @ (inner @ Id(join) >> Id(inner.cod[:-len(join)]) @ cups) if isinstance(box, PlanarGBX): left, join, right = split(box.dom, box.diagram.cod.left) inner = to_rigid_diagram(box.diagram) cups = rigid.cups(join, join.r) mid = (Id(join) @ inner) >> (cups @ Id(inner.cod[len(join):])) return Id(left) @ mid @ Id(right) if isinstance(box, PlanarGFX): left, join, right = split(box.dom, box.diagram.cod.right) inner = to_rigid_diagram(box.diagram) cups = rigid.cups(join.l, join) mid = (inner @ Id(join)) >> (Id(inner.cod[:-len(join)]) @ cups) return Id(left) @ mid @ Id(right) if isinstance(box, GBC): left = to_rigid_diagram(box.dom[0]) mid = to_rigid_diagram(box.dom[1].left) right = to_rigid_diagram(box.dom[1].right) return (Id(left[:-len(mid)]) @ rigid.cups(mid, mid.r) @ Id(right)) if isinstance(box, GFC): left = to_rigid_diagram(box.dom[0].left) mid = to_rigid_diagram(box.dom[0].right) right = to_rigid_diagram(box.dom[1]) return Id(left) @ rigid.cups(mid.l, mid) @ Id(right[len(mid):]) if isinstance(box, GBX): mid = to_rigid_diagram(box.dom[1].right) left, join, right = split(box.dom[0], box.dom[1].left) swaps = rigid.Diagram.swap(right, join >> mid) cups = rigid.cups(join, join.r) return Id(left) @ (Id(join) @ swaps >> cups @ Id(mid @ right)) if isinstance(box, GFX): mid = to_rigid_diagram(box.dom[0].left) left, join, right = split(box.dom[1], box.dom[0].right) cups = rigid.cups(join.l, join) return (rigid.Diagram.swap(mid << join, left) @ Id(join) >> Id(left @ mid) @ cups) @ Id(right) cod = to_rigid_diagram(box.cod) return Id(cod) if box.dom or not cod else Word(box.name, cod) to_rigid_diagram = Functor(ob=ob_func, ar=ar_func, ob_factory=rigid.Ty, ar_factory=rigid.Diagram) return to_rigid_diagram(self.to_biclosed_diagram(planar=planar))