Source code for mwptoolkit.data.dataloader.abstract_dataloader

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 11:34:06
# @File: abstract_dataloader.py
from typing import List

from mwptoolkit.utils.enum_type import FixType, SpecialTokens


[docs]class AbstractDataLoader(object): """abstract dataloader the base class of dataloader class """ def __init__(self, config, dataset): """ :param config: :param dataset: expected that config includes these parameters below: model (str): model name. equation_fix (str): [infix | postfix | prefix], convert equation to specified format. train_batch_size (int): the training batch size. test_batch_size (int): the testing batch size. symbol_for_tree (bool): build output symbols for tree or not. share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models. max_len (int|None): max input length. max_equ_len (int|None): max output length. add_sos (bool): add sos token at the head of input sequence. add_eos (bool): add eos token at the tail of input sequence. device (torch.device): """ super().__init__() self.model = config["model"] self.equation_fix = config["equation_fix"] self.train_batch_size = config["train_batch_size"] self.test_batch_size = config["test_batch_size"] self.symbol_for_tree = config["symbol_for_tree"] self.share_vocab = config["share_vocab"] self.max_len = config["max_len"] self.max_equ_len = config["max_equ_len"] self.add_sos = config["add_sos"] self.add_eos = config["add_eos"] self.filt_dirty = config["filt_dirty"] self.device = config["device"] self.dataset = dataset self.in_pad_token = None self.in_unk_token = None self.out_pad_token = None self.out_unk_token = None self.temp_unk_token = None self.temp_pad_token = None self.trainset_batches = [] self.validset_batches = [] self.testset_batches = [] self.__trainset_batch_idx = -1 self.__validset_batch_idx = -1 self.__testset_batch_idx = -1 self.trainset_batch_nums = 0 self.validset_batch_nums = 0 self.testset_batch_nums = 0 def _pad_input_batch(self, batch_seq, batch_seq_len): max_length = max(batch_seq_len) if self.max_len is not None: if self.max_len < max_length: max_length = self.max_len for idx, length in enumerate(batch_seq_len): if length < max_length: batch_seq[idx] += [self.in_pad_token for i in range(max_length - length)] else: if self.add_sos and self.add_eos: batch_seq[idx] = [batch_seq[idx][0]] + batch_seq[idx][1:max_length-1] + [batch_seq[idx][-1]] else: batch_seq[idx] = batch_seq[idx][:max_length] return batch_seq def _pad_output_batch(self, batch_target, batch_target_len): max_length = max(batch_target_len) if self.max_equ_len is not None: if self.max_equ_len < max_length: max_length = self.max_equ_len for idx, length in enumerate(batch_target_len): if length < max_length: batch_target[idx] += [self.out_pad_token for i in range(max_length - length)] else: batch_target[idx] = batch_target[idx][:max_length] return batch_target def _word2idx(self, sentence): sentence_idx = [] for word in sentence: try: idx = self.dataset.in_word2idx[word] except: idx = self.in_unk_token sentence_idx.append(idx) return sentence_idx def _idx2word(self, sentence_idx): sentence = [] for idx in sentence_idx: try: word = self.dataset.in_idx2word[idx] except: word = SpecialTokens.UNK_TOKEN sentence.append(word) return sentence def _equ_symbol2idx(self, equation): equ_idx = [] if self.equation_fix == FixType.MultiWayTree: for symbol in equation: if isinstance(symbol, list): sub_equ_idx = self._equ_symbol2idx(symbol) equ_idx.append(sub_equ_idx) else: if self.share_vocab: try: idx = self.dataset.in_word2idx[symbol] except: idx = self.in_unk_token else: try: idx = self.dataset.out_symbol2idx[symbol] except: idx = self.out_unk_token equ_idx.append(idx) else: for word in equation: if self.share_vocab: try: idx = self.dataset.in_word2idx[word] except: idx = self.in_unk_token else: try: idx = self.dataset.out_symbol2idx[word] except: idx = self.out_unk_token equ_idx.append(idx) return equ_idx def _equ_idx2symbol(self, equation_idx): equation = [] if self.equation_fix == FixType.MultiWayTree: for idx in equation_idx: if isinstance(idx, list): sub_equ = self._equ_idx2symbol(idx) equation.append(sub_equ) else: if self.share_vocab: try: symbol = self.dataset.in_idx2word[idx] except: symbol = SpecialTokens.UNK_TOKEN else: try: symbol = self.dataset.out_idx2symbol[idx] except: symbol = SpecialTokens.UNK_TOKEN equation.append(symbol) else: for idx in equation_idx: if self.share_vocab: try: symbol = self.dataset.in_idx2word[idx] except: symbol = self.in_unk_token else: try: symbol = self.dataset.out_idx2symbol[idx] except: symbol = self.out_unk_token equation.append(symbol) return equation def _temp_symbol2idx(self, template): temp_idx = [] if self.equation_fix == FixType.MultiWayTree: for symbol in template: if isinstance(symbol, list): sub_equ_idx = self._equ_symbol2idx(symbol) temp_idx.append(sub_equ_idx) else: if self.share_vocab: try: idx = self.dataset.in_word2idx[symbol] except: idx = self.in_unk_token else: try: idx = self.dataset.temp_symbol2idx[symbol] except: idx = self.out_unk_token temp_idx.append(idx) else: for word in template: if self.share_vocab: try: idx = self.dataset.in_word2idx[word] except: idx = self.in_unk_token else: try: idx = self.dataset.temp_symbol2idx[word] except: idx = self.temp_unk_token temp_idx.append(idx) return temp_idx def _temp_idx2symbol(self, template_idx): template = [] if self.equation_fix == FixType.MultiWayTree: for idx in template_idx: if isinstance(idx, list): sub_equ_idx = self._equ_idx2symbol(idx) template.append(sub_equ_idx) else: if self.share_vocab: try: symbol = self.dataset.in_idx2word[idx] except: symbol = SpecialTokens.UNK_TOKEN else: try: symbol = self.dataset.temp_idx2symbol[idx] except: symbol = SpecialTokens.UNK_TOKEN template.append(symbol) else: for idx in template_idx: if self.share_vocab: try: symbol = self.dataset.in_idx2word[idx] except: symbol = SpecialTokens.UNK_TOKEN else: try: symbol = self.dataset.temp_idx2symbol[idx] except: symbol = SpecialTokens.UNK_TOKEN template.append(symbol) return template def _get_mask(self, batch_seq_len): max_length = max(batch_seq_len) batch_mask = [] for idx, length in enumerate(batch_seq_len): batch_mask.append([1] * length + [0] * (max_length - length)) return batch_mask def _get_input_mask(self, batch_seq_len): if self.max_len: max_length = self.max_len else: max_length = max(batch_seq_len) batch_mask = [] for idx, length in enumerate(batch_seq_len): batch_mask.append([1] * length + [0] * (max_length - length)) return batch_mask def _build_num_stack(self, equation, num_list): num_stack = [] for word in equation: temp_num = [] flag_not = True if word not in self.dataset.out_idx2symbol: flag_not = False if "NUM" in word: temp_num.append(int(word[4:])) for i, j in enumerate(num_list): if j == word: temp_num.append(i) if not flag_not and len(temp_num) != 0: num_stack.append(temp_num) if not flag_not and len(temp_num) == 0: num_stack.append([_ for _ in range(len(num_list))]) num_stack.reverse() return num_stack
[docs] def load_data(self): """load data. """ raise NotImplementedError
[docs] def load_next_batch(self): """load data. """ raise NotImplementedError
[docs] def init_batches(self): """initialize batches. """ raise NotImplementedError
[docs] def convert_word_2_idx(self,sentence:List[str]): """ convert token of input sequence to index. :param sentence: List[str] :return: """ if self.add_sos: sentence = [SpecialTokens.SOS_TOKEN] + sentence if self.add_eos: sentence = sentence + [SpecialTokens.EOS_TOKEN] sentence_idx = self._word2idx(sentence) return sentence_idx
[docs] def convert_idx_2_word(self,sentence_idx:List[int]): """ convert token index of input sequence to token. :param sentence_idx: :return: """ sentence = self._idx2word(sentence_idx) return sentence
[docs] def convert_symbol_2_idx(self,equation:List[str]): """ convert symbol of equation to index. :param equation: :return: """ equation_idx = self._equ_symbol2idx(equation) return equation_idx
[docs] def convert_idx_2_symbol(self,equation_idx:List[int]): """ convert symbol index of equation to symbol. :param equation_idx: :return: """ equation = self._equ_idx2symbol(equation_idx) return equation