# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Bobcat parser=============A chart-based parser based on the C&C parser, with scores predicted by atransformer."""from__future__importannotations__all__=['BobcatParser','BobcatParseError']fromcollections.abcimportIterableimportjsonfrompathlibimportPathimportsysfromtypingimportAnyimporttorchfromtqdm.autoimporttqdmfromtransformersimportAutoTokenizerfromlambeq.bobcatimport(BertForChartClassification,Category,ChartParser,Grammar,ParseTree,Sentence,Supertag,Tagger)fromlambeq.bobcat.taggerimportTaggerOutputSentencefromlambeq.core.globalsimportVerbosityLevelfromlambeq.core.utilsimport(SentenceBatchType,tokenised_batch_type_check,untokenised_batch_type_check)fromlambeq.text2diagram.ccg_parserimportCCGParserfromlambeq.text2diagram.ccg_ruleimportCCGRulefromlambeq.text2diagram.ccg_treeimportCCGTreefromlambeq.text2diagram.ccg_typeimportCCGTypefromlambeq.text2diagram.model_downloaderimport(ModelDownloader,ModelDownloaderError,MODELS)fromlambeq.typingimportStrPathT
def__str__(self)->str:returnf'Bobcat failed to parse {self.sentence!r}.'
[docs]classBobcatParser(CCGParser):"""CCG parser using Bobcat as the backend."""
[docs]def__init__(self,model_name_or_path:str='bert',root_cats:Iterable[str]|None=None,device:int=-1,cache_dir:StrPathT|None=None,force_download:bool=False,verbose:str=VerbosityLevel.PROGRESS.value,**kwargs:Any)->None:"""Instantiate a BobcatParser. Parameters ---------- model_name_or_path : str, default: 'bert' Can be either: - The path to a directory containing a Bobcat model. - The name of a pre-trained model. By default, it uses the "bert" model. See also: `BobcatParser.available_models()` root_cats : iterable of str, optional A list of the categories allowed at the root of the parse tree. device : int, default: -1 The GPU device ID on which to run the model, if positive. If negative (the default), run on the CPU. cache_dir : str or os.PathLike, optional The directory to which a downloaded pre-trained model should be cached instead of the standard cache (`$XDG_CACHE_HOME` or `~/.cache`). force_download : bool, default: False Force the model to be downloaded, even if it is already available locally. verbose : str, default: 'progress', See :py:class:`VerbosityLevel` for options. **kwargs : dict, optional Additional keyword arguments to be passed to the underlying parsers (see Other Parameters). By default, they are set to the values in the `pipeline_config.json` file in the model directory. Other Parameters ---------------- Tagger parameters: batch_size : int, optional The number of sentences per batch. tag_top_k : int, optional The maximum number of tags to keep. If 0, keep all tags. tag_prob_threshold : float, optional The probability multiplier used for the threshold to keep tags. tag_prob_threshold_strategy : {'relative', 'absolute'} If "relative", the probablity threshold is relative to the highest scoring tag. Otherwise, the probability is an absolute threshold. span_top_k : int, optional The maximum number of entries to keep per span. If 0, keep all entries. span_prob_threshold : float, optional The probability multiplier used for the threshold to keep entries for a span. span_prob_threshold_strategy : {'relative', 'absolute'} If "relative", the probablity threshold is relative to the highest scoring entry. Otherwise, the probability is an absolute threshold. Chart parser parameters: eisner_normal_form : bool, default: True Whether to use eisner normal form. max_parse_trees : int, optional A safety limit to the number of parse trees that can be generated per parse before automatically failing. beam_size : int, optional The beam size to use in the chart cells. input_tag_score_weight : float, optional A scaling multiplier to the log-probabilities of the input tags. This means that a weight of 0 causes all of the input tags to have the same score. missing_cat_score : float, optional The default score for a category that is generated but not part of the grammar. missing_span_score : float, optional The default score for a category that is part of the grammar but has no score, due to being below the threshold kept by the tagger. """self.verbose=verboseifnotVerbosityLevel.has_value(verbose):raiseValueError(f'`{verbose}` is not a valid verbose value for ''BobcatParser.')model_dir=Path(model_name_or_path)ifnotmodel_dir.is_dir():# Check for updates only if a local model path is not# specified in `model_name_or_path`downloader=ModelDownloader(model_name_or_path,cache_dir)model_dir=downloader.model_dirif(force_downloadornotmodel_dir.is_dir()ordownloader.model_is_stale()):try:downloader.download_model(verbose)exceptModelDownloaderErrorase:local_model_version=downloader.get_local_model_version()if(model_dir.is_dir()andlocal_model_versionisnotNone):print('Failed to update model with 'f'exception: {e}')print('Attempting to continue with version 'f'{local_model_version}')else:# No local version to fall back toraiseewithopen(model_dir/'pipeline_config.json')asf:config=json.load(f)forsubconfiginconfig.values():forkeyinsubconfig:try:subconfig[key]=kwargs.pop(key)exceptKeyError:passifkwargs:raiseTypeError('BobcatParser got unexpected keyword argument(s): 'f'{", ".join(map(repr,kwargs))}')device_=torch.device('cpu'ifdevice<0elsef'cuda:{device}')model=(BertForChartClassification.from_pretrained(model_dir).eval().to(device_))tokenizer=AutoTokenizer.from_pretrained(model_dir)self.tagger=Tagger(model,tokenizer,**config['tagger'])grammar=Grammar.load(model_dir/'grammar.json')self.parser=ChartParser(grammar,self.tagger.model.config.cats,root_cats,**config['parser'])
@staticmethoddef_prepare_sentence(sent:TaggerOutputSentence,tags:list[str])->Sentence:"""Turn JSON input into a Sentence for parsing."""sent_tags=[[Supertag(tags[id],prob)forid,probinsupertags]forsupertagsinsent.tags]spans={(start,end):{id:scoreforid,scoreinscores}forstart,end,scoresinsent.spans}returnSentence(sent.words,sent_tags,spans)
[docs]defsentences2trees(self,sentences:SentenceBatchType,tokenised:bool=False,suppress_exceptions:bool=False,verbose:str|None=None)->list[CCGTree]|None:"""Parse multiple sentences into a list of :py:class:`.CCGTree` s. Parameters ---------- sentences : list of str, or list of list of str The sentences to be parsed, passed either as strings or as lists of tokens. suppress_exceptions : bool, default: False Whether to suppress exceptions. If :py:obj:`True`, then if a sentence fails to parse, instead of raising an exception, its return entry is :py:obj:`None`. tokenised : bool, default: False Whether each sentence has been passed as a list of tokens. verbose : str, optional See :py:class:`VerbosityLevel` for options. If set, takes priority over the :py:attr:`verbose` attribute of the parser. Returns ------- list of CCGTree or None The parsed trees. (May contain :py:obj:`None` if exceptions are suppressed) """ifverboseisNone:verbose=self.verboseifnotVerbosityLevel.has_value(verbose):raiseValueError(f'`{verbose}` is not a valid verbose value for ''BobcatParser.')iftokenised:ifnottokenised_batch_type_check(sentences):raiseValueError('`tokenised` set to `True`, but variable ''`sentences` does not have type ''`List[List[str]]`.')else:ifnotuntokenised_batch_type_check(sentences):raiseValueError('`tokenised` set to `False`, but variable ''`sentences` does not have type ''`List[str]`.')sent_list:list[str]=[str(s)forsinsentences]sentences=[sentence.split()forsentenceinsent_list]empty_indices=[]fori,sentenceinenumerate(sentences):ifnotsentence:ifsuppress_exceptions:empty_indices.append(i)else:raiseValueError('sentence is empty.')foriinreversed(empty_indices):delsentences[i]trees:list[CCGTree]=[]ifsentences:ifverbose==VerbosityLevel.TEXT.value:print('Tagging sentences.',file=sys.stderr)tag_results=self.tagger(sentences,verbose=verbose)tags=tag_results.tagsifverbose==VerbosityLevel.TEXT.value:print('Parsing tagged sentences.',file=sys.stderr)forsentintqdm(tag_results.sentences,desc='Parsing tagged sentences',leave=False,disable=verbose!=VerbosityLevel.PROGRESS.value):try:sentence_input=self._prepare_sentence(sent,tags)result=self.parser(sentence_input)trees.append(self._build_ccgtree(result[0]))exceptExceptionase:ifsuppress_exceptions:trees.append(None)else:raiseBobcatParseError(' '.join(sent.words))fromeforiinempty_indices:trees.insert(i,None)returntrees
@staticmethoddef_to_biclosed(cat:Category)->CCGType:"""Transform a Bobcat category into a biclosed type."""ifcat.atomic:ifcat.atom.is_punct:returnCCGType.PUNCTUATIONelse:atom=str(cat.atom)ifatom=='N':returnCCGType.NOUNelifatom=='NP':returnCCGType.NOUN_PHRASEelifatom=='S':returnCCGType.SENTENCEelifatom=='PP':returnCCGType.PREPOSITIONAL_PHRASEelifatom=='conj':returnCCGType.CONJUNCTIONraiseValueError(f'Invalid atomic type: {cat.atom!r}')else:result=BobcatParser._to_biclosed(cat.result)argument=BobcatParser._to_biclosed(cat.argument)returnresult.slash(cat.dir,argument)@staticmethoddef_build_ccgtree(tree:ParseTree)->CCGTree:"""Transform a Bobcat parse tree into a `CCGTree`."""children=[BobcatParser._build_ccgtree(child)forchildinfilter(None,(tree.left,tree.right))]iftree.rule.name=='ADJ_CONJ':rule=CCGRule.FORWARD_APPLICATIONelse:rule=CCGRule(tree.rule.name)returnCCGTree(text=tree.wordiftree.is_leafelseNone,rule=rule,biclosed_type=BobcatParser._to_biclosed(tree.cat),children=children,metadata={'original':tree})
[docs]@staticmethoddefavailable_models()->list[str]:"""List the available models."""return[*MODELS]