Source code for lambeq.text2diagram.model_based_reader.base
# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Model-based reader==================Base class for readers that use pre-trained models forgenerating diagrams."""from__future__importannotations__all__=['ModelBasedReader']fromabcimportabstractmethodfrompathlibimportPathfromtypingimportAnyimporttorchfromlambeq.core.globalsimportVerbosityLevelfromlambeq.core.utilsimport(SentenceBatchType,tokenised_batch_type_check,TokenisedSentenceBatchType,untokenised_batch_type_check)fromlambeq.text2diagram.baseimportReaderfromlambeq.text2diagram.model_based_reader.model_downloaderimport(ModelDownloader,ModelDownloaderError,MODELS)fromlambeq.typingimportStrPathT
[docs]classModelBasedReader(Reader):"""Base class for readers that use pre-trained models. This is an abstract base class that provides common functionality for model-based readers. Subclasses must implement the specific model initialization and inference logic. """
[docs]def__init__(self,model_name_or_path:str|None=None,device:int|str|torch.device='cpu',cache_dir:StrPathT|None=None,force_download:bool=False,verbose:str=VerbosityLevel.PROGRESS.value,)->None:"""Initialise the model-based reader. Parameters ---------- model_name_or_path : str, default: 'bert' Can be either: - The path to a directory containing a model. - The name of a pre-trained model. device : int, str, or torch.device, default: 'cpu' Specifies the device on which to run the tagger model. - For CPU, use `'cpu'`. - For CUDA devices, use `'cuda:<device_id>'` or `<device_id>`. - For Apple Silicon (MPS), use `'mps'`. - You may also pass a :py:class:`torch.device` object. - For other devices, refer to the PyTorch documentation. cache_dir : str or os.PathLike, optional The directory to which a downloaded pre-trained model should be cached instead of the standard cache. force_download : bool, default: False Force the model to be downloaded, even if it is already available locally. verbose : str, default: 'progress' See :py:class:`VerbosityLevel` for options. """super().__init__(verbose=verbose)ifmodel_name_or_pathisNone:raiseValueError(f'Invalid value `{model_name_or_path}`'' for argument `model_name_or_path`.')self.model_name_or_path=model_name_or_pathself.device=deviceself.cache_dir=cache_dirself.force_download=force_downloadself.model_dir:Path|None=None# Prepare model artifactsself._prepare_model_artifacts()
def_prepare_model_artifacts(self)->None:"""Download model artifacts to disk."""self.model_dir=Path(self.model_name_or_path)ifnotself.model_dir.is_dir():# Check for updates only if a local model path is not# specified in `self.model_name_or_path`downloader=ModelDownloader(self.model_name_or_path,self.cache_dir)self.model_dir=downloader.model_dirif(self.force_downloadornotself.model_dir.is_dir()ordownloader.model_is_stale()):try:downloader.download_model(self.verbose)exceptModelDownloaderErrorase:local_model_version=downloader.get_local_model_version()if(self.model_dir.is_dir()andlocal_model_versionisnotNone):print('Failed to update model with 'f'exception: {e}')print('Attempting to continue with version 'f'{local_model_version}')else:# No local version to fall back toraisee@abstractmethoddef_initialise_model(self,**kwargs:Any)->None:"""Initialise the model and put it into the appropriate device. Also handle required miscellaneous initialisation steps here."""
[docs]defvalidate_sentence_batch(self,sentences:SentenceBatchType,tokenised:bool=False,suppress_exceptions:bool=False,)->tuple[TokenisedSentenceBatchType,list[int]]:"""Prepare input sentences for parsing. Parameters ---------- sentences : list of str, or list of list of str The sentences to be parsed, passed either as strings or as lists of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if a sentence fails to parse, instead of raising an exception, its return entry is :py:obj:`None`. tokenised : bool, default: False Whether each sentence has been passed as a list of tokens. verbose : str, optional See :py:class:`VerbosityLevel` for options. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- SentenceBatchType List of (tokenised or untokenised) sentences """tokenised_sentences:TokenisedSentenceBatchTypeiftokenised:ifnottokenised_batch_type_check(sentences):raiseValueError('`tokenised` set to `True`, but variable ''`sentences` does not have type ''`List[List[str]]`.')tokenised_sentences=list(sentences)# type: ignore[arg-type]else:ifnotuntokenised_batch_type_check(sentences):raiseValueError('`tokenised` set to `False`, but variable ''`sentences` does not have type ''`List[str]`.')sent_list:list[str]=[str(s)forsinsentences]tokenised_sentences=[sentence.split()forsentenceinsent_list]# Remove empty sentencesempty_indices:list[int]=[]fori,sentenceinenumerate(tokenised_sentences):ifnotsentence:ifsuppress_exceptions:empty_indices.append(i)else:raiseValueError('sentence is empty.')foriinreversed(empty_indices):deltokenised_sentences[i]returntokenised_sentences,empty_indices
[docs]@staticmethoddefavailable_models()->list[str]:"""List the available models."""return[*MODELS]