Source code for lambeq.text2diagram.bobcat_parser

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Bobcat parser
=============
A chart-based parser based on the C&C parser, with scores predicted by a
transformer.

"""

from __future__ import annotations

__all__ = ['BobcatParser', 'BobcatParseError']

from collections.abc import Iterable
import json
from pathlib import Path
import sys
from typing import Any

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from lambeq.bobcat import (BertForChartClassification, Category,
                           ChartParser, Grammar, ParseTree,
                           Sentence, Supertag, Tagger)
from lambeq.bobcat.tagger import TaggerOutputSentence
from lambeq.core.globals import VerbosityLevel
from lambeq.core.utils import (SentenceBatchType,
                               tokenised_batch_type_check,
                               untokenised_batch_type_check)
from lambeq.text2diagram.ccg_parser import CCGParser
from lambeq.text2diagram.ccg_rule import CCGRule
from lambeq.text2diagram.ccg_tree import CCGTree
from lambeq.text2diagram.ccg_type import CCGType
from lambeq.text2diagram.model_downloader import (ModelDownloader,
                                                  ModelDownloaderError,
                                                  MODELS)
from lambeq.typing import StrPathT


[docs]class BobcatParseError(Exception):
[docs] def __init__(self, sentence: str) -> None: self.sentence = sentence
def __str__(self) -> str: return f'Bobcat failed to parse {self.sentence!r}.'
[docs]class BobcatParser(CCGParser): """CCG parser using Bobcat as the backend."""
[docs] def __init__(self, model_name_or_path: str = 'bert', root_cats: Iterable[str] | None = None, device: int = -1, cache_dir: StrPathT | None = None, force_download: bool = False, verbose: str = VerbosityLevel.PROGRESS.value, **kwargs: Any) -> None: """Instantiate a BobcatParser. Parameters ---------- model_name_or_path : str, default: 'bert' Can be either: - The path to a directory containing a Bobcat model. - The name of a pre-trained model. By default, it uses the "bert" model. See also: `BobcatParser.available_models()` root_cats : iterable of str, optional A list of the categories allowed at the root of the parse tree. device : int, default: -1 The GPU device ID on which to run the model, if positive. If negative (the default), run on the CPU. cache_dir : str or os.PathLike, optional The directory to which a downloaded pre-trained model should be cached instead of the standard cache (`$XDG_CACHE_HOME` or `~/.cache`). force_download : bool, default: False Force the model to be downloaded, even if it is already available locally. verbose : str, default: 'progress', See :py:class:`VerbosityLevel` for options. **kwargs : dict, optional Additional keyword arguments to be passed to the underlying parsers (see Other Parameters). By default, they are set to the values in the `pipeline_config.json` file in the model directory. Other Parameters ---------------- Tagger parameters: batch_size : int, optional The number of sentences per batch. tag_top_k : int, optional The maximum number of tags to keep. If 0, keep all tags. tag_prob_threshold : float, optional The probability multiplier used for the threshold to keep tags. tag_prob_threshold_strategy : {'relative', 'absolute'} If "relative", the probablity threshold is relative to the highest scoring tag. Otherwise, the probability is an absolute threshold. span_top_k : int, optional The maximum number of entries to keep per span. If 0, keep all entries. span_prob_threshold : float, optional The probability multiplier used for the threshold to keep entries for a span. span_prob_threshold_strategy : {'relative', 'absolute'} If "relative", the probablity threshold is relative to the highest scoring entry. Otherwise, the probability is an absolute threshold. Chart parser parameters: eisner_normal_form : bool, default: True Whether to use eisner normal form. max_parse_trees : int, optional A safety limit to the number of parse trees that can be generated per parse before automatically failing. beam_size : int, optional The beam size to use in the chart cells. input_tag_score_weight : float, optional A scaling multiplier to the log-probabilities of the input tags. This means that a weight of 0 causes all of the input tags to have the same score. missing_cat_score : float, optional The default score for a category that is generated but not part of the grammar. missing_span_score : float, optional The default score for a category that is part of the grammar but has no score, due to being below the threshold kept by the tagger. """ self.verbose = verbose if not VerbosityLevel.has_value(verbose): raise ValueError(f'`{verbose}` is not a valid verbose value for ' 'BobcatParser.') model_dir = Path(model_name_or_path) if not model_dir.is_dir(): # Check for updates only if a local model path is not # specified in `model_name_or_path` downloader = ModelDownloader(model_name_or_path, cache_dir) model_dir = downloader.model_dir if (force_download or not model_dir.is_dir() or downloader.model_is_stale()): try: downloader.download_model(verbose) except ModelDownloaderError as e: local_model_version = downloader.get_local_model_version() if (model_dir.is_dir() and local_model_version is not None): print('Failed to update model with ' f'exception: {e}') print('Attempting to continue with version ' f'{local_model_version}') else: # No local version to fall back to raise e with open(model_dir / 'pipeline_config.json') as f: config = json.load(f) for subconfig in config.values(): for key in subconfig: try: subconfig[key] = kwargs.pop(key) except KeyError: pass if kwargs: raise TypeError('BobcatParser got unexpected keyword argument(s): ' f'{", ".join(map(repr, kwargs))}') device_ = torch.device('cpu' if device < 0 else f'cuda:{device}') model = (BertForChartClassification.from_pretrained(model_dir) .eval() .to(device_)) tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tagger = Tagger(model, tokenizer, **config['tagger']) grammar = Grammar.load(model_dir / 'grammar.json') self.parser = ChartParser(grammar, self.tagger.model.config.cats, root_cats, **config['parser'])
@staticmethod def _prepare_sentence(sent: TaggerOutputSentence, tags: list[str]) -> Sentence: """Turn JSON input into a Sentence for parsing.""" sent_tags = [[Supertag(tags[id], prob) for id, prob in supertags] for supertags in sent.tags] spans = {(start, end): {id: score for id, score in scores} for start, end, scores in sent.spans} return Sentence(sent.words, sent_tags, spans)
[docs] def sentences2trees( self, sentences: SentenceBatchType, tokenised: bool = False, suppress_exceptions: bool = False, verbose: str | None = None ) -> list[CCGTree] | None: """Parse multiple sentences into a list of :py:class:`.CCGTree` s. Parameters ---------- sentences : list of str, or list of list of str The sentences to be parsed, passed either as strings or as lists of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if a sentence fails to parse, instead of raising an exception, its return entry is :py:obj:`None`. tokenised : bool, default: False Whether each sentence has been passed as a list of tokens. verbose : str, optional See :py:class:`VerbosityLevel` for options. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- list of CCGTree or None The parsed trees. (May contain :py:obj:`None` if exceptions are suppressed) """ if verbose is None: verbose = self.verbose if not VerbosityLevel.has_value(verbose): raise ValueError(f'`{verbose}` is not a valid verbose value for ' 'BobcatParser.') if tokenised: if not tokenised_batch_type_check(sentences): raise ValueError('`tokenised` set to `True`, but variable ' '`sentences` does not have type ' '`List[List[str]]`.') else: if not untokenised_batch_type_check(sentences): raise ValueError('`tokenised` set to `False`, but variable ' '`sentences` does not have type ' '`List[str]`.') sent_list: list[str] = [str(s) for s in sentences] sentences = [sentence.split() for sentence in sent_list] empty_indices = [] for i, sentence in enumerate(sentences): if not sentence: if suppress_exceptions: empty_indices.append(i) else: raise ValueError('sentence is empty.') for i in reversed(empty_indices): del sentences[i] trees: list[CCGTree] = [] if sentences: if verbose == VerbosityLevel.TEXT.value: print('Tagging sentences.', file=sys.stderr) tag_results = self.tagger(sentences, verbose=verbose) tags = tag_results.tags if verbose == VerbosityLevel.TEXT.value: print('Parsing tagged sentences.', file=sys.stderr) for sent in tqdm( tag_results.sentences, desc='Parsing tagged sentences', leave=False, disable=verbose != VerbosityLevel.PROGRESS.value): try: sentence_input = self._prepare_sentence(sent, tags) result = self.parser(sentence_input) trees.append(self._build_ccgtree(result[0])) except Exception as e: if suppress_exceptions: trees.append(None) else: raise BobcatParseError(' '.join(sent.words)) from e for i in empty_indices: trees.insert(i, None) return trees
@staticmethod def _to_biclosed(cat: Category) -> CCGType: """Transform a Bobcat category into a biclosed type.""" if cat.atomic: if cat.atom.is_punct: return CCGType.PUNCTUATION else: atom = str(cat.atom) if atom == 'N': return CCGType.NOUN elif atom == 'NP': return CCGType.NOUN_PHRASE elif atom == 'S': return CCGType.SENTENCE elif atom == 'PP': return CCGType.PREPOSITIONAL_PHRASE elif atom == 'conj': return CCGType.CONJUNCTION raise ValueError(f'Invalid atomic type: {cat.atom!r}') else: result = BobcatParser._to_biclosed(cat.result) argument = BobcatParser._to_biclosed(cat.argument) return result.slash(cat.dir, argument) @staticmethod def _build_ccgtree(tree: ParseTree) -> CCGTree: """Transform a Bobcat parse tree into a `CCGTree`.""" children = [BobcatParser._build_ccgtree(child) for child in filter(None, (tree.left, tree.right))] if tree.rule.name == 'ADJ_CONJ': rule = CCGRule.FORWARD_APPLICATION else: rule = CCGRule(tree.rule.name) return CCGTree(text=tree.word if tree.is_leaf else None, rule=rule, biclosed_type=BobcatParser._to_biclosed(tree.cat), children=children, metadata={'original': tree})
[docs] @staticmethod def available_models() -> list[str]: """List the available models.""" return [*MODELS]