# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, replace
from enum import Enum
from functools import cached_property
from typing import Any
from lambeq.bobcat.lexicon import Atom, Category, Feature, Relation
@dataclass
class IndexedWord:
"""A word in a sentence, annotated with its position (1-indexed)."""
word: str
index: int
def __repr__(self) -> str:
return f'{self.word}_{self.index}'
@dataclass
class Dependency:
relation: Relation
head: IndexedWord
var: int
unary_rule_id: int
filler: IndexedWord | None = None
def replace(self,
var: int,
unary_rule_id: int | None = None) -> Dependency:
if unary_rule_id is None:
unary_rule_id = self.unary_rule_id
return replace(self, var=var, unary_rule_id=unary_rule_id)
@classmethod
def generate(cls,
cat: Category,
unary_rule_id: int,
head: IndexedWord | Variable) -> list[Dependency]:
if cat.relation:
if isinstance(head, IndexedWord):
deps = [cls(cat.relation, head, cat.var, unary_rule_id)]
else:
deps = [cls(cat.relation, filler, cat.var, unary_rule_id)
for filler in head.fillers]
else:
deps = []
if cat.complex:
for c in (cat.result, cat.argument):
deps += cls.generate(c, unary_rule_id, head)
return deps
def fill(self, var: Variable) -> list[Dependency]:
return [Dependency(self.relation,
self.head,
0,
self.unary_rule_id,
filler)
for filler in var.fillers]
def __str__(self) -> str:
return (f'{self.head} {self.relation} {self.filler} '
f'{self.unary_rule_id}')
@dataclass
class Variable:
fillers: list[IndexedWord]
filled: bool
def __init__(self, word: IndexedWord | None = None) -> None:
if word is not None:
self.fillers = [word]
else:
self.fillers = []
self.filled = True
def __add__(self, other: Any) -> Variable:
ret = Variable()
ret.fillers = self.fillers + other.fillers
return ret
def as_filled(self, filled: bool) -> Variable:
if filled == self.filled:
return self
ret = Variable()
ret.fillers = self.fillers
ret.filled = filled
return ret
@property
def filler(self) -> IndexedWord:
return self.fillers[0]
class Unify:
def __init__(self,
left: ParseTree,
right: ParseTree,
result_is_left: bool) -> None:
self.feature = Feature.NONE
self.num_variables = 1
self.trans_left: dict[int, int] = {}
self.trans_right: dict[int, int] = {}
self.old_left: dict[int, int] = {}
self.old_right: dict[int, int] = {}
self.left = left
self.right = right
self.result_is_left = result_is_left
if result_is_left:
self.res, self.arg = left.cat, right.cat
self.trans_res, self.trans_arg = self.trans_left, self.trans_right
else:
self.arg, self.res = left.cat, right.cat
self.trans_arg, self.trans_res = self.trans_left, self.trans_right
def unify(self, arg: Category, res: Category) -> bool:
if self.result_is_left:
left, right = res, arg
else:
left, right = arg, res
if not self.unify_recursive(left, right):
return False
self.add_vars(self.arg, self.trans_arg)
self.add_vars(self.res, self.trans_res)
return True
def unify_recursive(self, left: Category, right: Category) -> bool:
if left.atomic:
if left.atom != right.atom:
return False
if left.atom == Atom.S:
if left.feature == Feature.X:
self.feature = right.feature
elif right.feature == Feature.X:
self.feature = left.feature
elif left.feature != right.feature:
return False
else:
if not (left.dir == right.dir
and self.unify_recursive(left.result, right.result)
and self.unify_recursive(left.argument, right.argument)):
return False
if (left.var not in self.trans_left
and right.var not in self.trans_right):
try:
v1 = self.left.var_map[left.var]
v2 = self.right.var_map[right.var]
except KeyError:
pass
else:
if v1.filled and v2.filled:
return False
self.trans_left[left.var] = self.num_variables
self.trans_right[right.var] = self.num_variables
self.old_left[self.num_variables] = left.var
self.old_right[self.num_variables] = right.var
self.num_variables += 1
return True
def add_vars(self, cat: Category, trans: dict[int, int]) -> None:
old = self.old_left if trans is self.trans_left else self.old_right
for var in cat.vars:
if var not in trans:
trans[var] = self.num_variables
old[self.num_variables] = var
self.num_variables += 1
def get_new_outer_var(self) -> int:
return self.trans_left.get(self.left.cat.var, 0)
def translate_arg(self, category: Category) -> Category:
return category.translate(self.trans_arg, self.feature)
def translate_res(self, category: Category) -> Category:
return category.translate(self.trans_res, self.feature)
class Rule(Enum):
"""The possible CCG rules."""
NONE = 0
L = 1
U = 2
BA = 3
FA = 4
BC = 5
FC = 6
BX = 7
GBC = 8
GFC = 9
GBX = 10
LP = 11
RP = 12
BTR = 13
FTR = 14
CONJ = 15
ADJ_CONJ = 16
[docs]
@dataclass
class ParseTree:
rule: Rule
cat: Category
left: ParseTree
right: ParseTree
unfilled_deps: list[Dependency]
filled_deps: list[Dependency]
var_map: dict[int, Variable]
score: float = 0
@property
def word(self) -> str:
if self.is_leaf:
return self.variable.filler.word
else:
raise AttributeError('only leaves have words')
@property
def variable(self) -> Variable:
try:
return self.var_map[self.cat.var]
except KeyError as e:
raise AttributeError('variable is not in map') from e
@property
def is_leaf(self) -> bool:
return self.rule == Rule.L
@property
def coordinated_or_type_raised(self) -> bool:
return self.rule in (Rule.CONJ, Rule.BTR, Rule.FTR)
@property
def coordinated(self) -> bool:
return self.rule == Rule.CONJ
@property
def bwd_comp(self) -> bool:
return self.rule in (Rule.BC, Rule.GBC)
@property
def fwd_comp(self) -> bool:
return self.rule in (Rule.FC, Rule.GFC)
@cached_property
def deps_and_tags(self) -> tuple[list[Dependency],
list[str]]: # pragma: no cover
deps = self.filled_deps.copy()
tags = []
if self.left:
for child in (self.left, self.right):
if child:
child_deps, child_tags = child.deps_and_tags
deps += child_deps
tags += child_tags
else:
tags.append(str(self.cat).replace('[X]', ''))
deps.sort(key=lambda dep: (dep.head.index, dep.filler.index))
return deps, tags
@property
def deps(self) -> list[Dependency]:
return self.deps_and_tags[0]
def Lexical(cat: Category, word: str, index: int) -> ParseTree:
head = IndexedWord(word, index)
unfilled_deps = Dependency.generate(cat, 0, head)
assert cat.var
var_map = {cat.var: Variable(head)}
return ParseTree(Rule.L, cat, None, None, unfilled_deps, [], var_map)
def Coordination(cat: Category,
left: ParseTree,
right: ParseTree) -> ParseTree:
var_map = {k: v.as_filled(False) for k, v in right.var_map.items()}
unfilled_deps = right.unfilled_deps.copy()
try:
var = right.variable
except AttributeError:
pass
else:
if var.filled:
unfilled_deps.append(Dependency(Relation.CONJ,
left.variable.filler,
cat.argument.var,
0))
return ParseTree(Rule.CONJ, cat, left, right, unfilled_deps, [], var_map)
def TypeChanging(rule: Rule,
cat: Category,
left: ParseTree,
right: ParseTree,
unary_rule_id: int,
replace: bool) -> ParseTree:
head = left if rule != Rule.LP else right
try:
outer_var = head.variable
except AttributeError:
outer_var = None
unfilled_deps = []
if replace:
new_var = (cat.argument.argument.var
if Category.parse(r'(S\NP)\(S\NP)').matches(cat)
else cat.argument.var)
unfilled_deps = [d.replace(new_var, unary_rule_id)
for d in head.unfilled_deps
if d.var == head.cat.argument.var]
elif outer_var:
unfilled_deps = Dependency.generate(cat, unary_rule_id, outer_var)
if cat.var and outer_var:
var_map = {cat.var: outer_var}
else:
var_map = {}
return ParseTree(rule, cat, left, right, unfilled_deps, [], var_map)
def PassThrough(rule: Rule,
left: ParseTree,
right: ParseTree,
passthrough: ParseTree) -> ParseTree:
return ParseTree(rule,
passthrough.cat,
left,
right,
passthrough.unfilled_deps,
[],
passthrough.var_map)
def LeftPunct(left: ParseTree, right: ParseTree) -> ParseTree:
return PassThrough(Rule.LP, left, right, right)
def RightPunct(left: ParseTree, right: ParseTree) -> ParseTree:
return PassThrough(Rule.RP, left, right, left)
def AdjectivalConj(left: ParseTree, right: ParseTree) -> ParseTree:
return PassThrough(Rule.ADJ_CONJ, left, right, right)
def TypeRaising(cat: Category, left: ParseTree) -> ParseTree:
if cat.type_raising_dep_var:
unfilled_deps = [dep.replace(cat.type_raising_dep_var)
for dep in left.unfilled_deps]
else:
unfilled_deps = []
try:
var_map = {1: left.variable}
except AttributeError:
var_map = {}
rule = Rule.FTR if cat.fwd else Rule.BTR
return ParseTree(rule, cat, left, None, unfilled_deps, [], var_map)
def BinaryCombinator(rule: Rule,
cat: Category,
left: ParseTree,
right: ParseTree,
unification: Unify) -> ParseTree:
var_map = {}
for i in range(1, unification.num_variables):
left_var = left.var_map.get(unification.old_left.get(i))
right_var = right.var_map.get(unification.old_right.get(i))
if left_var is not None and right_var is not None:
var_map[i] = left_var + right_var
elif left_var is not None:
var_map[i] = left_var.as_filled(True)
elif right_var is not None:
var_map[i] = right_var.as_filled(True)
var_ids = []
for dep in left.unfilled_deps:
try:
var_ids.append((dep, unification.trans_left[dep.var]))
except KeyError:
continue
for dep in right.unfilled_deps:
try:
var_ids.append((dep, unification.trans_right[dep.var]))
except KeyError:
continue
unfilled_deps = []
filled_deps = []
for dep, v in var_ids:
var = var_map.get(v, None)
if var is not None and var.filled:
filled_deps += dep.fill(var)
else:
unfilled_deps.append(dep.replace(v))
return ParseTree(
rule, cat, left, right, unfilled_deps, filled_deps, var_map)