Source code for lambeq.text2diagram.model_based_reader.oncilla_parser

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Oncilla parser
==============
A end-to-end pregroup parser, that directly generates pregroup diagrams
from text bypassing CCG as an intermediate representation.

"""

from __future__ import annotations

__all__ = ['OncillaParser', 'OncillaParseError']

import sys
from typing import Any

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from lambeq.backend.grammar import Diagram
from lambeq.core.globals import VerbosityLevel
from lambeq.core.utils import (SentenceBatchType,
                               SentenceType)
from lambeq.oncilla import (BertForSentenceToTree,
                            SentenceToTreeBertConfig)
from lambeq.text2diagram.model_based_reader.base import ModelBasedReader
from lambeq.text2diagram.pregroup_tree import PregroupTreeNode
from lambeq.text2diagram.pregroup_tree_converter import (generate_tree,
                                                         remove_cycles)
from lambeq.typing import StrPathT


[docs] class OncillaParseError(Exception):
[docs] def __init__(self, sentence: str, reason: str = '') -> None: self.sentence = sentence self.reason = reason
def __str__(self) -> str: out = f'Oncilla failed to parse {self.sentence!r}' if self.reason: out += f': {self.reason}' out += '.' return out
[docs] class OncillaParser(ModelBasedReader): """Parser using Oncilla as the backend."""
[docs] def __init__( self, model_name_or_path: str = 'oncilla', device: int | str | torch.device = 'cpu', cache_dir: StrPathT | None = None, force_download: bool = False, verbose: str = VerbosityLevel.PROGRESS.value, ) -> None: """Instantiate an OncillaParser. Parameters ---------- model_name_or_path : str, default: 'oncilla' Can be either: - The path to a directory containing an Oncilla model. - The name of a pre-trained model. By default, it uses the "bert" model. See also: `OncillaParser.available_models()` device : int, str, or torch.device, default: 'cpu' Specifies the device on which to run the tagger model. - For CPU, use `'cpu'`. - For CUDA devices, use `'cuda:<device_id>'` or `<device_id>`. - For Apple Silicon (MPS), use `'mps'`. - You may also pass a :py:class:`torch.device` object. - For other devices, refer to the PyTorch documentation. cache_dir : str or os.PathLike, optional The directory to which a downloaded pre-trained model should be cached instead of the standard cache (`$XDG_CACHE_HOME` or `~/.cache`). force_download : bool, default: False Force the model to be downloaded, even if it is already available locally. verbose : str, default: 'progress', See :py:class:`VerbosityLevel` for options. """ super().__init__(model_name_or_path=model_name_or_path, device=device, cache_dir=cache_dir, force_download=force_download, verbose=verbose) # Initialise model self._initialise_model()
def _initialise_model(self, **kwargs: Any) -> None: self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) self.model_config = SentenceToTreeBertConfig.from_pretrained( self.model_dir ) self.model = BertForSentenceToTree.from_pretrained( self.model_dir, config=self.model_config ).eval().to(self.device) def _sentences2pregrouptrees( self, sentences: SentenceBatchType, break_cycles: bool = False, tokenised: bool = False, suppress_exceptions: bool = False, verbose: str | None = None, ) -> list[PregroupTreeNode | None]: """Parse multiple sentences into a list of pregroup trees. Parameters ---------- sentences : list of str, or list of list of str The sentences to be parsed. break_cycles : bool, default: False Flag that indicates whether cycles will be broken in the output pregroup tree. This is done by removing duplicate nodes, keeping the copy of the node that is closest to its parent in the original sentence. tokenised : bool, default: False Whether each sentence has been passed as a list of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if a sentence fails to parse, instead of raising an exception, its return entry is :py:obj:`None`. verbose : str, optional See :py:class:`VerbosityLevel` for options. Not all parsers implement all three levels of progress reporting, see the respective documentation for each parser. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- list of :py:class:`lambeq.text2diagram.PregroupTreeNode` or None The pregroup trees. May contain :py:obj:`None` if exceptions are suppressed. """ if verbose is None: verbose = self.verbose if not VerbosityLevel.has_value(verbose): raise ValueError(f'`{verbose}` is not a valid verbose value ' ' for `OncillaParser`.') sentences_valid, empty_indices = self.validate_sentence_batch( sentences, tokenised=tokenised, suppress_exceptions=suppress_exceptions ) pregroup_trees: list[PregroupTreeNode | None] = [] if sentences_valid: if verbose == VerbosityLevel.TEXT.value: print('Turning sentences to pregroup trees.', file=sys.stderr) for sent in tqdm(sentences_valid, desc='Turning sentences to pregroup trees', leave=False, disable=verbose != VerbosityLevel.PROGRESS.value): pregroup_tree: PregroupTreeNode | None = None try: if sent[-1] == '.': # Remove ending '.' as this was removed from # the training dataset for training. sent = sent[:-1] # Predict types and parents parse_output = self.model._sentence2pred(sent, self.tokenizer) # Create tree from type and parent preds root_nodes: list[PregroupTreeNode] root_nodes, _ = generate_tree(parse_output.words, parse_output.types, parse_output.parents) except Exception as e: if not suppress_exceptions: raise OncillaParseError(' '.join(sent)) from e else: if len(root_nodes) > 1: if not suppress_exceptions: raise OncillaParseError( ' '.join(sent), reason=f'Got {len(root_nodes)} disjoint trees' ) elif not len(root_nodes): if not suppress_exceptions: raise OncillaParseError( ' '.join(sent), reason=f'Got {len(root_nodes)} trees' ) else: pregroup_tree = root_nodes[0] if break_cycles: remove_cycles(pregroup_tree) pregroup_trees.append(pregroup_tree) for i in empty_indices: pregroup_trees.insert(i, None) return pregroup_trees def _sentence2pregrouptree( self, sentence: SentenceType, break_cycles: bool = False, tokenised: bool = False, suppress_exceptions: bool = False, verbose: str | None = None, ) -> PregroupTreeNode | None: """Parse a sentence into a pregroup tree. Parameters ---------- sentence : str, list[str] The sentence to be parsed, passed either as a string, or as a list of tokens. break_cycles : bool, default: False Flag that indicates whether cycles will be broken in the output pregroup tree. This is done by removing duplicate nodes, keeping the copy of the node that is closest to its parent in the original sentence. tokenised : bool, default: False Whether each sentence has been passed as a list of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if a sentence fails to parse, instead of raising an exception, its return entry is :py:obj:`None`. verbose : str, optional See :py:class:`VerbosityLevel` for options. Not all parsers implement all three levels of progress reporting, see the respective documentation for each parser. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- :py:class:`lambeq.text2diagram.PregroupTreeNode` or None The pregroup tree, or :py:obj:`None` on failure. """ return self._sentences2pregrouptrees( [sentence], # type: ignore[arg-type] break_cycles=break_cycles, tokenised=tokenised, suppress_exceptions=suppress_exceptions, verbose=verbose )[0]
[docs] def sentences2diagrams( self, sentences: SentenceBatchType, tokenised: bool = False, suppress_exceptions: bool = False, verbose: str | None = None, ) -> list[Diagram | None]: """Parse multiple sentences into a list of lambeq diagrams. Parameters ---------- sentences : list of str, or list of list of str The sentences to be parsed. tokenised : bool, default: False Whether each sentence has been passed as a list of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if a sentence fails to parse, instead of raising an exception, its return entry is :py:obj:`None`. verbose : str, optional See :py:class:`VerbosityLevel` for options. Not all parsers implement all three levels of progress reporting, see the respective documentation for each parser. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- list of :py:class:`lambeq.backend.grammar.Diagram` or None The parsed diagrams. May contain :py:obj:`None` if exceptions are suppressed. """ pregroup_trees = self._sentences2pregrouptrees( sentences, tokenised=tokenised, suppress_exceptions=suppress_exceptions, verbose=verbose ) diagrams: list[Diagram | None] = [] if verbose is None: verbose = self.verbose if verbose is VerbosityLevel.TEXT.value: print('Turning pregroup trees to diagrams.', file=sys.stderr) for tree in tqdm( pregroup_trees, desc='Turning pregroup trees to diagrams', leave=False, total=len(pregroup_trees), disable=verbose != VerbosityLevel.PROGRESS.value ): diagram: Diagram | None = None if tree is not None: try: tokens = tree.get_words() diagram = tree.to_diagram(tokens=tokens) except Exception as e: if not suppress_exceptions: raise OncillaParseError(' '.join(tokens)) from e diagrams.append(diagram) return diagrams
[docs] def sentence2diagram( self, sentence: SentenceType, tokenised: bool = False, suppress_exceptions: bool = False, verbose: str | None = None ) -> Diagram | None: """Parse a sentence into a lambeq diagram. Parameters ---------- sentence : str, or list of str The sentence to be parsed. tokenised : bool, default: False Whether the sentence has been passed as a list of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if the sentence fails to parse, instead of raising an exception, returns :py:obj:`None`. verbose : str, optional See :py:class:`VerbosityLevel` for options. Not all parsers implement all three levels of progress reporting, see the respective documentation for each parser. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- :py:class:`lambeq.backend.grammar.Diagram` or None The parsed diagram, or :py:obj:`None` on failure. """ return self.sentences2diagrams( [sentence], # type: ignore[arg-type] tokenised=tokenised, suppress_exceptions=suppress_exceptions, verbose=verbose )[0]