Source code for mwptoolkit.data.dataloader.template_dataloader

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 11:35:57
# @File: template_dataloader.py


from mwptoolkit.data.dataloader.abstract_dataloader import AbstractDataLoader

[docs]class TemplateDataLoader(AbstractDataLoader): """template dataloader. you need implement: TemplateDataLoader.__init_batches() We replace abstract method TemplateDataLoader.load_batch() with TemplateDataLoader.__init_batches() after version 0.0.5 . Their functions are similar. """ def __init__(self, config, dataset): super().__init__(config, dataset) self.trainset_nums = len(dataset.trainset) self.validset_nums = len(dataset.validset) self.testset_nums = len(dataset.testset)
[docs] def load_data(self,type:str): """ Load batches, return every batch data in a generator object. :param type: [train | valid | test], data type. :return: Generator[dict], batches """ if type == "train": self.__trainset_batch_idx=-1 for batch in self.trainset_batches: self.__trainset_batch_idx = (self.__trainset_batch_idx + 1) % self.trainset_batch_nums yield batch elif type == "valid": self.__validset_batch_idx=-1 for batch in self.validset_batches: self.__validset_batch_idx = (self.__validset_batch_idx + 1) % self.validset_batch_nums yield batch elif type == "test": self.__testset_batch_idx=-1 for batch in self.testset_batches: self.__testset_batch_idx = (self.__testset_batch_idx + 1) % self.testset_batch_nums yield batch else: raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type))
[docs] def load_next_batch(self,type:str): """ Return next batch data :param type: [train | valid | test], data type. :return: batch data """ if type == "train": self.__trainset_batch_idx=(self.__trainset_batch_idx+1)%self.trainset_batch_nums return self.trainset_batches[self.__trainset_batch_idx] elif type == "valid": self.__validset_batch_idx = (self.__validset_batch_idx + 1) % self.validset_batch_nums return self.validset_batches[self.__validset_batch_idx] elif type == "test": self.__testset_batch_idx = (self.__testset_batch_idx + 1) % self.testset_batch_nums return self.testset_batches[self.__testset_batch_idx] else: raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type))
[docs] def init_batches(self): """ Initialize batches of trainset, validset and testset. :return: None """ self.__init_batches()
def __init_batches(self): """ In this function, you need to implement the codes of initializing batches. Specifically, you need to 1. reset the list variables TemplateDataLoader.trainset_batches, TemplateDataLoader.validset_batches and TemplateDataLoader.testset_batches. And save corresponding every batch data in them. What value every batch includes is designed by you. 2. reset the integer variables TemplateDataLoader.__trainset_batch_idx, TemplateDataLoader.__validset_batch_idx and TemplateDataLoader.__testset_batch_idx as -1. 3. reset the integer variables TemplateDataLoader.trainset_batch_nums, TemplateDataLoader.validset_batch_nums and TemplateDataLoader.testset_batch_nums. Their values should equal to corresponding length of batches. """ raise NotImplementedError
# def load_data(self, type): # if type == "train": # datas = self.dataset.trainset # batch_size = self.train_batch_size # elif type == "valid": # datas = self.dataset.validset # batch_size = self.test_batch_size # elif type == "test": # datas = self.dataset.testset # batch_size = self.test_batch_size # else: # raise ValueError("{} type not in ['train', 'valid', 'test'].".format(type)) # # num_total = len(datas) # batch_num = int(num_total / batch_size) + 1 # for batch_i in range(batch_num): # start_idx = batch_i * batch_size # end_idx = (batch_i + 1) * batch_size # if end_idx <= num_total: # batch_data = datas[start_idx:end_idx] # else: # batch_data = datas[start_idx:num_total] # if batch_data != []: # batch_data = self.load_batch(batch_data) # yield batch_data