# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from__future__importannotationsfromdataclassesimportdataclass,replacefromenumimportEnumfromfunctoolsimportcached_propertyfromtypingimportAnyfromlambeq.bobcat.lexiconimportAtom,Category,Feature,Relation@dataclassclassIndexedWord:"""A word in a sentence, annotated with its position (1-indexed)."""word:strindex:intdef__repr__(self)->str:returnf'{self.word}_{self.index}'@dataclassclassDependency:relation:Relationhead:IndexedWordvar:intunary_rule_id:intfiller:IndexedWord|None=Nonedefreplace(self,var:int,unary_rule_id:int|None=None)->Dependency:ifunary_rule_idisNone:unary_rule_id=self.unary_rule_idreturnreplace(self,var=var,unary_rule_id=unary_rule_id)@classmethoddefgenerate(cls,cat:Category,unary_rule_id:int,head:IndexedWord|Variable)->list[Dependency]:ifcat.relation:ifisinstance(head,IndexedWord):deps=[cls(cat.relation,head,cat.var,unary_rule_id)]else:deps=[cls(cat.relation,filler,cat.var,unary_rule_id)forfillerinhead.fillers]else:deps=[]ifcat.complex:forcin(cat.result,cat.argument):deps+=cls.generate(c,unary_rule_id,head)returndepsdeffill(self,var:Variable)->list[Dependency]:return[Dependency(self.relation,self.head,0,self.unary_rule_id,filler)forfillerinvar.fillers]def__str__(self)->str:return(f'{self.head}{self.relation}{self.filler} 'f'{self.unary_rule_id}')@dataclassclassVariable:fillers:list[IndexedWord]filled:booldef__init__(self,word:IndexedWord|None=None)->None:ifwordisnotNone:self.fillers=[word]else:self.fillers=[]self.filled=Truedef__add__(self,other:Any)->Variable:ret=Variable()ret.fillers=self.fillers+other.fillersreturnretdefas_filled(self,filled:bool)->Variable:iffilled==self.filled:returnselfret=Variable()ret.fillers=self.fillersret.filled=filledreturnret@propertydeffiller(self)->IndexedWord:returnself.fillers[0]classUnify:def__init__(self,left:ParseTree,right:ParseTree,result_is_left:bool)->None:self.feature=Feature.NONEself.num_variables=1self.trans_left:dict[int,int]={}self.trans_right:dict[int,int]={}self.old_left:dict[int,int]={}self.old_right:dict[int,int]={}self.left=leftself.right=rightself.result_is_left=result_is_leftifresult_is_left:self.res,self.arg=left.cat,right.catself.trans_res,self.trans_arg=self.trans_left,self.trans_rightelse:self.arg,self.res=left.cat,right.catself.trans_arg,self.trans_res=self.trans_left,self.trans_rightdefunify(self,arg:Category,res:Category)->bool:ifself.result_is_left:left,right=res,argelse:left,right=arg,resifnotself.unify_recursive(left,right):returnFalseself.add_vars(self.arg,self.trans_arg)self.add_vars(self.res,self.trans_res)returnTruedefunify_recursive(self,left:Category,right:Category)->bool:ifleft.atomic:ifleft.atom!=right.atom:returnFalseifleft.atom==Atom.S:ifleft.feature==Feature.X:self.feature=right.featureelifright.feature==Feature.X:self.feature=left.featureelifleft.feature!=right.feature:returnFalseelse:ifnot(left.dir==right.dirandself.unify_recursive(left.result,right.result)andself.unify_recursive(left.argument,right.argument)):returnFalseif(left.varnotinself.trans_leftandright.varnotinself.trans_right):try:v1=self.left.var_map[left.var]v2=self.right.var_map[right.var]exceptKeyError:passelse:ifv1.filledandv2.filled:returnFalseself.trans_left[left.var]=self.num_variablesself.trans_right[right.var]=self.num_variablesself.old_left[self.num_variables]=left.varself.old_right[self.num_variables]=right.varself.num_variables+=1returnTruedefadd_vars(self,cat:Category,trans:dict[int,int])->None:old=self.old_leftiftransisself.trans_leftelseself.old_rightforvarincat.vars:ifvarnotintrans:trans[var]=self.num_variablesold[self.num_variables]=varself.num_variables+=1defget_new_outer_var(self)->int:returnself.trans_left.get(self.left.cat.var,0)deftranslate_arg(self,category:Category)->Category:returncategory.translate(self.trans_arg,self.feature)deftranslate_res(self,category:Category)->Category:returncategory.translate(self.trans_res,self.feature)classRule(Enum):"""The possible CCG rules."""NONE=0L=1U=2BA=3FA=4BC=5FC=6BX=7GBC=8GFC=9GBX=10LP=11RP=12BTR=13FTR=14CONJ=15ADJ_CONJ=16
[docs]@dataclassclassParseTree:rule:Rulecat:Categoryleft:ParseTreeright:ParseTreeunfilled_deps:list[Dependency]filled_deps:list[Dependency]var_map:dict[int,Variable]score:float=0@propertydefword(self)->str:ifself.is_leaf:returnself.variable.filler.wordelse:raiseAttributeError('only leaves have words')@propertydefvariable(self)->Variable:try:returnself.var_map[self.cat.var]exceptKeyErrorase:raiseAttributeError('variable is not in map')frome@propertydefis_leaf(self)->bool:returnself.rule==Rule.L@propertydefcoordinated_or_type_raised(self)->bool:returnself.rulein(Rule.CONJ,Rule.BTR,Rule.FTR)@propertydefcoordinated(self)->bool:returnself.rule==Rule.CONJ@propertydefbwd_comp(self)->bool:returnself.rulein(Rule.BC,Rule.GBC)@propertydeffwd_comp(self)->bool:returnself.rulein(Rule.FC,Rule.GFC)@cached_propertydefdeps_and_tags(self)->tuple[list[Dependency],list[str]]:# pragma: no coverdeps=self.filled_deps.copy()tags=[]ifself.left:forchildin(self.left,self.right):ifchild:child_deps,child_tags=child.deps_and_tagsdeps+=child_depstags+=child_tagselse:tags.append(str(self.cat).replace('[X]',''))deps.sort(key=lambdadep:(dep.head.index,dep.filler.index))returndeps,tags@propertydefdeps(self)->list[Dependency]:returnself.deps_and_tags[0]