Source code for mwptoolkit.data.utils

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 21:39:08
# @File: utils.py
from typing import Union, Type

from mwptoolkit.config.configuration import Config

from mwptoolkit.data.dataset.abstract_dataset import AbstractDataset
from mwptoolkit.data.dataset.single_equation_dataset import SingleEquationDataset
from mwptoolkit.data.dataset.multi_equation_dataset import MultiEquationDataset
from mwptoolkit.data.dataset.dataset_multiencdec import DatasetMultiEncDec
from mwptoolkit.data.dataset.dataset_ept import DatasetEPT
from mwptoolkit.data.dataset.pretrain_dataset import PretrainDataset
from mwptoolkit.data.dataset.dataset_hms import DatasetHMS
from mwptoolkit.data.dataset.dataset_gpt2 import DatasetGPT2

from mwptoolkit.data.dataloader.abstract_dataloader import AbstractDataLoader
from mwptoolkit.data.dataloader.single_equation_dataloader import SingleEquationDataLoader
from mwptoolkit.data.dataloader.multi_equation_dataloader import MultiEquationDataLoader
from mwptoolkit.data.dataloader.dataloader_multiencdec import DataLoaderMultiEncDec
from mwptoolkit.data.dataloader.dataloader_ept import DataLoaderEPT
from mwptoolkit.data.dataloader.pretrain_dataloader import PretrainDataLoader
from mwptoolkit.data.dataloader.dataloader_hms import DataLoaderHMS
from mwptoolkit.data.dataloader.dataloader_gpt2 import DataLoaderGPT2

from mwptoolkit.utils.enum_type import TaskType


[docs]def create_dataset(config): """Create dataset according to config Args: config (mwptoolkit.config.configuration.Config): An instance object of Config, used to record parameter information. Returns: Dataset: Constructed dataset. """ try: return eval('Dataset{}'.format(config['model']))(config) except: pass if config['transformers_pretrained_model'] is not None or config['pretrained_model'] is not None: return PretrainDataset(config) task_type = config['task_type'].lower() if task_type == TaskType.SingleEquation: return SingleEquationDataset(config) elif task_type == TaskType.MultiEquation: return MultiEquationDataset(config) else: return AbstractDataset(config)
[docs]def create_dataloader(config): """Create dataloader according to config Args: config (mwptoolkit.config.configuration.Config): An instance object of Config, used to record parameter information. Returns: Dataloader module """ try: return eval('DataLoader{}'.format(config['model'])) except: pass if config['transformers_pretrained_model'] is not None or config['pretrained_model'] is not None: return PretrainDataLoader task_type = config['task_type'].lower() if task_type == TaskType.SingleEquation: return SingleEquationDataLoader elif task_type == TaskType.MultiEquation: return MultiEquationDataLoader else: return AbstractDataLoader
[docs]def get_dataset_module(config: Config) \ -> Type[Union[ DatasetMultiEncDec, DatasetEPT, DatasetHMS, DatasetGPT2, PretrainDataset, SingleEquationDataset, MultiEquationDataset, AbstractDataset]]: """ return a dataset module according to config :param config: An instance object of Config, used to record parameter information. :return: dataset module """ try: return eval('Dataset{}'.format(config['model'])) except: pass if config['transformers_pretrained_model'] is not None or config['pretrained_model'] is not None: return PretrainDataset task_type = config['task_type'].lower() if task_type == TaskType.SingleEquation: return SingleEquationDataset elif task_type == TaskType.MultiEquation: return MultiEquationDataset else: return AbstractDataset
[docs]def get_dataloader_module(config: Config) \ -> Type[Union[ DataLoaderMultiEncDec, DataLoaderEPT, DataLoaderHMS, DataLoaderGPT2, PretrainDataLoader, SingleEquationDataLoader, MultiEquationDataLoader, AbstractDataLoader]]: """Create dataloader according to config Args: config (mwptoolkit.config.configuration.Config): An instance object of Config, used to record parameter information. Returns: Dataloader module """ try: return eval('DataLoader{}'.format(config['model'])) except: pass if config['transformers_pretrained_model'] is not None or config['pretrained_model'] is not None: return PretrainDataLoader task_type = config['task_type'].lower() if task_type == TaskType.SingleEquation: return SingleEquationDataLoader elif task_type == TaskType.MultiEquation: return MultiEquationDataLoader else: return AbstractDataLoader