Source code for mwptoolkit.utils.preprocess_tool.dataset_operator

import re

from mwptoolkit.utils.utils import str2float
from mwptoolkit.utils.enum_type import EPT
from mwptoolkit.utils.preprocess_tool.sentence_operator import find_ept_numbers_in_text
from mwptoolkit.utils.preprocess_tool.number_operator import constant_number
from mwptoolkit.utils.preprocess_tool.equation_operator import orig_infix_to_postfix,infix_to_postfix


[docs]def id_reedit(trainset, validset, testset): r"""if some datas of a dataset hava the same id, re-edit the id for differentiate them. example: There are two datas have the same id 709356. Make one of them be 709356 and the other be 709356-1. """ id_dict = {} for data in trainset: if not isinstance(data['id'], str): data['id'] = str(data['id']) try: id_dict[data['id']] = id_dict[data['id']] + 1 except: id_dict[data['id']] = 1 for data in validset: if not isinstance(data['id'], str): data['id'] = str(data['id']) try: id_dict[data['id']] = id_dict[data['id']] + 1 except: id_dict[data['id']] = 1 for data in testset: if not isinstance(data['id'], str): data['id'] = str(data['id']) try: id_dict[data['id']] = id_dict[data['id']] + 1 except: id_dict[data['id']] = 1 for data in trainset: old_id = data['id'] if id_dict[old_id] > 1: new_id = old_id + '-' + str(id_dict[old_id] - 1) data['id'] = new_id id_dict[old_id] = id_dict[old_id] - 1 for data in validset: old_id = data['id'] if id_dict[old_id] > 1: new_id = old_id + '-' + str(id_dict[old_id] - 1) data['id'] = new_id id_dict[old_id] = id_dict[old_id] - 1 for data in testset: old_id = data['id'] if id_dict[old_id] > 1: new_id = old_id + '-' + str(id_dict[old_id] - 1) data['id'] = new_id id_dict[old_id] = id_dict[old_id] - 1 return trainset, validset, testset
[docs]def preprocess_ept_dataset_(train_datas, valid_datas, test_datas, dataset_name): train_datas = ept_preprocess(train_datas, dataset_name) valid_datas = ept_preprocess(valid_datas, dataset_name) test_datas = ept_preprocess(test_datas, dataset_name) return train_datas, valid_datas, test_datas
[docs]def ept_preprocess(datas, dataset_name): datas_list = [] for idx, data in enumerate(datas): if dataset_name == "mawps": answer_list = [(x,) for x in data['aux']['lSolutions']] masked_text = re.sub('\\s+', ' ', data['aux']['mask_text']).strip().split(' ') temp_tokens = data['aux']['num_list'] regenerated_text = [] for token in masked_text: if token.startswith('temp_'): regenerated_text.append(str(temp_tokens[int(token[5:])])) else: regenerated_text.append(token) problem = ' '.join(regenerated_text) elif dataset_name == "SVAMP": data["original_text"] = data["ques source 1"].strip() data["ans"] = [str2float(data["Answer"])] answer_list = [tuple(x for x in data['ans'])] problem = data["original_text"].strip() elif dataset_name == "asdiv-a": data["original_text"] = data["ques source 1"].strip() if 'r' in data["ans"]: data["ans"] = data["ans"][:2] data["ans"] = [str2float(data["ans"])] answer_list = [tuple(x for x in data['ans'])] problem = data["original_text"].strip() elif dataset_name == "mawps_asdiv-a_svamp": data["original_text"] = data["ques source 1"].strip() data['ans']=[data['ans']] answer_list = [tuple(x for x in data['ans'])] problem = data["original_text"].strip() elif dataset_name == 'math23k': data["original_text"] = data["ques source 1"].strip() data["ans"] = [str2float(data["ans"])] answer_list = [tuple(x for x in data['ans'])] problem = data["original_text"].strip() #if '^' in data['infix equation']: # continue elif dataset_name == 'hmwp': data['original_text'] = data['ques source 1'] answer_list = [tuple(x for x in data['ans'])] problem = data["original_text"].strip() elif dataset_name == 'alg514' or dataset_name == 'draw': answer_list = [tuple(x for x in data['ans'])] problem = data["original_text"].strip() text, numbers = find_ept_numbers_in_text(problem) data['ept'] = {} data['ept']['text'] = text data['ept']['numbers'] = numbers data['ept']['answer'] = answer_list prefix_formula = refine_formula_as_prefix(data, numbers, dataset_name) data['ept']['expr'] = prefix_formula datas_list.append(data) return datas_list
[docs]def refine_formula_as_prefix(item, numbers, dataset_name): if dataset_name in ['SVAMP','asdiv-a','math23k','mawps_asdiv-a_svamp']: formula = item['infix equation'] formula = ["x", "="]+formula else: formula = item['infix equation'] if dataset_name in ["alg514", 'draw']: formula = [re.sub('([-+*/=])', ' \\1 ', eqn.lower().replace('-1', '1NEG')).replace('1NEG', '-1') for eqn in item["aux"]['Template']] # Shorthand for linear formula tokens = re.split('\\s+', item['aux']['sQuestion'].strip()) number_by_tokenid = {j: i for i, x in enumerate(numbers) for j in x['token']} # Build map between (sentence, token in sentence) --> number token index number_token_sentence = {} sent_id = 0 sent_token_id = 0 for tokid, token in enumerate(tokens): if token in '.!?': # End of sentence sent_id += 1 sent_token_id = 0 continue if tokid in number_by_tokenid: number_token_sentence[(sent_id, sent_token_id)] = number_by_tokenid[tokid] sent_token_id += 1 # [1] Build mapping between coefficients in the template and var names (N_0, T_0, ...) mappings = {} for align in item["aux"]['Alignment']: var = align['coeff'] val = align['Value'] sent_id = align['SentenceId'] token_id = align['TokenId'] if (sent_id, token_id) not in number_token_sentence: # If this is not in numbers recognized by our system, regard it as a constant. positive, const_code = constant_number(val) mappings[var] = [const_code] if not positive: mappings[var].append('-') continue number_id = number_token_sentence[(sent_id, token_id)] number_info = numbers[number_id] expression = ['N_%s' % number_id] expr_value = eval(number_info['value']) offset = 1 while abs(val - expr_value) > 1E-10 and (sent_id, token_id + offset) in number_token_sentence: next_number_id = number_token_sentence[(sent_id, token_id + offset)] next_info = numbers[next_number_id] next_value = eval(next_info['value']) next_token = 'N_%s' % next_number_id if next_value >= 100: # Multiplicative case: e.g. '[Num] million' expr_value *= next_value # As a postfix expression expression.append(next_token) expression.append('*') else: # Additive case: e.g. '[NUM] hundred thirty-two' expr_value += next_value expression.append(next_token) expression.append('+') offset += 1 # Final check. # assert abs(val - expr_value) < 1E-5, "%s vs %s: \n%s\n%s" % (align, expr_value, numbers, item) mappings[var] = expression # [2] Parse template and convert coefficients into our variable names. # Free symbols in the template denotes variables representing the answer. new_formula = [] free_symbols = [] for eqn in formula: output_tokens = orig_infix_to_postfix(eqn, mappings, free_symbols) if output_tokens: new_formula.append((EPT.PREP_KEY_EQN, output_tokens)) if free_symbols: new_formula.append((EPT.PREP_KEY_ANS, ' '.join(['X_%s' % i for i in range(len(free_symbols))]))) elif dataset_name in ['mawps']: template_to_number = {} template_to_value = {} number_by_tokenid = {j: i for i, x in enumerate(numbers) for j in x['token']} for tokid, token in enumerate(re.sub('\\s+', ' ', item['aux']['mask_text']).strip().split(' ')): if token.startswith('temp_'): assert tokid in number_by_tokenid, (tokid, number_by_tokenid, item['aux']) num_id = number_by_tokenid[tokid] template_to_number[token] = ['N_%s' % num_id] template_to_value[token] = numbers[num_id]['value'] # We should read both template_equ and new_equation because of NONE in norm_post_equ. formula = item['aux']['template_equ'].split(' ') original = item['aux']['new_equation'].split(' ') assert len(formula) == len(original) # Recover 'NONE' constant in the template_equ. for i in range(len(formula)): f_i = formula[i] o_i = original[i] if f_i == 'NONE': formula[i] = original[i] elif f_i.startswith('temp_'): assert abs(float(template_to_value[f_i]) - float(o_i)) < 1E-4,\ "Equation is different! '%s' vs '%s' at %i-th position" % (formula, original, i) else: # Check whether two things are the same. assert f_i == o_i, "Equation is different! '%s' vs '%s' at %i-th position" % (formula, original, i) free_symbols = [] new_formula = [(EPT.PREP_KEY_EQN, orig_infix_to_postfix(formula, template_to_number, free_symbols))] if free_symbols: new_formula.append((EPT.PREP_KEY_ANS, ' '.join(['X_%s' % i for i in range(len(free_symbols))]))) else: for wordid, word in enumerate(formula): if word == '[' or word == '{': formula[wordid] = '(' elif word == ']' or word == '}': formula[wordid] = ')' formula.append("<BRG>") formula_list = [] formula_string = '' for word in formula: if word == '<BRG>': formula_list.append(formula_string.strip()) formula_string = '' else: formula_string += word formula_string += ' ' formula = formula_list new_formula = [] free_symbols = [] for eqn in formula: output_tokens = infix_to_postfix(eqn, free_symbols) if output_tokens: new_formula.append((EPT.PREP_KEY_EQN, output_tokens)) if free_symbols: new_formula.append((EPT.PREP_KEY_ANS, ' '.join(['X_%s' % i for i in range(len(free_symbols))]))) return new_formula