Source code for lambeq.bobcat.lexicon

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

__all__ = ['Atom', 'Feature', 'Relation', 'Category']

from collections.abc import Mapping
from dataclasses import dataclass
import re
from typing import Any, ClassVar, TYPE_CHECKING

from lambeq.bobcat.fast_int_enum import FastIntEnum


class Atom(FastIntEnum):
    """The possible atomic types for a category."""

    values = ['', 'N', 'NP', 'S', 'PP', 'conj', ',', ';', ':', '.', 'LQU',
              'RQU', 'LRB', 'RRB']
    names = ['NONE', 'N', 'NP', 'S', 'PP', 'CONJ', 'COMMA', 'SEMICOLON',
             'COLON', 'PERIOD']

    is_punct: bool


for atom in Atom._member_map_.values():
    if TYPE_CHECKING:
        from typing import cast
        atom = cast(Atom, atom)
    atom.is_punct = atom >= Atom.COMMA


class Feature(FastIntEnum):
    """The possible features for a category."""

    values = ['', 'X', 'adj', 'as', 'asup', 'b', 'bem', 'dcl', 'em', 'expl',
              'for', 'frg', 'intj', 'inv', 'nb', 'ng', 'num', 'poss', 'pss',
              'pt', 'q', 'qem', 'thr', 'to', 'wq']
    names = ['NONE']

    @property
    def is_free(self) -> bool:
        return self in (Feature.NONE, Feature.X)


@dataclass
class Relation:
    category: str
    slot: int

    def __repr__(self) -> str:
        return f'{self.category} {self.slot}'

    CONJ: ClassVar[Relation]


Relation.CONJ = Relation('conj', 1)


[docs]@dataclass class Category: r"""The type of a constituent in a CCG. A category may be atomic (e.g. N) or complex (e.g. S/NP). """ # atomic arguments atom: Atom = Atom.NONE feature: Feature = Feature.NONE # shared arguments var: int = 0 relation: Relation | None = None # complex arguments dir: str = '\0' result: Category | None = None argument: Category | None = None # in type raised categories only type_raising_dep_var: int = 0 def __post_init__(self) -> None: self.atomic = self.dir == '\0' self.complex = not self.atomic self.hash = self._hash() self.vars = set() if self.var: self.vars.add(self.var) if self.complex: self.vars.update(self.argument.vars, self.result.vars)
[docs] def slash(self, dir: str, argument: Category, var: int = 0, relation: Relation | None = None, type_raising_dep_var: int = 0) -> Category: """Create a complex category.""" return Category(Atom.NONE, Feature.NONE, var, relation, dir, self, argument, type_raising_dep_var)
[docs] def translate(self, var_map: Mapping[int, int], feature: Feature = Feature.NONE) -> Category: """Translate a category. Parameters ---------- var_map : dict of int to int A mapping to relabel variable slots. feature : Feature, optional The concrete feature for variable features. """ new_var = var_map[self.var] if self.atomic: if self.feature == Feature.X and feature != Feature.NONE: new_feature = feature else: new_feature = self.feature return Category(self.atom, new_feature, new_var, self.relation) else: result = self.result.translate(var_map, feature) argument = self.argument.translate(var_map, feature) return result.slash(self.dir, argument, new_var, self.relation)
def _str(self, full: bool = False, slot_counter: int = 0) -> tuple[str, int]: # pragma: no cover """Helper function to stringify a Category.""" if self.atomic: output = f'{self.atom}' if self.feature != Feature.NONE: output += f'[{self.feature}]' else: strings = [] for cat in (self.result, self.argument): string, slot_counter = cat._str(full, slot_counter) if cat.complex and not string.endswith(('}', '>')): string = f'({string})' strings.append(string) output = f'{strings[0]}{self.dir}{strings[1]}' if full and (self.var or self.relation): if self.complex: output = f'({output})' if self.var: output += f'{{{VARIABLES[self.var]}}}' if self.relation: slot_counter += 1 output += f'<{slot_counter}>' return output, slot_counter def __repr__(self) -> str: return self._str(full=True)[0] def __str__(self) -> str: return self._str()[0] def _hash(self) -> int: """Helper function to hash a Category.""" t: tuple[Any, ...] if self.atomic: t = (self.atom, Feature.NONE if self.feature == Feature.X else self.feature) else: t = (self.result, self.argument, self.dir) return hash(t) def __hash__(self) -> int: return self.hash def _equals(self, other: Category) -> bool: """Helper function to test Category equality.""" if self.hash != other.hash: return False if self.atomic: return (self.atom == other.atom and (self.feature == other.feature or (self.atom == Atom.S and self.feature.is_free and other.feature.is_free))) else: return (self.dir == other.dir and self.result._equals(other.result) and self.argument._equals(other.argument)) def __eq__(self, other: Any) -> bool: return (self is other or isinstance(other, Category) and self._equals(other)) def _matches(self, other: Category) -> bool: """Helper function to test Category pattern matching.""" if self.atomic: return (self.atom == other.atom and self.feature in (Feature.NONE, other.feature)) else: return (self.dir == other.dir and self.result._matches(other.result) and self.argument._matches(other.argument))
[docs] def matches(self, other: Any) -> bool: """Check if the template set out in this matches the argument. Like == but the NONE feature matches with everything. """ return (self is other or isinstance(other, Category) and self._matches(other))
@property def bwd(self) -> bool: """Whether this is a backward complex category.""" return self.dir == '\\' @property def fwd(self) -> bool: """Whether this is a forward complex category.""" return self.dir == '/'
[docs] @staticmethod def parse(string: str, type_raising_dep_var: str = '+') -> Category: """Parse a category string.""" return parse(string, type_raising_dep_var)
VAR_SLOT_REGEX = re.compile(r'(\{(?P<var>[_A-Z]+)\*?})?' r'(<(?P<slot>\d+)>)?', re.VERBOSE) CAT_REGEX = re.compile(r'(?P<atom>[A-Z]+|conj|[,.;:])' r'(\[(?P<feature>[Xa-z]+)])?' + VAR_SLOT_REGEX.pattern, re.VERBOSE) CATEGORIES: dict[tuple[str, int], Category] = {} VARIABLES = '+_YZWVUTRQAB' def parse_variable_id(string: str) -> int: assert len(string) == 1 return VARIABLES.index(string) def parse(string: str, type_raising_dep_var: str = '+') -> Category: var = parse_variable_id(type_raising_dep_var) try: return CATEGORIES[string, var] except KeyError: category, pos, _ = _parse(string, var) if pos != len(string): category, pos, _ = _parse(f'({string})', var) assert pos == len(string) + 2 CATEGORIES[string, var] = category return category def _parse(string: str, type_raising_dep_var: int, pos: int = 0, slots: int = 0, in_result: bool = True) -> tuple[Category, int, int]: if string[pos] == '(': left, pos, slots = _parse(string, 0, pos + 1, slots, in_result) dir = string[pos] pos += 1 if in_result: slots += 1 right, pos, slots = _parse(string, 0, pos, slots, False) assert string[pos] == ')' pos += 1 match = VAR_SLOT_REGEX.match(string, pos=pos) assert match, string[:pos] var = parse_variable_id(match['var']) if match['var'] else 0 relation = Relation(string, slots) if match['slot'] else None cat = left.slash(dir, right, var, relation, type_raising_dep_var) else: match = CAT_REGEX.match(string, pos=pos) assert match, string[pos:] feature = (Feature(match['feature']) if match['feature'] else Feature.NONE) var = parse_variable_id(match['var']) if match['var'] else 0 relation = Relation(string, slots) if match['slot'] else None cat = Category(Atom(match['atom']), feature, var, relation) return cat, match.end(), slots