Source code for lambeq.bobcat.parser

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
import math
from typing import overload, Tuple

from lambeq.bobcat.grammar import Grammar
from lambeq.bobcat.lexicon import Atom, Category
from lambeq.bobcat.lexicon import CATEGORIES
from lambeq.bobcat.rules import Rules
from lambeq.bobcat.tree import Dependency, Lexical, ParseTree

SpanT = Tuple[int, int]

NEGATIVE_INFINITY = -float('inf')


[docs]@dataclass class Supertag: """A string category, annotated with its log probability.""" category: str probability: float
[docs]@dataclass class Sentence: """An input sentence. Attributes ---------- words : list of str The tokens in the sentence. input_supertags : list of list of Supertag A list of supertags for each word. span_scores : dict of tuple of int and int to dict of int to float Mapping of a span to a dict of category (indices) mapped to their log probability. """ words: list[str] input_supertags: list[list[Supertag]] span_scores: dict[SpanT, dict[int, float]] def __post_init__(self) -> None: if len(self.words) != len(self.input_supertags): raise ValueError( '`words` must be the same length as `input_supertags`') def __len__(self) -> int: return len(self.words)
@dataclass class Cell: """A cell in the chart. The cell maintains a list of trees in sorted order, up to the beam size (though may be larger if there are ties at the bottom), with the further restriction that only one tree is allowed per category. """ beam_size: int trees: list[ParseTree] = field(default_factory=list) trees_map: dict[Category, ParseTree] = field(default_factory=dict) min_score: float = NEGATIVE_INFINITY def find(self, score: float) -> int: """Find the index where a tree with the given score can go.""" trees = self.trees lo = 0 hi = len(trees) while lo < hi: mid = (lo + hi) // 2 cmp = trees[mid].score if score == cmp: return mid elif score > cmp: hi = mid else: lo = mid + 1 return lo def add(self, to_add: Iterable[ParseTree]) -> int: """Add the trees to the cell. For each tree that is to be added, it is checked against the existing trees to determine whether it should be added, and if so, is added using a binary search; then, the beam is applied. """ to_add = sorted(to_add, key=lambda tree: -tree.score) trees = self.trees trees_map = self.trees_map added = 0 b = self.beam_size for tree in to_add: score = tree.score if len(trees) >= b and score < trees[-1].score: break # Check whether there exists a tree with the same category. # If there does, and it has a lower score, then remove the # old tree before inserting the new tree. # If the score is higher, then do nothing. insert: bool try: old_tree = trees_map[tree.cat] except KeyError: insert = True else: old_score = old_tree.score insert = score > old_score if insert: old_index = self.find(old_score) deleted = False for i in range(old_index, len(trees)): if trees[i] is old_tree: del trees[i] deleted = True break elif trees[i].score != old_score: break if not deleted: for i in reversed(range(old_index)): if trees[i] is old_tree: del trees[i] break if insert: trees.insert(self.find(score), tree) trees_map[tree.cat] = tree added += 1 try: cutoff = self.min_score = trees[b - 1].score if trees[b].score < cutoff: added -= len(trees) - b for tree in trees[b:]: del trees_map[tree.cat] del trees[b:] except IndexError: pass return added @dataclass class Chart: """The parse chart, containing a mapping from span to cell. A span (i, j) represents the phrase from the ith word to the jth word (inclusive), indexed from 0. """ beam_size: int chart: dict[SpanT, Cell] = field(default_factory=dict) parse_tree_count: int = 0 def __getitem__(self, index: SpanT) -> list[ParseTree]: return self.chart[index].trees def min_score(self, start: int, end: int) -> float: """Get the lowest score needed to add a tree to the given cell.""" try: return self.chart[start, end].min_score except KeyError: return NEGATIVE_INFINITY def add(self, start: int, end: int, to_add: Iterable[ParseTree]) -> None: """Add parse trees to the cell in the chart.""" if not to_add: return try: cell = self.chart[start, end] except KeyError: cell = self.chart[start, end] = Cell(self.beam_size) self.parse_tree_count += cell.add(to_add) @dataclass class ParseResult: """The result of a parse. This acts as a list of the most probable parse trees, in order, i.e. use `parse_result[0]` to access the most probable parse tree. Parameters ---------- chart : Chart The parse chart. Attributes ---------- words : list[str] The words in the sentence. root : list[str] The most probable parse trees, in order. """ chart: Chart words: list[str] = field(init=False) root: list[ParseTree] = field(init=False) def __post_init__(self) -> None: self.words = [] while True: try: tree = self.chart[len(self.words), len(self.words)][0] except KeyError: break else: while tree.left: tree = tree.left self.words.append(tree.word) try: self.root = self.chart[0, len(self.words) - 1] except KeyError: self.root = [] def __bool__(self) -> bool: return len(self.root) != 0 def __len__(self) -> int: return len(self.root) @overload def __getitem__(self, index: int) -> ParseTree: ... @overload def __getitem__(self, index: slice) -> list[ParseTree]: ... def __getitem__(self, index: int | slice) -> ParseTree | list[ParseTree]: return self.root[index] def deps( self, tree: ParseTree | None = None ) -> tuple[list[Dependency], list[str]]: # pragma: no cover """Get the dependencies and output tags of the parse. If `tree` is not specified, then this looks for the best scoring tree at the root of the parse; if there is none, then it amalgamates results from the best-scoring trees in the chart. """ if tree is None: try: tree = self.root[0] except IndexError: return self._skim_deps() return tree.deps_and_tags def _skim_deps( self, start: int = 0, end: int | None = None ) -> tuple[list[Dependency], list[str]]: # pragma: no cover if end is None: end = len(self.words) - 1 if start > end: return [], [] result_start = result_end = max_tree = None for span_length in reversed(range(end + 1 - start)): max_score = NEGATIVE_INFINITY for i in range(start, end + 1 - span_length): try: cell = self.chart[i, i + span_length] except KeyError: pass else: tree = cell[0] if tree.score > max_score: max_score = tree.score max_tree = tree result_start = i result_end = i + span_length if max_tree: break left_deps, left_tags = self._skim_deps(start, result_start - 1) tree_deps, tree_tags = max_tree.deps_and_tags right_deps, right_tags = self._skim_deps(result_end + 1, end) return (left_deps + tree_deps + right_deps, left_tags + tree_tags + right_tags)
[docs]class ChartParser:
[docs] def __init__(self, grammar: Grammar, cats: Iterable[str], root_cats: Iterable[str] | None, eisner_normal_form: bool, max_parse_trees: int, beam_size: int, input_tag_score_weight: float, missing_cat_score: float, missing_span_score: float) -> None: self.max_parse_trees = max_parse_trees self.categories = {} for plain_cat, markedup_cat in grammar.categories.items(): self.categories[plain_cat] = Category.parse(markedup_cat) self.rules = Rules(eisner_normal_form, grammar, self.categories) self.input_tag_score_weight = input_tag_score_weight self.beam_size = beam_size try: self.missing_cat_score = math.log(missing_cat_score) except ValueError: self.missing_cat_score = NEGATIVE_INFINITY try: self.missing_span_score = math.log(missing_span_score) except ValueError: self.missing_span_score = NEGATIVE_INFINITY CONJ_TAG = '[conj]' self.result_cats: dict[tuple[str, tuple[Category, ...]], int] = {} cat_id = 0 for cat_str in cats: chain = cat_str.split('::') res_cats: tuple[Category, ...] if len(chain) == 1 and chain[0].endswith(CONJ_TAG): base_cat = chain[0][:-len(CONJ_TAG)] if '/' in base_cat or '\\' in base_cat: base_cat = f'({base_cat})' cat_modified = fr'({base_cat}\{base_cat})' res_cats = (Category.parse(cat_modified),) label = 'conj' else: res_cats = tuple(map(Category.parse, chain)) label = 'unary' if len(chain) > 1 else 'binary' self.result_cats[label, res_cats] = cat_id cat_id += 1 self.set_root_cats(root_cats)
[docs] def set_root_cats(self, root_cats: Iterable[Category | str] | None) -> None: if root_cats is None: self.root_cats = None else: try: self.root_cats = [(cat if isinstance(cat, Category) else CATEGORIES[cat, 0]) for cat in root_cats] except KeyError as e: raise ValueError('Grammar does not contain root category: ' f'{repr(e.args[0])}') from e
[docs] def filter_root(self, trees: list[ParseTree]) -> list[ParseTree]: if self.root_cats is None: return trees else: results = [] for tree in trees: for cat in self.root_cats: if cat.matches(tree.cat): results.append(tree) break return results
[docs] def __call__(self, sentence: Sentence) -> ParseResult: """Parse a sentence.""" chart = Chart(self.beam_size) for i, (word, supertags) in enumerate(zip(sentence.words, sentence.input_supertags)): results = [] for supertag in supertags: tree = Lexical(self.categories[supertag.category], word, i + 1) tree.score = self.input_tag_score_weight * supertag.probability results.append(tree) try: span_scores = sentence.span_scores[i, i] except KeyError: pass else: if len(sentence) > 1: results += self.rules.type_change(results) results += self.rules.type_raise(results) for tree in results: if tree.left: self.calc_score_unary(tree, span_scores) # filter root cats if len(sentence) == 1: results = self.filter_root(results) chart.add(i, i, results) for span_length in range(1, len(sentence)): for end in range(span_length, len(sentence)): if chart.parse_tree_count > self.max_parse_trees: break start = end - span_length try: span_scores = sentence.span_scores[start, end] except KeyError: continue max_span_score = max((self.missing_cat_score, self.missing_span_score, *span_scores.values())) for split in range(start + 1, end + 1): try: left_trees = chart[start, split - 1] right_trees = chart[split, end] except KeyError: continue for left in left_trees: for right in right_trees: max_score = (left.score + right.score + max_span_score) if max_score < chart.min_score(start, end): break results = self.rules.combine(left, right) if results and len(sentence) > span_length + 1: results += self.rules.type_change(results) results += self.rules.type_raise(results) # filter root cats if span_length == len(sentence) - 1: results = self.filter_root(results) for tree in results: if tree.right: self.calc_score_binary(tree, span_scores) else: self.calc_score_unary(tree, span_scores) chart.add(start, end, results) return ParseResult(chart)
[docs] def calc_score_unary(self, tree: ParseTree, span_scores: Mapping[int, float]) -> None: """Calculate the score for a unary tree (chain).""" left = tree.left res_cat: tuple[str, tuple[Category, ...]] if left.right is None and left.left is not None: base = left.left res_cat = ('unary', (tree.cat, left.cat, left.left.cat)) else: base = left res_cat = ('unary', (tree.cat, left.cat)) if base.right is not None: tree.score = base.left.score + base.right.score else: tree.score = base.score cat_id = self.result_cats.get(res_cat) tree.score += self.get_span_score(span_scores, cat_id)
[docs] def calc_score_binary(self, tree: ParseTree, span_scores: Mapping[int, float]) -> None: """Calculate the score for a binary tree.""" if tree.coordinated: cat_id = self.result_cats.get(('conj', (tree.cat,))) else: cat = tree.cat try: cat_id = self.result_cats['binary', (tree.cat,)] except KeyError: if cat.atom == Atom.NP: cat_no_nb = Category(cat.atom) cat_id = self.result_cats.get(('binary', (cat_no_nb,))) else: cat_id = None tree.score = (tree.left.score + tree.right.score + self.get_span_score(span_scores, cat_id))
[docs] def get_span_score(self, span_scores: Mapping[int, float], cat_id: int | None) -> float: """Get the score in a span for a category (chain) ID.""" if cat_id is None: return self.missing_cat_score try: return span_scores[cat_id] except KeyError: return self.missing_span_score