# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 11:35:43
# @File: pretrain_dataloader.py
import math
import torch
from typing import List
from mwptoolkit.config import Config
from mwptoolkit.data.dataset import PretrainDataset
from mwptoolkit.data.dataloader.abstract_dataloader import AbstractDataLoader
from mwptoolkit.utils.enum_type import FixType, SpecialTokens
[docs]def get_num_mask(num_size_batch, generate_nums):
num_mask = []
max_num_size = max(num_size_batch) + len(generate_nums)
for i in num_size_batch:
d = i + len(generate_nums)
num_mask.append([0] * d + [1] * (max_num_size - d))
return num_mask
[docs]class PretrainDataLoader(AbstractDataLoader):
"""dataloader class for pre-train model.
"""
def __init__(self, config:Config, dataset:PretrainDataset):
"""
:param config:
:param dataset:
expected that config includes these parameters below:
model (str): model name.
equation_fix (str): [infix | postfix | prefix], convert equation to specified format.
train_batch_size (int): the training batch size.
test_batch_size (int): the testing batch size.
symbol_for_tree (bool): build output symbols for tree or not.
share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models.
max_len (int|None): max input length.
max_equ_len (int|None): max output length.
add_sos (bool): add sos token at the head of input sequence.
add_eos (bool): add eos token at the tail of input sequence.
device (torch.device):
"""
super().__init__(config, dataset)
# self.dataset=PretrainDataset(config)
# dataset=PretrainDataset(config)
self.trainset_nums = len(dataset.trainset)
self.validset_nums = len(dataset.validset)
self.testset_nums = len(dataset.testset)
self.in_pad_token = dataset.tokenizer.convert_tokens_to_ids(SpecialTokens.PAD_TOKEN)
self.in_unk_token = dataset.tokenizer.convert_tokens_to_ids(SpecialTokens.UNK_TOKEN)
if self.symbol_for_tree or self.equation_fix == FixType.MultiWayTree:
self.out_pad_token = self.in_pad_token
self.out_unk_token = dataset.out_symbol2idx[SpecialTokens.UNK_TOKEN]
self.temp_unk_token = dataset.temp_symbol2idx[SpecialTokens.UNK_TOKEN]
else:
if self.share_vocab:
self.out_pad_token = self.in_pad_token
self.out_unk_token = self.in_unk_token
self.temp_pad_token = self.in_pad_token
self.temp_unk_token = self.in_unk_token
else:
self.out_pad_token = dataset.out_symbol2idx[SpecialTokens.PAD_TOKEN]
self.out_unk_token = dataset.out_symbol2idx[SpecialTokens.UNK_TOKEN]
self.temp_pad_token = dataset.temp_symbol2idx[SpecialTokens.PAD_TOKEN]
self.temp_unk_token = dataset.temp_symbol2idx[SpecialTokens.UNK_TOKEN]
self.__init_batches()
[docs] def load_data(self, type: str):
"""
Load batches, return every batch data in a generator object.
:param type: [train | valid | test], data type.
:return: Generator[dict], batches
"""
if type == "train":
self.__trainset_batch_idx = -1
for batch in self.trainset_batches:
self.__trainset_batch_idx = (self.__trainset_batch_idx + 1) % self.trainset_batch_nums
yield batch
elif type == "valid":
self.__validset_batch_idx = -1
for batch in self.validset_batches:
self.__validset_batch_idx = (self.__validset_batch_idx + 1) % self.validset_batch_nums
yield batch
elif type == "test":
self.__testset_batch_idx = -1
for batch in self.testset_batches:
self.__testset_batch_idx = (self.__testset_batch_idx + 1) % self.testset_batch_nums
yield batch
else:
raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type))
[docs] def load_next_batch(self, type: str) -> dict:
"""
Return next batch data
:param type: [train | valid | test], data type.
:return: batch data
"""
if type == "train":
self.__trainset_batch_idx = (self.__trainset_batch_idx + 1) % self.trainset_batch_nums
return self.trainset_batches[self.__trainset_batch_idx]
elif type == "valid":
self.__validset_batch_idx = (self.__validset_batch_idx + 1) % self.validset_batch_nums
return self.validset_batches[self.__validset_batch_idx]
elif type == "test":
self.__testset_batch_idx = (self.__testset_batch_idx + 1) % self.testset_batch_nums
return self.testset_batches[self.__testset_batch_idx]
else:
raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type))
[docs] def init_batches(self):
"""
Initialize batches of trainset, validset and testset.
:return: None
"""
self.__init_batches()
def _word2idx(self, sentence):
sentence_idx = []
sentence_idx = self.dataset.tokenizer.convert_tokens_to_ids(sentence)
#sentence_idx = self.dataset.tokenizer.encode(sentence,add_special_token=False)
return sentence_idx
def _equ_symbol2idx(self, equation):
equ_idx = []
if self.equation_fix == FixType.MultiWayTree:
for symbol in equation:
if isinstance(symbol, list):
sub_equ_idx = self._equ_symbol2idx(symbol)
equ_idx.append(sub_equ_idx)
else:
if self.share_vocab:
idx = self.dataset.tokenizer.convert_tokens_to_ids(symbol)
else:
try:
idx = self.dataset.out_symbol2idx[symbol]
except:
idx = self.out_unk_token
equ_idx.append(idx)
else:
for symbol in equation:
if self.share_vocab:
idx = self.dataset.tokenizer.convert_tokens_to_ids(symbol)
else:
try:
idx = self.dataset.out_symbol2idx[symbol]
except:
idx = self.out_unk_token
equ_idx.append(idx)
return equ_idx
def __build_batch(self,batch_data):
ques_batch = []
equ_batch = []
temp_batch = []
ques_source_batch = []
equ_source_batch = []
temp_source_batch = []
ques_source_1_batch = []
infix_equ_batch = []
num_list_batch = []
num_pos_batch = []
id_batch = []
ans_batch = []
equ_len_batch = []
ques_len_batch = []
num_stack_batch = []
group_nums_batch = []
batch_data = sorted(batch_data, key=lambda x: len(x['question']), reverse=True)
for data in batch_data:
sentence = data["question"]
equation = data["equation"]
template = data["template"]
# question word to index
if self.add_sos:
sentence = [SpecialTokens.SOS_TOKEN] + sentence
if self.add_eos:
sentence = sentence + [SpecialTokens.EOS_TOKEN]
ques_tensor = self._word2idx(sentence)
# equation symbol to index
if self.share_vocab:
equation = self.dataset.tokenizer.tokenize(' '.join(data["equation"]))
template = self.dataset.tokenizer.tokenize(' '.join(data["template"]))
if self.symbol_for_tree or self.equation_fix == FixType.MultiWayTree:
pass
else:
equation.append(SpecialTokens.EOS_TOKEN)
template.append(SpecialTokens.EOS_TOKEN)
equ_tensor = self._equ_symbol2idx(equation)
temp_tensor = self._temp_symbol2idx(template)
equ_len_batch.append(len(equ_tensor))
ques_len_batch.append(len(ques_tensor))
ques_batch.append(ques_tensor)
equ_batch.append(equ_tensor)
temp_batch.append(temp_tensor)
# question / equation to string
ques_source = ' '.join(sentence)
if self.equation_fix == FixType.MultiWayTree:
equ_source = ' '
temp_source = ' '
else:
equ_source = ' '.join(equation)
temp_source = ' '.join(template)
ques_source_batch.append(ques_source)
equ_source_batch.append(equ_source)
temp_source_batch.append(temp_source)
ques_source_1_batch.append(data["ques source 1"])
# infix equation
infix_equ_batch.append(data["infix equation"])
# quantity list
num_list_batch.append(data["number list"])
# quantity position
if self.add_sos:
num_pos = [pos + 1 for pos in
data["number position"]] # pos plus one because of adding <SOS> at the head of sentence
else:
num_pos = [pos for pos in data["number position"]]
num_pos_batch.append(num_pos)
# question id and answer
id_batch.append(data["id"])
ans_batch.append(data["ans"])
try:
group_nums_batch.append(data["group nums"])
except:
group_nums_batch.append([])
num_stack_batch.append(self._build_num_stack(equation, data["number list"]))
# padding batch question
ques_batch = self._pad_input_batch(ques_batch, ques_len_batch)
if self.max_len != None:
ques_len_batch = [self.max_len if l > self.max_len else l for l in ques_len_batch]
# padding batch equation
if self.equation_fix == FixType.MultiWayTree:
pass
else:
equ_batch = self._pad_output_batch(equ_batch, equ_len_batch)
temp_batch = self._pad_output_batch(temp_batch, equ_len_batch)
# question mask
ques_mask_batch = self._get_input_mask(ques_len_batch)
# equation mask
equ_mask_batch = self._get_mask(equ_len_batch)
# quantity count
num_size_batch = [len(num_pos) for num_pos in num_pos_batch]
# quantity mask
num_mask_batch = get_num_mask(num_size_batch, self.dataset.generate_list)
new_group_nums_batch = []
for group_nums in group_nums_batch:
new_group_nums = []
for group_num in group_nums:
new_group_num = []
for pos in group_num:
if self.add_sos:
new_group_num.append(pos + 1)
else:
new_group_num.append(pos)
new_group_nums.append(new_group_num)
new_group_nums_batch.append(new_group_nums)
return {
"question": ques_batch,
"equation": equ_batch,
"template": temp_batch,
"ques len": ques_len_batch,
"equ len": equ_len_batch,
"num list": num_list_batch,
"num pos": num_pos_batch,
"id": id_batch,
"num mask": num_mask_batch,
"ques mask": ques_mask_batch,
"equ mask": equ_mask_batch,
"num stack": num_stack_batch,
"ans": ans_batch,
"num size": num_size_batch,
"ques_source": ques_source_batch,
"equ_source": equ_source_batch,
"temp_source": temp_source_batch,
"ques source 1": ques_source_1_batch,
"group nums": new_group_nums_batch,
"infix equation": infix_equ_batch,
}
def __init_batches(self):
self.trainset_batches=[]
self.validset_batches=[]
self.testset_batches=[]
for set_type in ['train','valid','test']:
if set_type=='train':
datas = self.dataset.trainset
batch_size = self.train_batch_size
elif set_type=='valid':
datas = self.dataset.validset
batch_size = self.test_batch_size
elif set_type=='test':
datas = self.dataset.testset
batch_size = self.test_batch_size
else:
raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type))
num_total = len(datas)
batch_num = math.ceil(num_total / batch_size)
for batch_i in range(batch_num):
start_idx = batch_i * batch_size
end_idx = (batch_i + 1) * batch_size
if end_idx <= num_total:
batch_data = datas[start_idx:end_idx]
else:
batch_data = datas[start_idx:num_total]
built_batch = self.__build_batch(batch_data)
if set_type == 'train':
self.trainset_batches.append(built_batch)
elif set_type == 'valid':
self.validset_batches.append(built_batch)
elif set_type == 'test':
self.testset_batches.append(built_batch)
else:
raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type))
self.__trainset_batch_idx=-1
self.__validset_batch_idx=-1
self.__testset_batch_idx=-1
self.trainset_batch_nums=len(self.trainset_batches)
self.validset_batch_nums=len(self.validset_batches)
self.testset_batch_nums=len(self.testset_batches)
[docs] def build_batch_for_predict(self, batch_data: List[dict]):
for idx, data in enumerate(batch_data):
data['equation'] = []
data['template'] = []
data['infix equation'] = []
data['ans'] = None
if data.get('id', None) is None:
data['id'] = 'temp_{}'.format(idx)
batch = self.__build_batch(batch_data)
del batch['equation']
del batch['template']
del batch['equ len']
del batch['equ mask']
del batch['ans']
del batch['equ_source']
del batch['temp_source']
del batch['infix equation']
return batch
# def load_batch(self, batch_data):
# """load one batch
#
# Args:
# batch_data (list[dict])
#
# Returns:
# loaded batch data (dict)
# """
# ques_batch = []
# equ_batch = []
# temp_batch = []
# ques_source_batch = []
# equ_source_batch = []
# temp_source_batch = []
# ques_source_1_batch = []
# infix_equ_batch = []
#
# num_list_batch = []
# num_pos_batch = []
#
# id_batch = []
# ans_batch = []
#
# ques_mask_batch = []
# equ_mask_batch = []
# num_mask_batch = []
#
# equ_len_batch = []
# ques_len_batch = []
#
# num_size_batch = []
# num_stack_batch = []
#
# group_nums_batch = []
# # for data in batch_data:
# # data['question_']=self.dataset.tokenizer.tokenize(' '.join(data["question"]))
# # batch_data=sorted(batch_data,key=lambda x:len(x['question_']),reverse=True)
# batch_data = sorted(batch_data, key=lambda x: len(x['question']), reverse=True)
# for data in batch_data:
# ques_tensor = []
# equ_tensor = []
# temp_tensor = []
# sentence = data["question"]
# equation = data["equation"]
# template = data["template"]
#
# # question word to index
# # sentence=self.dataset.tokenizer.tokenize(' '.join(data["question"]))
# if self.add_sos:
# sentence = [SpecialTokens.SOS_TOKEN] + sentence
# if self.add_eos:
# sentence = sentence + [SpecialTokens.EOS_TOKEN]
# ques_tensor = self._word2idx(sentence)
#
# # equation symbol to index
# if self.share_vocab:
# equation = self.dataset.tokenizer.tokenize(' '.join(data["equation"]))
# template = self.dataset.tokenizer.tokenize(' '.join(data["template"]))
# if self.symbol_for_tree or self.equation_fix == FixType.MultiWayTree:
# pass
# else:
# equation.append(SpecialTokens.EOS_TOKEN)
# template.append(SpecialTokens.EOS_TOKEN)
# equ_tensor = self._equ_symbol2idx(equation)
# temp_tensor = self._temp_symbol2idx(template)
#
# equ_len_batch.append(len(equ_tensor))
# ques_len_batch.append(len(ques_tensor))
# ques_batch.append(ques_tensor)
# equ_batch.append(equ_tensor)
# temp_batch.append(temp_tensor)
#
# # question / equation to string
# ques_source = ' '.join(sentence)
# if self.equation_fix == FixType.MultiWayTree:
# equ_source = ' '
# temp_source = ' '
# else:
# equ_source = ' '.join(equation)
# temp_source = ' '.join(template)
# ques_source_batch.append(ques_source)
# equ_source_batch.append(equ_source)
# temp_source_batch.append(temp_source)
# ques_source_1_batch.append(data["ques source 1"])
# # infix equation
# infix_equ_batch.append(data["infix equation"])
# # quantity list
# num_list_batch.append(data["number list"])
# # quantity position
# if self.add_sos:
# num_pos = [pos + 1 for pos in
# data["number position"]] # pos plus one because of adding <SOS> at the head of sentence
# else:
# num_pos = [pos for pos in data["number position"]]
# num_pos_batch.append(num_pos)
# # question id and answer
# id_batch.append(data["id"])
# ans_batch.append(data["ans"])
# try:
# group_nums_batch.append(data["group nums"])
# except:
# group_nums_batch.append([])
#
# num_stack_batch.append(self._build_num_stack(equation, data["number list"]))
#
# # padding batch question
# ques_batch = self._pad_input_batch(ques_batch, ques_len_batch)
# if self.max_len != None:
# ques_len_batch = [self.max_len if l > self.max_len else l for l in ques_len_batch]
# # padding batch equation
# if self.equation_fix == FixType.MultiWayTree:
# pass
# else:
# equ_batch = self._pad_output_batch(equ_batch, equ_len_batch)
# temp_batch = self._pad_output_batch(temp_batch, equ_len_batch)
# # question mask
# ques_mask_batch = self._get_input_mask(ques_len_batch)
# # equation mask
# equ_mask_batch = self._get_mask(equ_len_batch)
# # quantity count
# num_size_batch = [len(num_pos) for num_pos in num_pos_batch]
# # quantity mask
# num_mask_batch = get_num_mask(num_size_batch, self.dataset.generate_list)
#
# new_group_nums_batch = []
# for group_nums in group_nums_batch:
# new_group_nums = []
# for group_num in group_nums:
# new_group_num = []
# for pos in group_num:
# if self.add_sos:
# new_group_num.append(pos + 1)
# else:
# new_group_num.append(pos)
# new_group_nums.append(new_group_num)
# new_group_nums_batch.append(new_group_nums)
# # to tensor
# ques_tensor_batch = torch.tensor(ques_batch).to(self.device)
# if self.equation_fix == FixType.MultiWayTree:
# equ_tensor_batch = equ_batch
# temp_tensor_batch = temp_batch
# else:
# equ_tensor_batch = torch.tensor(equ_batch).to(self.device)
# temp_tensor_batch = torch.tensor(temp_batch).to(self.device)
# ques_mask_batch = torch.tensor(ques_mask_batch).to(self.device).bool()
# num_mask_batch = torch.tensor(num_mask_batch).to(self.device).bool()
# ques_len_batch = torch.tensor(ques_len_batch).long()
# equ_mask_batch = torch.tensor(equ_mask_batch).to(self.device).bool()
#
# return {
# "question": ques_tensor_batch,
# "equation": equ_tensor_batch,
# "template": temp_tensor_batch,
# "ques len": ques_len_batch,
# "equ len": equ_len_batch,
# "num list": num_list_batch,
# "num pos": num_pos_batch,
# "id": id_batch,
# "num mask": num_mask_batch,
# "ques mask": ques_mask_batch,
# "equ mask": equ_mask_batch,
# "num stack": num_stack_batch,
# "ans": ans_batch,
# "num size": num_size_batch,
# "ques_source": ques_source_batch,
# "equ_source": equ_source_batch,
# "temp_source": temp_source_batch,
# "ques source 1": ques_source_1_batch,
# "group nums": new_group_nums_batch,
# "infix equation": infix_equ_batch,
# }