Source code for mwptoolkit.utils.data_structure

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 22:14:55
# @File: data_structure.py


from mwptoolkit.utils.enum_type import SpecialTokens, NumMask


[docs]class Node(): """node """ def __init__(self, node_value, isleaf=True): self.node_value = node_value self.is_leaf = isleaf self.embedding = None self.left_node = None self.right_node = None
[docs] def set_left_node(self, node): self.left_node = node
[docs] def set_right_node(self, node): self.right_node = node
[docs]class AbstractTree(): def __init__(self): self.root = None
[docs] def equ2tree(): raise NotImplementedError
[docs] def tree2equ(): raise NotImplementedError
[docs]class BinaryTree(AbstractTree): """binary tree """ def __init__(self, root_node=None): super().__init__() self.root = root_node
[docs] def equ2tree(self, equ_list, out_idx2symbol, op_list, input_var, emb): stack = [] for idx in equ_list: if idx == out_idx2symbol.index(SpecialTokens.PAD_TOKEN): break if idx == out_idx2symbol.index(SpecialTokens.EOS_TOKEN): break if out_idx2symbol[idx] in op_list: node = Node(idx, isleaf=False) node.set_right_node(stack.pop()) node.set_left_node(stack.pop()) stack.append(node) else: node = Node(idx, isleaf=True) position = (input_var == idx).nonzero() node.node_embeding = emb[position] stack.append(node) self.root = stack.pop()
[docs] def equ2tree_(self, equ_list): stack = [] for symbol in equ_list: if symbol in [SpecialTokens.EOS_TOKEN, SpecialTokens.PAD_TOKEN]: break if symbol in ['+', '-', '*', '/', '^', '=', SpecialTokens.BRG_TOKEN, SpecialTokens.OPT_TOKEN]: node = Node(symbol, isleaf=False) node.set_right_node(stack.pop()) node.set_left_node(stack.pop()) stack.append(node) else: node = Node(symbol, isleaf=True) stack.append(node) if len(stack)>1: raise IndexError self.root = stack.pop()
[docs] def tree2equ(self, node): equation = [] if node.is_leaf: equation.append(node.node_value) return equation right_equ = self.tree2equ(node.right_node) left_equ = self.tree2equ(node.left_node) equation = left_equ + right_equ + [node.node_value] return equation
[docs]class PrefixTree(BinaryTree): def __init__(self, root_node): super().__init__(root_node=root_node)
[docs] def prefix2tree(self,equ_list): stack = [] for symbol in equ_list[::-1]: if symbol in [SpecialTokens.EOS_TOKEN, SpecialTokens.PAD_TOKEN]: break if symbol in ['+', '-', '*', '/', '^', '=', SpecialTokens.BRG_TOKEN, SpecialTokens.OPT_TOKEN]: node = Node(symbol, isleaf=False) node.set_right_node(stack.pop()) node.set_left_node(stack.pop()) stack.append(node) else: node = Node(symbol, isleaf=True) stack.append(node) self.root = stack.pop()
[docs]class GoldTree(AbstractTree): def __init__(self, root_node=None, gold_ans=None): super().__init__() self.root = root_node self.gold_ans = gold_ans
[docs] def equ2tree(self, equ_list, out_idx2symbol, op_list, num_list, ans): stack = [] for idx in equ_list: if idx == out_idx2symbol.index(SpecialTokens.PAD_TOKEN): break if idx == out_idx2symbol.index(SpecialTokens.EOS_TOKEN): break symbol = out_idx2symbol[idx] if symbol in op_list: node = Node(symbol, isleaf=False) node.set_right_node(stack.pop()) node.set_left_node(stack.pop()) stack.append(node) else: if symbol in NumMask.number: i = NumMask.number.index(symbol) value = num_list[i] node = Node(value, isleaf=True) elif symbol == SpecialTokens.UNK_TOKEN: node = Node('-inf', isleaf=True) else: node = Node(symbol, isleaf=True) stack.append(node) self.root = stack.pop() self.gold_ans = ans
[docs] def is_float(self, num_str, num_list): if num_str in num_list: return True else: return False
[docs] def is_equal(self, v1, v2): if v1 == v2: return True else: return False
[docs] def lca(self, root, va, vb, parent): left = False right = False if not self.result and root.left_node: left = self.lca(root.left_node, va, vb, root) if not self.result and root.right_node: right = self.lca(root.right_node, va, vb, root) mid = False if self.is_equal(root.node_value, va) or self.is_equal(root.node_value, vb): mid = True if not self.result and (left + right + mid) == 2: if mid: self.result = parent else: self.result = root return left or mid or right
[docs] def is_in_rel_quants(self, value, rel_quants): if value in rel_quants: return True else: return False
[docs] def query(self, va, vb): if self.root == None: return None self.result = None self.lca(self.root, va, vb, None) if self.result: return self.result.node_value else: return self.result
[docs]class DependencyNode(): def __init__(self, node_value, position, relation, is_leaf=True): self.node_value = node_value self.position = position self.relation = relation self.embedding = None self.left_nodes = [] self.right_nodes = [] self.is_leaf = is_leaf
[docs] def add_left_node(self, node): self.left_nodes.append(node)
[docs] def add_right_node(self, node): self.right_nodes.append(node)
[docs]class DependencyTree(): def __init__(self, root_node=None): self.root = root_node
[docs] def sentence2tree(self, sentence, dependency_info): r''' dependency info [relation,child,father] ''' node_dict = {} for r, c, f in dependency_info: if f in node_dict: node_dict[f].append((r, c)) else: node_dict[f] = [(r, c)] relation, root_idx = node_dict[-1][0] child_list = node_dict.get(root_idx, []) if child_list: node = DependencyNode(sentence[root_idx], root_idx, relation, is_leaf=False) left_list, right_list = self._build_sub_node(root_idx, child_list, node_dict, sentence) for child in left_list: node.add_left_node(child) for child in right_list: node.add_right_node(child) else: node = DependencyNode(sentence[root_idx], root_idx, relation) self.root = node
def _build_sub_node(self, father_idx, child_list, node_dict, sentence): left_list = [] right_list = [] for relation, child_idx in child_list: sub_child_list = node_dict.get(child_idx, []) if sub_child_list: child_node = DependencyNode(sentence[child_idx], child_idx, relation, is_leaf=False) sub_left_list, sub_right_list = self._build_sub_node(child_idx, sub_child_list, node_dict, sentence) for node in sub_left_list: child_node.add_left_node(node) for node in sub_right_list: child_node.add_right_node(node) else: child_node = DependencyNode(sentence[child_idx], child_idx, relation) if child_idx < father_idx: left_list.append(child_node) else: right_list.append(child_node) return left_list, right_list
[docs]class Tree(): def __init__(self): self.parent = None self.num_children = 0 self.children = [] def __str__(self, level=0): ret = "" for child in self.children: if isinstance(child, type(self)): ret += child.__str__(level + 1) else: ret += "\t" * level + str(child) + "\n" return ret
[docs] def add_child(self, c): if isinstance(c, type(self)): c.parent = self self.children.append(c) self.num_children = self.num_children + 1
[docs] def to_string(self): r_list = [] for i in range(self.num_children): if isinstance(self.children[i], Tree): r_list.append("( " + self.children[i].to_string() + " )") else: r_list.append(str(self.children[i])) return "".join(r_list)
[docs] def to_list(self, out_idx2symbol): r_list = [] for i in range(self.num_children): if isinstance(self.children[i], type(self)): cl = self.children[i].to_list(out_idx2symbol) r_list.append(cl) elif self.children[i] == out_idx2symbol.index(SpecialTokens.NON_TOKEN): continue elif self.children[i] == out_idx2symbol.index(SpecialTokens.EOS_TOKEN): continue else: r_list.append(self.children[i]) return r_list