# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or# implied. See the License for the specific language governing# permissions and limitations under the License.fromdataclassesimportdataclass,fieldfromfunctoolsimportcached_propertyfromtypingimportOptionalfromlambeq.backend.grammarimportDiagram,TyROOT_INDEX=-1classPregroupTreeNodeError(Exception):def__init__(self,sentence:str)->None:self.sentence=sentencedef__str__(self)->str:returnf'PregroupTreeNode failed to parse {self.sentence!r}.'
[docs]@dataclassclassPregroupTreeNode:""" A node in a pregroup tree. A pregroup tree is a compact tree-like representation of a pregroup diagram. Each node in the tree represents a token in the sentence, annotated with the pregroup type of the outcome wire(s) expected from this word, e.g. `s` for a verb, `n` for an adjective, `n.r@s` for an adverb and so on. The root of the tree is the head word in the sentence (i.e a word with free wires, usually an `s`, that deliver the state of the sentence after the composition), and the branches of the tree represent cups identifying input wires (cups) to the parent node. Examples -------- Consider the sentence "John gave Mary a flower", with the following pregroup diagram: .. code-block:: console John gave Mary a flower ──── ───────────── ──── ───── ────── n n.r·s·n.l·n.l n n·n.l n ╰────╯ │ │ ╰────╯ │ ╰─────╯ │ ╰─────────────╯ The tree for this diagram becomes: .. code-block:: console gave_1 (s) ├ John_0 (n) ├ Mary_2 (n) └ a_3 (n) └ flower_4 (n) where the numbers after the underscore indicate the order of each word in the sentence. This representation is sufficient for reconstructing the original pregroup diagram, since the original types of the nodes can be recovered by following the parent-child relationships in the tree and adding the necessary type adjoints to accomodate the arguments in the type of the parent. Notes ----- Since the original pregroup diagram can contain cycles, any nodes with more than one parents will be duplicated to allow a tree-like representation. """word:strind:inttyp:Tytyp_indxs:list[int]=field(default_factory=list)parent:Optional['PregroupTreeNode']=Nonechildren:list['PregroupTreeNode']=field(default_factory=list)_children_words:list[tuple[str,int]]=field(default_factory=list)def__post_init__(self)->None:"""Create a list of `(word, word index)` tuples which is used in several conversion functions."""self._children_words=[(child.word,child.ind)forchildinself.children]forchildinself.children:child.parent=selfdef__repr__(self)->str:"""Return the string representation of the node."""returnf'{self.word}_{self.ind} ({self.typ})'def__lt__(self,other:'PregroupTreeNode')->bool:returnself.ind<other.inddef__gt__(self,other:'PregroupTreeNode')->bool:returnself.ind>other.inddef__eq__(self,other:object)->bool:"""Check if these are the same instances, including the children and parent."""ifnotisinstance(other,PregroupTreeNode):returnNotImplementedreturn(self.word==other.wordandself.ind==other.indandself.typ==other.typand((self.parent.indifself.parentelseNone)==(other.parent.indifother.parentelseNone))andsorted(self.children)==sorted(other.children))
[docs]defis_same_word(self,other:object)->bool:"""Check if these words are the same words in the sentence. This is a relaxed version of the `__eq__` function which doesn't check equality of the children - essentially, this just checks if `other` is the same token."""ifnotisinstance(other,PregroupTreeNode):returnNotImplemented# type: ignore[no-any-return]return(self.word==other.wordandself.ind==other.ind)
@cached_propertydef_tree_repr(self)->str:"""The string representation of the entire tree."""ifnotself.children:# Leafreturnstr(self)else:out=str(self)n_children=len(self.children)fori,childinenumerate(self.children):child_lines=child._tree_repr.split('\n')lines=[]forj,linenumerate(child_lines):ifj==0:prefix='└'ifi==n_children-1else'├'else:prefix='│'ifi!=n_children-1else' 'lines.append(f'\n{prefix}{l}')out+=''.join(lines)returnout
[docs]defdraw(self)->None:"""Draw the tree."""print(self._tree_repr)
@cached_propertydefheight(self)->int:"""The height of the tree."""h=1curr_nodes=[self]next_nodes=[]whilecurr_nodes:fornodeincurr_nodes:next_nodes.extend(node.children)iflen(next_nodes):h+=1curr_nodes=next_nodesnext_nodes=[]returnh
[docs]defget_types(self,as_str:bool=True)->list[list[str]]|list[list[Ty]]:"""Return the types of each node in the tree. Parameters ---------- as_str : bool Whether to return the types as str or as Ty Returns ------- list[list[str] | list[Ty]] List of the string representations of the types or the types of each node indexed by the word order. """nodes_list=self.get_nodes()types_list=[[n.typforninnodes]fornodesinnodes_list]ifas_str:return[[str(t)fortintys]fortysintypes_list]returntypes_list
[docs]defget_parents(self)->list[list[int]]:"""Return the indices of the parents of each node in the tree. Returns ------- list[list[int]] List of the indices of the parents of each node, in the original sentence. The parent of the root node is assigned an index of [-1]. This is indexed by the word order. """nodes_list=self.get_nodes()parents_list=[[n.parent.indifn.parentelse-1forninnodes]fornodesinnodes_list]returnparents_list
[docs]defget_words(self)->list[str]:"""Return the words for each node in the tree. Returns ------- list of str List of the words corresponding to each node indexed by the word order. """nodes_list=self.get_nodes()words_list=[nodes[0].wordfornodesinnodes_list]returnwords_list
[docs]defget_word_indices(self)->list[int]:"""Return the indices of the word (in the original sentence) for each node in the tree. Returns ------- list of int List of the indices of the words corresponding to each node indexed by the word order. Notes ----- This is useful when the subtree doesn't form a span of the original sentence. """nodes_list=self.get_nodes()word_indices=[nodes[0].indfornodesinnodes_list]returnword_indices
def_get_nodes_flat(self)->list['PregroupTreeNode']:"""Return the nodes of the tree following the word order in the sentence. Returns ------- list of PregroupTreeNode List of the nodes corresponding to each word indexed by the word order but flattened. """nodes_list=[self]forchildinself.children:nodes_list.extend(child._get_nodes_flat())returnsorted(nodes_list,key=lambdan:n.ind)
[docs]defget_nodes(self)->list[list['PregroupTreeNode']]:"""Collect nodes from `_get_nodes_flat` into a list for each index so we have the same length as the number of words. Returns ------- list of list[PregroupTreeNode] List of the nodes corresponding to each word indexed by the word order. If cycles are present, multiple nodes will be assigned to the word. """flat_nodes_list=self._get_nodes_flat()nodes_list:list[list['PregroupTreeNode']]=[[]for_inrange(flat_nodes_list[-1].ind+1)]# Merge nodes with the same `ind` into a listfornodeinflat_nodes_list:nodes_list[node.ind].append(node)nodes_list=[nforninnodes_listifn]returnnodes_list
[docs]defget_root(self)->'PregroupTreeNode':"""Return the root of the tree where this node belongs to."""root=selfwhileroot.parentisnotNone:root=root.parentreturnroot
[docs]defget_depth(self,node:Optional['PregroupTreeNode']=None)->int:"""Return the depth of `self` node in the (sub)tree with `node` as its root Parameters ---------- node : PregroupTreeNode, optional, default is `None` The node which we will treat as the root. If not given, will try to find the root of the tree where `self` belongs to and compute the depth from that node. Returns ------- int The depth of this node in the tree. This is -1 if `self` is not in the tree rooted at `node`. """ifnodeisNone:node=self.get_root()depth=0curr_node:Optional['PregroupTreeNode']=selfnot_in_tree=Falsewhilecurr_node!=node:depth+=1ifcurr_nodeisNone:not_in_tree=Truebreakcurr_node=curr_node.parentifnot_in_tree:return-1returndepth
[docs]defmerge(self)->None:""" If `self` has only one children of the same type, this merges the words into one token while preserving the type. The minimum index is taken as the index of the new node. This modifies the calling node. """iflen(self.children)!=1:print('Cannot perform merge on node that '+"doesn't have exactly one child.")else:child=self.children[0]ifself.typ==child.typandabs(self.ind-child.ind)==1:# Perform mergeifself.ind<child.ind:self.word+=f' {child.word}'else:self.word=f'{child.word}{self.word}'self.ind=min(self.ind,child.ind)self.children=child.children# Modify parent of childchild.parent=Noneforcinself.children:c.parent=selfelse:print('Cannot perform merge when parent and child '+"types don't match or tokens are not consecutive.")
[docs]defremove_self_cycles(self)->None:"""Removes the children of this node that is the same token, i.e. self-cycles. This is used before breaking cycles. """new_children=[]forcinself.children:ifself.is_same_word(c):c.parent=Noneelse:new_children.append(c)self.children=new_childrenforcinself.children:c.remove_self_cycles()