# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Oncilla parser
==============
A end-to-end pregroup parser, that directly generates pregroup diagrams
from text bypassing CCG as an intermediate representation.
"""
from __future__ import annotations
__all__ = ['OncillaParser', 'OncillaParseError']
import sys
from typing import Any
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from lambeq.backend.grammar import Diagram
from lambeq.core.globals import VerbosityLevel
from lambeq.core.utils import (SentenceBatchType,
SentenceType)
from lambeq.oncilla import (BertForSentenceToTree,
SentenceToTreeBertConfig)
from lambeq.text2diagram.model_based_reader.base import ModelBasedReader
from lambeq.text2diagram.pregroup_tree import PregroupTreeNode
from lambeq.text2diagram.pregroup_tree_converter import (generate_tree,
remove_cycles)
from lambeq.typing import StrPathT
[docs]
class OncillaParseError(Exception):
[docs]
def __init__(self, sentence: str, reason: str = '') -> None:
self.sentence = sentence
self.reason = reason
def __str__(self) -> str:
out = f'Oncilla failed to parse {self.sentence!r}'
if self.reason:
out += f': {self.reason}'
out += '.'
return out
[docs]
class OncillaParser(ModelBasedReader):
"""Parser using Oncilla as the backend."""
[docs]
def __init__(
self,
model_name_or_path: str = 'oncilla',
device: int | str | torch.device = 'cpu',
cache_dir: StrPathT | None = None,
force_download: bool = False,
verbose: str = VerbosityLevel.PROGRESS.value,
) -> None:
"""Instantiate an OncillaParser.
Parameters
----------
model_name_or_path : str, default: 'oncilla'
Can be either:
- The path to a directory containing an Oncilla model.
- The name of a pre-trained model.
By default, it uses the "bert" model.
See also: `OncillaParser.available_models()`
device : int, str, or torch.device, default: 'cpu'
Specifies the device on which to run the tagger model.
- For CPU, use `'cpu'`.
- For CUDA devices, use `'cuda:<device_id>'` or `<device_id>`.
- For Apple Silicon (MPS), use `'mps'`.
- You may also pass a :py:class:`torch.device` object.
- For other devices, refer to the PyTorch documentation.
cache_dir : str or os.PathLike, optional
The directory to which a downloaded pre-trained model should
be cached instead of the standard cache
(`$XDG_CACHE_HOME` or `~/.cache`).
force_download : bool, default: False
Force the model to be downloaded, even if it is already
available locally.
verbose : str, default: 'progress',
See :py:class:`VerbosityLevel` for options.
"""
super().__init__(model_name_or_path=model_name_or_path,
device=device,
cache_dir=cache_dir,
force_download=force_download,
verbose=verbose)
# Initialise model
self._initialise_model()
def _initialise_model(self, **kwargs: Any) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model_config = SentenceToTreeBertConfig.from_pretrained(
self.model_dir
)
self.model = BertForSentenceToTree.from_pretrained(
self.model_dir, config=self.model_config
).eval().to(self.device)
def _sentences2pregrouptrees(
self,
sentences: SentenceBatchType,
break_cycles: bool = False,
tokenised: bool = False,
suppress_exceptions: bool = False,
verbose: str | None = None,
) -> list[PregroupTreeNode | None]:
"""Parse multiple sentences into a list of pregroup trees.
Parameters
----------
sentences : list of str, or list of list of str
The sentences to be parsed.
break_cycles : bool, default: False
Flag that indicates whether cycles will be broken in
the output pregroup tree. This is done by removing
duplicate nodes, keeping the copy of the node that is closest
to its parent in the original sentence.
tokenised : bool, default: False
Whether each sentence has been passed as a list of tokens.
suppress_exceptions : bool, default: False
Whether to suppress exceptions. If :py:obj:`True`, then if a
sentence fails to parse, instead of raising an exception,
its return entry is :py:obj:`None`.
verbose : str, optional
See :py:class:`VerbosityLevel` for options. Not all parsers
implement all three levels of progress reporting, see the
respective documentation for each parser. If set, takes
priority over the :py:attr:`verbose` attribute of the
parser.
Returns
-------
list of :py:class:`lambeq.text2diagram.PregroupTreeNode` or None
The pregroup trees. May contain :py:obj:`None` if
exceptions are suppressed.
"""
if verbose is None:
verbose = self.verbose
if not VerbosityLevel.has_value(verbose):
raise ValueError(f'`{verbose}` is not a valid verbose value '
' for `OncillaParser`.')
sentences_valid, empty_indices = self.validate_sentence_batch(
sentences,
tokenised=tokenised,
suppress_exceptions=suppress_exceptions
)
pregroup_trees: list[PregroupTreeNode | None] = []
if sentences_valid:
if verbose == VerbosityLevel.TEXT.value:
print('Turning sentences to pregroup trees.', file=sys.stderr)
for sent in tqdm(sentences_valid,
desc='Turning sentences to pregroup trees',
leave=False,
disable=verbose != VerbosityLevel.PROGRESS.value):
pregroup_tree: PregroupTreeNode | None = None
try:
if sent[-1] == '.':
# Remove ending '.' as this was removed from
# the training dataset for training.
sent = sent[:-1]
# Predict types and parents
parse_output = self.model._sentence2pred(sent,
self.tokenizer)
# Create tree from type and parent preds
root_nodes: list[PregroupTreeNode]
root_nodes, _ = generate_tree(parse_output.words,
parse_output.types,
parse_output.parents)
except Exception as e:
if not suppress_exceptions:
raise OncillaParseError(' '.join(sent)) from e
else:
if len(root_nodes) > 1:
if not suppress_exceptions:
raise OncillaParseError(
' '.join(sent),
reason=f'Got {len(root_nodes)} disjoint trees'
)
elif not len(root_nodes):
if not suppress_exceptions:
raise OncillaParseError(
' '.join(sent),
reason=f'Got {len(root_nodes)} trees'
)
else:
pregroup_tree = root_nodes[0]
if break_cycles:
remove_cycles(pregroup_tree)
pregroup_trees.append(pregroup_tree)
for i in empty_indices:
pregroup_trees.insert(i, None)
return pregroup_trees
def _sentence2pregrouptree(
self,
sentence: SentenceType,
break_cycles: bool = False,
tokenised: bool = False,
suppress_exceptions: bool = False,
verbose: str | None = None,
) -> PregroupTreeNode | None:
"""Parse a sentence into a pregroup tree.
Parameters
----------
sentence : str, list[str]
The sentence to be parsed, passed either as a string, or as
a list of tokens.
break_cycles : bool, default: False
Flag that indicates whether cycles will be broken in
the output pregroup tree. This is done by removing
duplicate nodes, keeping the copy of the node that is closest
to its parent in the original sentence.
tokenised : bool, default: False
Whether each sentence has been passed as a list of tokens.
suppress_exceptions : bool, default: False
Whether to suppress exceptions. If :py:obj:`True`, then if a
sentence fails to parse, instead of raising an exception,
its return entry is :py:obj:`None`.
verbose : str, optional
See :py:class:`VerbosityLevel` for options. Not all parsers
implement all three levels of progress reporting, see the
respective documentation for each parser. If set, takes
priority over the :py:attr:`verbose` attribute of the
parser.
Returns
-------
:py:class:`lambeq.text2diagram.PregroupTreeNode` or None
The pregroup tree, or :py:obj:`None` on failure.
"""
return self._sentences2pregrouptrees(
[sentence], # type: ignore[arg-type]
break_cycles=break_cycles,
tokenised=tokenised,
suppress_exceptions=suppress_exceptions,
verbose=verbose
)[0]
[docs]
def sentences2diagrams(
self,
sentences: SentenceBatchType,
tokenised: bool = False,
suppress_exceptions: bool = False,
verbose: str | None = None,
) -> list[Diagram | None]:
"""Parse multiple sentences into a list of lambeq diagrams.
Parameters
----------
sentences : list of str, or list of list of str
The sentences to be parsed.
tokenised : bool, default: False
Whether each sentence has been passed as a list of tokens.
suppress_exceptions : bool, default: False
Whether to suppress exceptions. If :py:obj:`True`, then if a
sentence fails to parse, instead of raising an exception,
its return entry is :py:obj:`None`.
verbose : str, optional
See :py:class:`VerbosityLevel` for options. Not all parsers
implement all three levels of progress reporting, see the
respective documentation for each parser. If set, takes
priority over the :py:attr:`verbose` attribute of the
parser.
Returns
-------
list of :py:class:`lambeq.backend.grammar.Diagram` or None
The parsed diagrams. May contain :py:obj:`None` if
exceptions are suppressed.
"""
pregroup_trees = self._sentences2pregrouptrees(
sentences,
tokenised=tokenised,
suppress_exceptions=suppress_exceptions,
verbose=verbose
)
diagrams: list[Diagram | None] = []
if verbose is None:
verbose = self.verbose
if verbose is VerbosityLevel.TEXT.value:
print('Turning pregroup trees to diagrams.', file=sys.stderr)
for tree in tqdm(
pregroup_trees,
desc='Turning pregroup trees to diagrams',
leave=False,
total=len(pregroup_trees),
disable=verbose != VerbosityLevel.PROGRESS.value
):
diagram: Diagram | None = None
if tree is not None:
try:
tokens = tree.get_words()
diagram = tree.to_diagram(tokens=tokens)
except Exception as e:
if not suppress_exceptions:
raise OncillaParseError(' '.join(tokens)) from e
diagrams.append(diagram)
return diagrams
[docs]
def sentence2diagram(
self,
sentence: SentenceType,
tokenised: bool = False,
suppress_exceptions: bool = False,
verbose: str | None = None
) -> Diagram | None:
"""Parse a sentence into a lambeq diagram.
Parameters
----------
sentence : str, or list of str
The sentence to be parsed.
tokenised : bool, default: False
Whether the sentence has been passed as a list of tokens.
suppress_exceptions : bool, default: False
Whether to suppress exceptions. If :py:obj:`True`, then if
the sentence fails to parse, instead of raising an
exception, returns :py:obj:`None`.
verbose : str, optional
See :py:class:`VerbosityLevel` for options. Not all parsers
implement all three levels of progress reporting, see the
respective documentation for each parser. If set, takes
priority over the :py:attr:`verbose` attribute of the
parser.
Returns
-------
:py:class:`lambeq.backend.grammar.Diagram` or None
The parsed diagram, or :py:obj:`None` on failure.
"""
return self.sentences2diagrams(
[sentence], # type: ignore[arg-type]
tokenised=tokenised,
suppress_exceptions=suppress_exceptions,
verbose=verbose
)[0]