import re
from copy import deepcopy
from collections import OrderedDict
from typing import Tuple
import nltk
from tqdm import tqdm
from mwptoolkit.utils.utils import str2float, lists2dict
from mwptoolkit.utils.enum_type import DatasetName, MaskSymbol, NumMask, SpecialTokens, TaskType
from mwptoolkit.utils.preprocess_tool.number_operator import english_word_2_num, joint_fraction
[docs]def number_transfer(datas, dataset_name, task_type, mask_type, min_generate_keep, linear_dataset, equ_split_symbol=';',
vocab_level='word', word_lower=False) -> Tuple[list, list, int, list]:
"""
number transfer
:param list datas: dataset.
:param str dataset_name: dataset name.
:param str task_type: [single_equation | multi_equation], task type.
:param mask_type:
:param int min_generate_keep: generate number that count greater than the value, will be kept in output symbols.
:param bool linear_dataset:
:param str equ_split_symbol: equation split symbol, in multiple-equation dataset, symbol to split equations, this symbol will be repalced with special token SpecialTokens.BRG
:param str vocab_level:
:param bool word_lower:
:return: processed datas, generate number list, copy number, unk symbol list.
"""
if dataset_name == DatasetName.math23k:
transfer = number_transfer_math23k
elif dataset_name == DatasetName.ape200k:
transfer = number_transfer_ape200k
elif dataset_name == DatasetName.asdiv_a:
transfer = number_transfer_asdiv_a
elif dataset_name == DatasetName.SVAMP:
transfer = number_transfer_svamp
elif dataset_name == DatasetName.mawps_single:
transfer = number_transfer_mawps_single
elif dataset_name == DatasetName.mawps:
transfer = number_transfer_mawps
elif dataset_name == DatasetName.alg514:
transfer = num_transfer_alg514
elif dataset_name == DatasetName.draw:
transfer = num_transfer_draw
elif dataset_name == DatasetName.hmwp:
transfer = num_transfer_hmwp
else:
if task_type == TaskType.SingleEquation:
transfer = number_transfer_single
elif task_type == TaskType.MultiEquation:
transfer = num_transfer_multi
else:
raise NotImplementedError
generate_nums = []
generate_nums_dict = {}
copy_nums = 0
processed_datas = []
unk_symbol = []
for data in tqdm(datas,desc='word segmentation and number mapping'):
if task_type == TaskType.SingleEquation:
new_data = transfer(data, mask_type, linear_dataset, vocab_level, word_lower)
elif task_type == TaskType.MultiEquation:
new_data = transfer(data, mask_type, equ_split_symbol, vocab_level, word_lower)
else:
raise NotImplementedError
if dataset_name == DatasetName.mawps_single and task_type == TaskType.SingleEquation and '=' in new_data[
"equation"]:
continue
num_list = new_data["number list"]
out_seq = new_data["equation"]
copy_num = len(new_data["number list"])
for idx, s in enumerate(out_seq):
# tag the num which is generated
if s[0] == '-' and len(s) >= 2 and s[1].isdigit() and s not in generate_nums and s not in num_list:
generate_nums.append(s)
generate_nums_dict[s] = 0
if s[0].isdigit() and s not in generate_nums and s not in num_list:
generate_nums.append(s)
generate_nums_dict[s] = 0
if s in generate_nums and s not in num_list:
generate_nums_dict[s] = generate_nums_dict[s] + 1
if copy_num > copy_nums:
copy_nums = copy_num
# get unknown number
if task_type == TaskType.SingleEquation:
if linear_dataset:
for s in out_seq:
if len(s) == 1 and s.isalpha():
if s in unk_symbol:
continue
else:
unk_symbol.append(s)
else:
pass
elif task_type == TaskType.MultiEquation:
for s in out_seq:
if len(s) == 1 and s.isalpha():
if s in unk_symbol:
continue
else:
unk_symbol.append(s)
else:
raise NotImplementedError
processed_datas.append(new_data)
# keep generate number
generate_number = []
for g in generate_nums:
if generate_nums_dict[g] >= min_generate_keep:
generate_number.append(g)
return processed_datas, generate_number, copy_nums, unk_symbol
[docs]def seg_and_tag_single(st, nums_fraction, nums): # seg the equation and tag the num
res = []
pos_st = re.search(r"-\d+\.\d+%?|-\d+%?", st) # search negative number but filtate minus symbol
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_single(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
res.append(st_num)
if p_end < len(st):
res += seg_and_tag_single(st[p_end:], nums_fraction, nums)
return res
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_single(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_single(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_single(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
res.append(st_num)
if p_end < len(st):
res += seg_and_tag_single(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss == ' ':
continue
res.append(ss)
return res
[docs]def seg_and_tag_math23k(st, nums_fraction, nums): # seg the equation and tag the num
res = []
pos_st = re.search(r"([+]|-|[*]|/|[(]|=)-(([(]\d+\.\d+[)])|([(]\d+/\d+[)]))",
st) # search negative number but filtate minus symbol
if pos_st:
p_start = pos_st.start() + 1
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_math23k(st[:p_start], nums_fraction, nums)
try:
st_num = str(eval(st[p_start:p_end]))
except: # % in number
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
try:
number = str(int(eval(st_num)))
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(st_num)
except:
res.append(st_num)
if p_end < len(st):
res += seg_and_tag_math23k(st[p_end:], nums_fraction, nums)
return res
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_math23k(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_math23k(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_math23k(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
try:
number = str(int(eval(st_num)))
res.append(nums[number])
except:
res.append(st_num)
if p_end < len(st):
res += seg_and_tag_math23k(st[p_end:], nums_fraction, nums)
return res
for ss in st:
res.append(ss)
return res
[docs]def seg_and_tag_ape200k(st, nums_fraction, nums): # seg the equation and tag the num
res = []
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_ape200k(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_ape200k(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_ape200k(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
try:
number = str(int(eval(st_num)))
res.append(nums[number])
except:
res.append(st_num)
if p_end < len(st):
res += seg_and_tag_ape200k(st[p_end:], nums_fraction, nums)
return res
for ss in st:
res.append(ss)
return res
[docs]def seg_and_tag_asdiv_a(st, nums_fraction, nums): # seg the equation and tag the num
res = []
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_asdiv_a(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_asdiv_a(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_asdiv_a(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
try:
number = str(int(eval(st_num)))
res.append(nums[number])
except:
number = str(str2float(st_num))
try:
res.append(nums[number])
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_asdiv_a(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss == ' ':
continue
res.append(ss)
return res
[docs]def seg_and_tag_svamp(st, nums_fraction, nums): # seg the equation and tag the num
res = []
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_svamp(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_svamp(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_svamp(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float((st_num)))
try:
res.append(nums[number])
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_svamp(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss == " ":
continue
res.append(ss)
return res
[docs]def seg_and_tag_multi(st, nums_fraction, nums): # seg the equation and tag the num
res = []
pos_st = re.search(r"([+]|-|[*]|/|[(]|=)-((\d+\.?\d*))", st) # search negative number but filtate minus symbol
if pos_st:
p_start = pos_st.start() + 1
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_multi(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_multi(st[p_end:], nums_fraction, nums)
return res
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_multi(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_multi(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st) # search number including number with % symbol
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_multi(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_multi(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("<BRG>", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_multi(st[:p_start], nums_fraction, nums)
res.append(st[p_start:p_end])
if p_end < len(st):
res += seg_and_tag_multi(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss.isalpha():
res.append(ss.lower())
elif ss == " ":
continue
else:
res.append(ss)
return res
[docs]def seg_and_tag_hmwp(st, nums_fraction, nums): # seg the equation and tag the num
res = []
pos_st = re.search(r"([+]|-|[*]|/|[(]|=)\s-\s((\d+\.?\d*))", st) # search negative number but filtate minus symbol
if pos_st:
p_start = pos_st.start() + 2
p_end = pos_st.end()
num_str = ''.join(st[p_start:p_end].split(" "))
if p_start > 0:
res += seg_and_tag_hmwp(st[:p_start], nums_fraction, nums)
st_num = num_str
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_hmwp(st[p_end:], nums_fraction, nums)
return res
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_hmwp(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_hmwp(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st) # search number including number with % symbol
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_hmwp(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_hmwp(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("<BRG>", st)
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_hmwp(st[:p_start], nums_fraction, nums)
res.append(st[p_start:p_end])
if p_end < len(st):
res += seg_and_tag_hmwp(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss.isalpha():
res.append(ss.lower())
elif ss == " ":
continue
else:
res.append(ss)
return res
[docs]def seg_and_tag_mawps_single(st, nums_fraction, nums):
res = []
pos_st = re.search(r"([+]|-|[*]|/|[(]|=)-((\d+\.?\d*))", st) # search negative number but filtate minus symbol
if pos_st:
p_start = pos_st.start() + 1
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_mawps_single(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_mawps_single(st[p_end:], nums_fraction, nums)
return res
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_mawps_single(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_mawps_single(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st) # search number including number with % symbol
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_mawps_single(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_mawps_single(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss.isalpha():
res.append(ss.lower())
elif ss == " ":
continue
else:
res.append(ss)
return res
[docs]def seg_and_tag_mawps(st, nums_fraction, nums): # seg the equation and tag the num
res = []
pos_st = re.search(r"([+]|-|[*]|/|[(]|=)-((\d+\.?\d*))", st) # search negative number but filtate minus symbol
if pos_st:
p_start = pos_st.start() + 1
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_mawps(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_mawps(st[p_end:], nums_fraction, nums)
return res
for n in nums_fraction:
if n in st:
p_start = st.find(n)
p_end = p_start + len(n)
if p_start > 0:
res += seg_and_tag_mawps(st[:p_start], nums_fraction, nums)
try:
res.append(nums[n])
except:
res.append(n)
if p_end < len(st):
res += seg_and_tag_mawps(st[p_end:], nums_fraction, nums)
return res
pos_st = re.search("\d+\.\d+%?|\d+%?", st) # search number including number with % symbol
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
res += seg_and_tag_mawps(st[:p_start], nums_fraction, nums)
st_num = st[p_start:p_end]
try:
res.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
res.append(nums[number])
else:
res.append(number)
except:
res.append(number)
if p_end < len(st):
res += seg_and_tag_mawps(st[p_end:], nums_fraction, nums)
return res
for ss in st:
if ss.isalpha():
res.append(ss.lower())
elif ss == " ":
continue
else:
res.append(ss)
return res
[docs]def number_transfer_single(data, mask_type, linear, vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
if word_lower:
data["question"] = data["question"].lower()
seg = data["question"].split(" ")
equations = data["equation"]
if linear:
if equations.startswith('x=') or equations.startswith('X='):
equations = equations[2:]
elif equations.endswith('=x') or equations.endswith('=X'):
equations = equations[:-2]
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(s[pos.start():pos.end()])
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == ' ' or s == '':
continue
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_single(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
source[pos] = num_str
source = ' '.join(source)
assert len(num_list) == len(num_pos)
new_data = data
new_data["question"] = input_seq
new_data["ques source 1"] = source
new_data["equation"] = out_seq
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def number_transfer_math23k(data, mask_type, linear, vocab_level='word', word_lower=False):
# pattern = re.compile("\data*\(\data+/\data+\)\data*|\data+\.\data+%?|\data+%?")
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
if word_lower:
data["segmented_text"] = data["segmented_text"].lower()
seg = data["segmented_text"].split(" ")
equations = data["equation"][2:]
if '千' in equations:
equations = equations[:equations.index('千')]
num_pos_dict = {}
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(s[pos.start():pos.end()])
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == ' ' or s == '':
continue
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_math23k(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
new_data = data
new_data["question"] = input_seq
new_data["ques source 1"] = source
new_data["equation"] = out_seq
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def number_transfer_ape200k(data, mask_type, linear, vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
if word_lower:
data["segmented_text"] = data["segmented_text"].lower()
seg = data["segmented_text"].split(" ")
seg = joint_fraction(seg)
equations = data["equation"]
if "x=" == equations[:2] or "X=" == equations[:2]:
equations = equations[2:]
equations = equations.replace('**','^',100)
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(s[pos.start():pos.end()])
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == ' ' or s == '':
continue
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq_ = seg_and_tag_ape200k(equations, nums_fraction, nums)
out_seq = []
i = 0
while i<len(out_seq_):
s = out_seq_[i]
if s == '%':
out_seq.append('/')
out_seq.append('100')
i+=1
elif s == ':':
out_seq.append('/')
i+=1
else:
out_seq.append(s)
i+=1
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
new_data = data
new_data["question"] = input_seq
new_data["ques source 1"] = source
new_data["equation"] = out_seq
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def number_transfer_asdiv_a(data, mask_type, linear, vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
if word_lower:
data["Body"] = data["Body"].lower()
data["Question"] = data["Question"].lower()
seg = nltk.word_tokenize(data["Body"] + ' ' + data["Question"])
formula = data["Formula"]
equations = formula[:formula.index('=')]
ans = formula[formula.index('=') + 1:]
num_pos_dict = {}
for idx, word in enumerate(seg):
if re.match(r"(\d+\,\d+)+", word):
new_word = "".join(word.split(","))
seg[idx] = new_word
seg = english_word_2_num(seg)
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == ' ' or s == '':
continue
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_asdiv_a(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
new_data = data
new_data['id'] = data['@ID']
new_data['ans'] = ans
new_data["question"] = input_seq
new_data["ques source 1"] = source
new_data["equation"] = out_seq
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def number_transfer_svamp(data, mask_type, linear, vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
if word_lower:
data["Body"] = data["Body"].lower()
data["Question"] = data["Question"].lower()
seg = nltk.word_tokenize(data["Body"] + ' ' + data["Question"])
equations = data["Equation"]
if equations.startswith('( ') and equations.endswith(' )'):
equations = equations[2:-2]
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_svamp(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
new_data = data
new_data["question"] = input_seq
new_data["ques source 1"] = source
new_data["equation"] = out_seq
new_data["number list"] = num_list
new_data["number position"] = num_pos
new_data["id"] = data["ID"]
new_data["ans"] = data["Answer"]
return new_data
[docs]def number_transfer_mawps_single(data, mask_type, linear, vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
if word_lower:
data["sQuestion"] = data["sQuestion"].lower()
seg = nltk.word_tokenize(data["sQuestion"])
equations = data["lEquations"][0]
if equations[:2] == 'x=' or equations[:2] == 'X=':
equations = equations[2:]
if equations[-2:] == '=x' or equations[-2:] == '=X':
equations = equations[:-2]
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == ' ' or s == '':
continue
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_mawps_single(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
new_data = data
new_data['id'] = data['iIndex']
new_data["question"] = input_seq
new_data["ques source 1"] = source
new_data["equation"] = out_seq
new_data["number list"] = num_list
new_data["number position"] = num_pos
new_data["ans"] = data['lSolutions'][0]
return new_data
[docs]def number_transfer_mawps(data, mask_type, linear, vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?|(-\d+)")
if word_lower:
data["original_text"] = data["original_text"].lower()
seg = data["original_text"].split(" ")
equations = data["equation"]
equations = re.sub(r"[a-zA-Z]{2,}", "x", equations)
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == '':
continue
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
if data['id'] == 46:
x = 1
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_mawps(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
# copy data
new_data = data
new_data["question"] = input_seq
new_data["equation"] = out_seq
new_data["ques source 1"] = source
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def num_transfer_multi(data, mask_type, equ_split_symbol=";", vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?|(-\d+)")
if word_lower:
data["original_text"] = data["original_text"].lower()
seg = data["original_text"].split(" ")
equations = data["equation"]
equations = re.sub(r"[a-zA-Z]{2,}", "x", equations)
equations = re.sub(equ_split_symbol, SpecialTokens.BRG_TOKEN, equations)
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if s == '':
continue
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_multi(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
# copy data
new_data = data
new_data["question"] = input_seq
new_data["equation"] = out_seq
new_data["ques source 1"] = source
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def num_transfer_alg514(data, mask_type, equ_split_symbol=";", vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?|(-\d+)")
if word_lower:
data["original_text"] = data["original_text"].lower()
seg = nltk.word_tokenize(data["original_text"])
for idx, word in enumerate(seg):
if re.match(r"(\d+\,\d+)+", word):
new_word = "".join(word.split(","))
seg[idx] = new_word
seg = english_word_2_num(seg)
equations = data["equation"]
equations = re.sub(r"[a-zA-Z]{2,}", "x", equations)
equations = re.sub(equ_split_symbol, SpecialTokens.BRG_TOKEN, equations)
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
# input_seq.append(s[pos.start():pos.end()])
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_multi(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
# copy data
new_data = data
new_data["question"] = input_seq
new_data["equation"] = out_seq
new_data["ques source 1"] = source
new_data["number list"] = num_list
new_data["number position"] = num_pos
if num_list == []:
new_data["number list"] = ["-inf"]
new_data["number position"] = [-1]
return new_data
[docs]def num_transfer_draw(data, mask_type, equ_split_symbol=";", vocab_level='word', word_lower=False):
# pattern = re.compile(r"\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?|(-\d+)")
pattern = re.compile(r"\d+\/\d+|\d+\.\d+%?|\d+%?|(-\d+)")
if word_lower:
data["original_text"] = data["original_text"].lower()
seg = data["original_text"].split(" ")
for idx, word in enumerate(seg):
if re.match(r"(\d+\,\d+)+", word):
new_word = "".join(word.split(","))
seg[idx] = new_word
elif re.match(r"\.\d+", word):
new_word = "0" + word
seg[idx] = new_word
seg = english_word_2_num(seg, 3)
equations = data["equation"]
equations = re.sub(r"[a-zA-Z]{2,}", "x", equations)
equations = re.sub(equ_split_symbol, SpecialTokens.BRG_TOKEN, equations)
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = []
pos_st = re.search(r"^-((\d+\.?\d*))", equations) # search negative number starting
if pos_st:
p_start = pos_st.start()
p_end = pos_st.end()
if p_start > 0:
out_seq += seg_and_tag_multi(equations[:p_start], nums_fraction, nums)
st_num = equations[p_start:p_end]
try:
out_seq.append(nums[st_num])
except:
number = str(str2float(st_num))
try:
if abs(eval(number) - eval(st_num)) < 1e-4:
out_seq.append(nums[number])
else:
out_seq.append(number)
except:
out_seq.append(number)
if p_end < len(equations):
out_seq += seg_and_tag_multi(equations[p_end:], nums_fraction, nums)
else:
out_seq = seg_and_tag_multi(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
# copy data
new_data = data
new_data["question"] = input_seq
new_data["equation"] = out_seq
new_data["ques source 1"] = source
new_data["number list"] = num_list
new_data["number position"] = num_pos
if num_list == []:
new_data["number list"] = ["-inf"]
new_data["number position"] = [-1]
return new_data
[docs]def num_transfer_hmwp(data, mask_type, equ_split_symbol=";", vocab_level='word', word_lower=False):
pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?|(-\d+)")
if word_lower:
data["original_text"] = data["original_text"].lower()
seg = data["original_text"].split(" ")
equations = data["equation"]
equations = re.sub(r"[a-zA-Z]{2,}", "x", equations)
equations = re.sub(equ_split_symbol, SpecialTokens.BRG_TOKEN, equations)
# match and split number
input_seq = []
for s in seg:
pos = re.search(pattern, s)
if pos and pos.start() == 0:
# input_seq.append(s[pos.start():pos.end()])
input_seq.append(str(str2float(s[pos.start():pos.end()])))
if pos.end() < len(s):
if vocab_level == 'char':
input_seq += [c for c in s[pos.end():]]
else:
input_seq.append(s[pos.end():])
else:
if vocab_level == 'char':
input_seq += [c for c in s]
else:
input_seq.append(s)
input_seq, num_list, num_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction = get_num_pos(input_seq,
mask_type,
pattern)
out_seq = seg_and_tag_hmwp(equations, nums_fraction, nums)
source = deepcopy(input_seq)
for pos in all_pos:
for key, value in num_pos_dict.items():
if pos in value:
num_str = key
break
num = str(str2float(num_str))
source[pos] = num
source = ' '.join(source)
assert len(num_list) == len(num_pos)
# copy data
new_data = data
new_data["question"] = input_seq
new_data["equation"] = out_seq
new_data["ques source 1"] = source
new_data["number list"] = num_list
new_data["number position"] = num_pos
return new_data
[docs]def get_num_pos(input_seq, mask_type, pattern):
if mask_type == MaskSymbol.NUM:
sent_mask_list = NumMask.NUM
equ_mask_list = NumMask.number
elif mask_type == MaskSymbol.alphabet:
sent_mask_list = NumMask.alphabet
equ_mask_list = NumMask.alphabet
elif mask_type == MaskSymbol.number:
sent_mask_list = NumMask.number
equ_mask_list = NumMask.number
nums = OrderedDict()
num_list = []
num_pos = []
num_pos_dict = {}
if mask_type == MaskSymbol.NUM:
# find all number position
for word_pos, word in enumerate(input_seq):
pos = re.search(pattern, word)
if pos and pos.start() == 0:
num_list.append(word)
num_pos.append(word_pos)
if word in num_pos_dict:
num_pos_dict[word].append(word_pos)
else:
num_pos_dict[word] = [word_pos]
mask_list = equ_mask_list[:len(num_list)]
new_num_list = []
new_mask_list = []
for i in num_list:
if num_list.count(i) != 1:
x = 1
if num_list.count(i) == 1:
new_num_list.append(i)
new_mask_list.append(mask_list[num_list.index(i)])
else:
pass
nums = lists2dict(new_num_list, new_mask_list)
else:
# find all number position
for word_pos, word in enumerate(input_seq):
pos = re.search(pattern, word)
if pos and pos.start() == 0:
if word in num_pos_dict:
num_pos_dict[word].append(word_pos)
else:
num_list.append(word)
num_pos_dict[word] = [word_pos]
num_list = sorted(num_list, key=lambda x: max(num_pos_dict[x]), reverse=False)
nums = lists2dict(num_list, equ_mask_list[:len(num_list)])
nums_for_ques = lists2dict(num_list, sent_mask_list[:len(num_list)])
# all number position
all_pos = []
if mask_type == MaskSymbol.NUM:
all_pos = deepcopy(num_pos)
else:
for num, mask in nums_for_ques.items():
for pos in num_pos_dict[num]:
all_pos.append(pos)
# final numbor position
final_pos = []
if mask_type == MaskSymbol.NUM:
final_pos = deepcopy(num_pos)
else:
for num in num_list:
# select the latest position as the number position
# if the number corresponds multiple positions
final_pos.append(max(num_pos_dict[num]))
# number transform
for num, mask in nums_for_ques.items():
for pos in num_pos_dict[num]:
input_seq[pos] = mask
# nums_fraction = []
# for num, mask in nums.items():
# if re.search("\data*\(\data+/\data+\)\data*", num):
# nums_fraction.append(num)
# nums_fraction = sorted(nums_fraction, key=lambda x: len(x), reverse=True)
nums_fraction = []
for num, mask in nums.items():
if re.search("\d*\(\d+/\d+\)\d*", num):
nums_fraction.append(num)
nums_fraction = sorted(nums_fraction, key=lambda x: len(x), reverse=True)
return input_seq, num_list, final_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction