Source code for mwptoolkit.evaluate.evaluator

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 19:19:55
# @File: evaluator.py


import copy
import re
import threading
from typing import Type, Union

import sympy as sym

from mwptoolkit.config.configuration import Config
from mwptoolkit.utils.enum_type import SpecialTokens, OPERATORS, NumMask, MaskSymbol, FixType
from mwptoolkit.utils.preprocess_tools import from_infix_to_postfix


[docs]class Solver(threading.Thread): r"""time-limited equation-solving mechanism based threading. """ def __init__(self, func, equations, unk_symbol): super(Solver, self).__init__() """ Args: func (function): a function to solve equations. equations (list): list of expressions. unk_symbol (list): list of unknown symbols. """ self.func = func self.equations = equations self.unk_symbol = unk_symbol
[docs] def run(self): """run equation solving process """ try: self.result = self.func(self.equations, self.unk_symbol) except: self.result = None
[docs] def get_result(self): """return the result """ try: return self.result except: return None
[docs]class AbstractEvaluator(object): """abstract evaluator """ def __init__(self, config): super().__init__() self.share_vocab = config["share_vocab"] self.mask_symbol = config["mask_symbol"] self.task_type = config["task_type"] self.single = config["single"] self.linear = config["linear"]
[docs] def result(self): raise NotImplementedError
[docs] def result_multi(self): raise NotImplementedError
[docs]class InfixEvaluator(AbstractEvaluator): r"""evaluator for infix equation sequnence. """ def __init__(self, config): super().__init__(config)
[docs] def result(self, test_exp, tar_exp): """evaluate single equation. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): iist of target expression. """ if (self.single and self.linear) != True: # single but non-linear return self.result_multi(test_exp, tar_exp) if test_exp == []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: if abs(self._compute_expression_by_postfix(test_exp) - self._compute_expression_by_postfix(tar_exp)) < 1e-4: return True, False, tar_exp, tar_exp else: return False, False, tar_exp, tar_exp except: return False, False, tar_exp, tar_exp
[docs] def result_multi(self, test_exp, tar_exp): """evaluate multiple euqations. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if test_exp == []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: test_solves, test_unk = self._compute_expression_by_postfix_multi(test_exp) tar_solves, tar_unk = self._compute_expression_by_postfix_multi(tar_exp) if len(test_unk) != len(tar_unk): return False, False, test_exp, tar_exp flag = False if len(tar_unk) == 1: if len(tar_solves) == 1: test_ans = test_solves[list(test_unk.values())[0]] tar_ans = tar_solves[list(tar_unk.values())[0]] if abs(test_ans - tar_ans) < 1e-4: flag = True else: flag = True for test_ans, tar_ans in zip(test_solves, tar_solves): if abs(test_ans[0] - tar_ans[0]) > 1e-4: flag = False break else: if len(tar_solves) == len(tar_unk): flag = True for tar_x in list(tar_unk.values()): test_ans = test_solves[tar_x] tar_ans = tar_solves[tar_x] if abs(test_ans - tar_ans) > 1e-4: flag = False break else: for test_ans, tar_ans in zip(test_solves, tar_solves): try: te_ans = float(test_ans[0]) except: te_ans = float(test_ans[1]) try: ta_ans = float(tar_ans[0]) except: ta_ans = float(tar_ans[1]) if abs(te_ans - ta_ans) > 1e-4: flag = False break if flag == True: return True, False, tar_exp, tar_exp else: flag = True test_solves_list = list(test_solves.values()) target_solvers_list = list(tar_solves.values()) t1 = sorted(test_solves_list) t2 = sorted(target_solvers_list) for v1, v2 in zip(t1, t2): if abs(v1 - v2) > 1e-4: flag = False break if flag: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, tar_exp, tar_exp
def _compute_postfix_expression(self, post_fix): st = list() operators = ["+", "-", "^", "*", "/"] for p in post_fix: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(b + a) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(b * a) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if a == 0: return None st.append(b / a) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(b - a) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(a) != 2.0 and float(a) != 3.0: return None st.append(b**a) else: return None if len(st) == 1: return st.pop() return None def _compute_postfix_expression_multi(self, post_fix): st = list() operators = ["+", "-", "^", "*", "/", "=", "<BRG>"] unk_symbols = {} for p in post_fix: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) elif p.isalpha(): if p in unk_symbols: st.append(unk_symbols[p]) else: x = sym.symbols(p) st.append(x) unk_symbols[p] = x else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(b + a) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(b * a) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if a == 0: return None, unk_symbols st.append(b / a) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(b - a) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(a) != 2.0 and float(a) != 3.0: return None, unk_symbols st.append(b**a) elif p == "=": a = st.pop() b = st.pop() st.append([sym.Eq(b, a)]) elif p == "<BRG>": a = st.pop() b = st.pop() st.append(b + a) else: return None, unk_symbols if len(st) == 1: equations = st.pop() unk_list = list(unk_symbols.values()) t = Solver(sym.solve, equations, unk_list) t.setDaemon(True) t.start() t.join(10) result = t.get_result() return result, unk_symbols return None, unk_symbols def _compute_expression_by_postfix(self, expression): try: post_exp = from_infix_to_postfix(expression) except: return None return self._compute_postfix_expression(post_exp)
[docs] def _compute_expression_by_postfix_multi(self, expression): r"""return solves and unknown number list """ try: post_exp = from_infix_to_postfix(expression) except: return None, None return self._compute_postfix_expression_multi(post_exp)
[docs]class PrefixEvaluator(AbstractEvaluator): r"""evaluator for prefix equation. """ def __init__(self, config): super().__init__(config)
[docs] def result(self, test_exp, tar_exp): """evaluate single equation. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if (self.single and self.linear) != True: # single but non-linear return self.result_multi(test_exp, tar_exp) if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: if abs(self._compute_prefix_expression(test_exp) - self._compute_prefix_expression(tar_exp)) < 1e-4: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp
[docs] def result_multi(self, test_exp, tar_exp): """evaluate multiple euqations. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: test_solves, test_unk = self._compute_prefix_expression_multi(test_exp) tar_solves, tar_unk = self._compute_prefix_expression_multi(tar_exp) if len(test_unk) != len(tar_unk): return False, False, test_exp, tar_exp flag = False if len(tar_unk) == 1: if len(tar_solves) == 1: test_ans = test_solves[list(test_unk.values())[0]] tar_ans = tar_solves[list(tar_unk.values())[0]] if abs(test_ans - tar_ans) < 1e-4: flag = True else: flag = True for test_ans, tar_ans in zip(test_solves, tar_solves): if abs(test_ans[0] - tar_ans[0]) > 1e-4: flag = False break else: if len(tar_solves) == len(tar_unk): flag = True for tar_x in list(tar_unk.values()): test_ans = test_solves[tar_x] tar_ans = tar_solves[tar_x] if abs(test_ans - tar_ans) > 1e-4: flag = False break else: for test_ans, tar_ans in zip(test_solves, tar_solves): try: te_ans = float(test_ans[0]) except: te_ans = float(test_ans[1]) try: ta_ans = float(tar_ans[0]) except: ta_ans = float(tar_ans[1]) if abs(te_ans - ta_ans) > 1e-4: flag = False break if flag == True: return True, False, test_exp, tar_exp else: flag = True test_solves_list = list(test_solves.values()) target_solvers_list = list(tar_solves.values()) t1 = sorted(test_solves_list) t2 = sorted(target_solvers_list) for v1, v2 in zip(t1, t2): if abs(v1 - v2) > 1e-4: flag = False break if flag: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp return False, False, test_exp, tar_exp
def _compute_prefix_expression(self, pre_fix): st = list() operators = ["+", "-", "^", "*", "/"] pre_fix_ = copy.deepcopy(pre_fix) pre_fix_.reverse() for p in pre_fix_: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(a + b) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(a * b) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(a * b) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if b == 0: return None st.append(a / b) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(a - b) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(b) != 2.0 and float(b) != 3.0: return None st.append(a**b) else: return None if len(st) == 1: return st.pop() return None def _compute_prefix_expression_multi(self, pre_fix): st = list() operators = ["+", "-", "^", "*", "/", "=", "<BRG>"] unk_symbols = {} pre_fix_ = copy.deepcopy(pre_fix) pre_fix_.reverse() for p in pre_fix_: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) elif p.isalpha(): if p in unk_symbols: st.append(unk_symbols[p]) else: x = sym.symbols(p) st.append(x) unk_symbols[p] = x else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(a + b) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(a * b) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if b == 0: return None st.append(a / b) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(a - b) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(b) != 2.0 and float(b) != 3.0: return None st.append(a**b) elif p == "=": a = st.pop() b = st.pop() st.append([sym.Eq(a, b)]) elif p == "<BRG>": a = st.pop() b = st.pop() st.append(a + b) else: return None if len(st) == 1: equations = st.pop() unk_list = list(unk_symbols.values()) t = Solver(sym.solve, equations, unk_list) t.setDaemon(True) t.start() t.join(10) result = t.get_result() return result, unk_symbols return None
[docs] def eval_source(self, test_res, test_tar, num_list, num_stack=None): raise NotImplementedError
[docs]class PostfixEvaluator(AbstractEvaluator): r"""evaluator for postfix equation. """ def __init__(self, config): super().__init__(config)
[docs] def result(self, test_exp, tar_exp): """evaluate single equation. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if (self.single and self.linear) != True: # single but non-linear return self.result_multi(test_exp, tar_exp) if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: if abs(self._compute_postfix_expression(test_exp) - self._compute_postfix_expression(tar_exp)) < 1e-4: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp
[docs] def result_multi(self, test_exp, tar_exp): """evaluate multiple euqations. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: test_solves, test_unk = self._compute_postfix_expression_multi(test_exp) tar_solves, tar_unk = self._compute_postfix_expression_multi(tar_exp) if len(test_unk) != len(tar_unk): return False, False, test_exp, tar_exp flag = False if len(tar_unk) == 1: if len(tar_solves) == 1: test_ans = test_solves[list(test_unk.values())[0]] tar_ans = tar_solves[list(tar_unk.values())[0]] if abs(test_ans - tar_ans) < 1e-4: flag = True else: flag = True for test_ans, tar_ans in zip(test_solves, tar_solves): if abs(test_ans[0] - tar_ans[0]) > 1e-4: flag = False break else: if len(tar_solves) == len(tar_unk): flag = True for tar_x in list(tar_unk.values()): test_ans = test_solves[tar_x] tar_ans = tar_solves[tar_x] if abs(test_ans - tar_ans) > 1e-4: flag = False break else: for test_ans, tar_ans in zip(test_solves, tar_solves): try: te_ans = float(test_ans[0]) except: te_ans = float(test_ans[1]) try: ta_ans = float(tar_ans[0]) except: ta_ans = float(tar_ans[1]) if abs(te_ans - ta_ans) > 1e-4: flag = False break if flag == True: return True, False, test_exp, tar_exp else: flag = True test_solves_list = list(test_solves.values()) target_solvers_list = list(tar_solves.values()) t1 = sorted(test_solves_list) t2 = sorted(target_solvers_list) for v1, v2 in zip(t1, t2): if abs(v1 - v2) > 1e-4: flag = False break if flag: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp return False, False, test_exp, tar_exp
def _compute_postfix_expression(self, post_fix): st = list() operators = ["+", "-", "^", "*", "/"] for p in post_fix: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(b + a) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(a * b) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if a == 0: return None st.append(b / a) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(b - a) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(a) != 2.0 and float(a) != 3.0: return None st.append(b**a) else: return None if len(st) == 1: return st.pop() return None def _compute_postfix_expression_multi(self, post_fix): st = list() operators = ["+", "-", "^", "*", "/", "=", "<BRG>"] unk_symbols = {} for p in post_fix: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) elif p.isalpha(): if p in unk_symbols: st.append(unk_symbols[p]) else: x = sym.symbols(p) st.append(x) unk_symbols[p] = x else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(b + a) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(b * a) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if a == 0: return None, unk_symbols st.append(b / a) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(b - a) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(a) != 2.0 and float(a) != 3.0: return None, unk_symbols st.append(b**a) elif p == "=": a = st.pop() b = st.pop() st.append([sym.Eq(b, a)]) elif p == "<BRG>": a = st.pop() b = st.pop() st.append(b + a) else: return None, unk_symbols if len(st) == 1: equations = st.pop() unk_list = list(unk_symbols.values()) t = Solver(sym.solve, equations, unk_list) t.setDaemon(True) t.start() t.join(10) result = t.get_result() return result, unk_symbols return None, unk_symbols
[docs] def eval_source(self): raise NotImplementedError
[docs]class MultiWayTreeEvaluator(AbstractEvaluator): def __init__(self, config): super().__init__(config)
[docs] def result(self, test_exp, tar_exp): """evaluate single equation. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if (self.single and self.linear) != True: # single but non-linear return self.result_multi(test_exp, tar_exp) if test_exp == []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: if abs(self._compute_expression_by_postfix(test_exp) - self._compute_expression_by_postfix(tar_exp)) < 1e-4: return True, False, tar_exp, tar_exp else: return False, False, tar_exp, tar_exp except: return False, False, tar_exp, tar_exp
[docs] def result_multi(self, test_exp, tar_exp): r"""evaluate multiple euqations. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if test_exp == []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: test_solves, test_unk = self._compute_expression_by_postfix_multi(test_exp) tar_solves, tar_unk = self._compute_expression_by_postfix_multi(tar_exp) if len(test_unk) != len(tar_unk): return False, False, test_exp, tar_exp flag = False if len(tar_unk) == 1: if len(tar_solves) == 1: test_ans = test_solves[list(test_unk.values())[0]] tar_ans = tar_solves[list(tar_unk.values())[0]] if abs(test_ans - tar_ans) < 1e-4: flag = True else: flag = True for test_ans, tar_ans in zip(test_solves, tar_solves): if abs(test_ans[0] - tar_ans[0]) > 1e-4: flag = False break else: if len(tar_solves) == len(tar_unk): flag = True for tar_x in list(tar_unk.values()): test_ans = test_solves[tar_x] tar_ans = tar_solves[tar_x] if abs(test_ans - tar_ans) > 1e-4: flag = False break else: for test_ans, tar_ans in zip(test_solves, tar_solves): try: te_ans = float(test_ans[0]) except: te_ans = float(test_ans[1]) try: ta_ans = float(tar_ans[0]) except: ta_ans = float(tar_ans[1]) if abs(te_ans - ta_ans) > 1e-4: flag = False break if flag == True: return True, False, tar_exp, tar_exp else: flag = True test_solves_list = list(test_solves.values()) target_solvers_list = list(tar_solves.values()) t1 = sorted(test_solves_list) t2 = sorted(target_solvers_list) for v1, v2 in zip(t1, t2): if abs(v1 - v2) > 1e-4: flag = False break if flag: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, tar_exp, tar_exp
def _compute_postfix_expression(self, post_fix): st = list() operators = ["+", "-", "^", "*", "/"] for p in post_fix: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(b + a) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(b * a) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if a == 0: return None st.append(b / a) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(b - a) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(a) != 2.0 and float(a) != 3.0: return None st.append(b**a) else: return None if len(st) == 1: return st.pop() return None def _compute_postfix_expression_multi(self, post_fix): st = list() operators = ["+", "-", "^", "*", "/", "=", "<BRG>"] unk_symbols = {} for p in post_fix: if p not in operators: pos = re.search("\d+\(", p) if pos: st.append(eval(p[pos.start():pos.end() - 1] + "+" + p[pos.end() - 1:])) elif p[-1] == "%": st.append(float(p[:-1]) / 100) elif p.isalpha(): if p in unk_symbols: st.append(unk_symbols[p]) else: x = sym.symbols(p) st.append(x) unk_symbols[p] = x else: st.append(eval(p)) elif p == "+" and len(st) > 1: a = st.pop() b = st.pop() st.append(b + a) elif p == "*" and len(st) > 1: a = st.pop() b = st.pop() st.append(b * a) elif p == "/" and len(st) > 1: a = st.pop() b = st.pop() if a == 0: return None, unk_symbols st.append(b / a) elif p == "-" and len(st) > 1: a = st.pop() b = st.pop() st.append(b - a) elif p == "^" and len(st) > 1: a = st.pop() b = st.pop() if float(a) != 2.0 and float(a) != 3.0: return None, unk_symbols st.append(b**a) elif p == "=": a = st.pop() b = st.pop() st.append([sym.Eq(b, a)]) elif p == "<BRG>": a = st.pop() b = st.pop() st.append(b + a) else: return None, unk_symbols if len(st) == 1: equations = st.pop() unk_list = list(unk_symbols.values()) t = Solver(sym.solve, equations, unk_list) t.setDaemon(True) t.start() t.join(10) result = t.get_result() return result, unk_symbols return None, unk_symbols def _compute_expression_by_postfix(self, expression): try: post_exp = from_infix_to_postfix(expression) except: return None return self._compute_postfix_expression(post_exp)
[docs] def _compute_expression_by_postfix_multi(self, expression): r"""return solves and unknown number list """ try: post_exp = from_infix_to_postfix(expression) except: return None, None return self._compute_postfix_expression_multi(post_exp)
[docs]class MultiEncDecEvaluator(PostfixEvaluator, PrefixEvaluator): r"""evaluator for deep-learning model MultiE&D. """ def __init__(self, config): super().__init__(config)
[docs] def prefix_result(self, test_exp, tar_exp): """evaluate single prefix equation. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if (self.single and self.linear) != True: # single but non-linear return self.prefix_result_multi(test_exp, tar_exp) if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: if abs(self._compute_prefix_expression(test_exp) - self._compute_prefix_expression(tar_exp)) < 1e-4: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp
[docs] def prefix_result_multi(self, test_exp, tar_exp): """evaluate multiple prefix euqations. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: test_solves, test_unk = self._compute_prefix_expression_multi(test_exp) tar_solves, tar_unk = self._compute_prefix_expression_multi(tar_exp) if len(test_unk) != len(tar_unk): return False, False, test_exp, tar_exp flag = False if len(tar_unk) == 1: if len(tar_solves) == 1: test_ans = test_solves[list(test_unk.values())[0]] tar_ans = tar_solves[list(tar_unk.values())[0]] if abs(test_ans - tar_ans) < 1e-4: flag = True else: flag = True for test_ans, tar_ans in zip(test_solves, tar_solves): if abs(test_ans[0] - tar_ans[0]) > 1e-4: flag = False break else: if len(tar_solves) == len(tar_unk): flag = True for tar_x in list(tar_unk.values()): test_ans = test_solves[tar_x] tar_ans = tar_solves[tar_x] if abs(test_ans - tar_ans) > 1e-4: flag = False break else: for test_ans, tar_ans in zip(test_solves, tar_solves): try: te_ans = float(test_ans[0]) except: te_ans = float(test_ans[1]) try: ta_ans = float(tar_ans[0]) except: ta_ans = float(tar_ans[1]) if abs(te_ans - ta_ans) > 1e-4: flag = False break if flag == True: return True, False, test_exp, tar_exp else: flag = True test_solves_list = list(test_solves.values()) target_solvers_list = list(tar_solves.values()) t1 = sorted(test_solves_list) t2 = sorted(target_solvers_list) for v1, v2 in zip(t1, t2): if abs(v1 - v2) > 1e-4: flag = False break if flag: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp return False, False, test_exp, tar_exp
[docs] def postfix_result(self, test_exp, tar_exp): """evaluate single postfix equation. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if (self.single and self.linear) != True: # single but non-linear return self.postfix_result_multi(test_exp, tar_exp) if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: if abs(self._compute_postfix_expression(test_exp) - self._compute_postfix_expression(tar_exp)) < 1e-4: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp
[docs] def postfix_result_multi(self, test_exp, tar_exp): """evaluate multiple postfix euqations. Args: test_exp (list): list of test expression. tar_exp (list): list of target expression. Returns: (tuple(bool,bool,list,list)) val_ac (bool): the correctness of test expression answer compared to target expression answer. equ_ac (bool): the correctness of test expression compared to target expression. test_exp (list): list of test expression. tar_exp (list): list of target expression. """ if test_exp is []: return False, False, test_exp, tar_exp if test_exp == tar_exp: return True, True, test_exp, tar_exp try: test_solves, test_unk = self._compute_postfix_expression_multi(test_exp) tar_solves, tar_unk = self._compute_postfix_expression_multi(tar_exp) if len(test_unk) != len(tar_unk): return False, False, test_exp, tar_exp flag = False if len(tar_unk) == 1: if len(tar_solves) == 1: test_ans = test_solves[list(test_unk.values())[0]] tar_ans = tar_solves[list(tar_unk.values())[0]] if abs(test_ans - tar_ans) < 1e-4: flag = True else: flag = True for test_ans, tar_ans in zip(test_solves, tar_solves): if abs(test_ans[0] - tar_ans[0]) > 1e-4: flag = False break else: if len(tar_solves) == len(tar_unk): flag = True for tar_x in list(tar_unk.values()): test_ans = test_solves[tar_x] tar_ans = tar_solves[tar_x] if abs(test_ans - tar_ans) > 1e-4: flag = False break else: for test_ans, tar_ans in zip(test_solves, tar_solves): try: te_ans = float(test_ans[0]) except: te_ans = float(test_ans[1]) try: ta_ans = float(tar_ans[0]) except: ta_ans = float(tar_ans[1]) if abs(te_ans - ta_ans) > 1e-4: flag = False break if flag == True: return True, False, test_exp, tar_exp else: flag = True test_solves_list = list(test_solves.values()) target_solvers_list = list(tar_solves.values()) t1 = sorted(test_solves_list) t2 = sorted(target_solvers_list) for v1, v2 in zip(t1, t2): if abs(v1 - v2) > 1e-4: flag = False break if flag: return True, False, test_exp, tar_exp else: return False, False, test_exp, tar_exp except: return False, False, test_exp, tar_exp return False, False, test_exp, tar_exp
[docs] def result(self, test_exp, tar_exp): raise NotImplementedError
[docs] def result_multi(self, test_exp, tar_exp): raise NotImplementedError
[docs]def get_evaluator(config): """build evaluator Args: config (Config): An instance object of Config, used to record parameter information. Returns: Evaluator: Constructed evaluator. """ if config["equation_fix"] == FixType.Prefix: evaluator = PrefixEvaluator(config) elif config["equation_fix"] == FixType.Nonfix or config["equation_fix"] == FixType.Infix: evaluator = InfixEvaluator(config) elif config["equation_fix"] == FixType.Postfix: evaluator = PostfixEvaluator(config) elif config["equation_fix"] == FixType.MultiWayTree: evaluator = MultiWayTreeEvaluator(config) else: raise NotImplementedError if config['model'].lower() in ['multiencdec']: evaluator = MultiEncDecEvaluator(config) return evaluator
[docs]def get_evaluator_module(config: Config) -> Type[Union[PrefixEvaluator,InfixEvaluator,PostfixEvaluator,MultiWayTreeEvaluator,AbstractEvaluator,MultiEncDecEvaluator]]: """return a evaluator module according to config :param config: An instance object of Config, used to record parameter information. :return: evaluator module """ if config["equation_fix"] == FixType.Prefix: evaluator_module = PrefixEvaluator elif config["equation_fix"] == FixType.Nonfix or config["equation_fix"] == FixType.Infix: evaluator_module = InfixEvaluator elif config["equation_fix"] == FixType.Postfix: evaluator_module = PostfixEvaluator elif config["equation_fix"] == FixType.MultiWayTree: evaluator_module = MultiWayTreeEvaluator else: evaluator_module = AbstractEvaluator if config['model'].lower() in ['multiencdec']: evaluator_module = MultiEncDecEvaluator return evaluator_module