mwptoolkit.data.dataset.abstract_dataset

class mwptoolkit.data.dataset.abstract_dataset.AbstractDataset(config)[source]

Bases: object

abstract dataset

the base class of dataset class

Parameters

config (mwptoolkit.config.configuration.Config) –

expected that config includes these parameters below:

model (str): model name.

dataset (str): dataset name.

equation_fix (str): [infix | postfix | prefix], convert equation to specified format.

dataset_dir or dataset_path (str): the road path of dataset folder.

language (str): a property of dataset, the language of dataset.

single (bool): a property of dataset, the equation of dataset is single or not.

linear (bool): a property of dataset, the equation of dataset is linear or not.

source_equation_fix (str): [infix | postfix | prefix], a property of dataset, the source format of equation of dataset.

rebuild (bool): when loading additional dataset information, this can decide to build information anew or load information built before.

validset_divide (bool): whether to split validset. if True, the dataset is split to trainset-validset-testset. if False, the dataset is split to trainset-testset.

mask_symbol (str): [NUM | number], the symbol to mask numbers in equation.

min_word_keep (int): in dataset, words that count greater than the value, will be kept in input vocabulary.

min_generate_keep (int): generate number that count greater than the value, will be kept in output symbols.

symbol_for_tree (bool): build output symbols for tree or not.

share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models.

k_fold (int|None): if it’s an integer, it indicates to run k-fold cross validation. if it’s None, it indicates to run trainset-validset-testset split.

read_local_folds (bool): when running k-fold cross validation, if True, then loading split folds from dataset folder. if False, randomly split folds.

shuffle (bool): whether to shuffle trainset before training.

device (torch.device):

resume_training or resume (bool):

_load_dataset()[source]

read dataset from files

_load_fold_dataset()[source]

read one fold of dataset from file.

cross_validation_load(k_fold, start_fold_t=None)[source]

dataset load for cross validation

Build folds for cross validation. Choose one of folds for testset and other folds for trainset.

Parameters
  • k_fold (int) – the number of folds, also the cross validation parameter k.

  • start_fold_t (int) – default None, training start from the training of t-th time.

Returns

Generator including current training index of cross validation.

dataset_load()[source]

dataset process and build vocab.

when running k-fold setting, this function required to call once per fold.

en_rule1_process(k)[source]
en_rule2_process()[source]
fix_process(fix)[source]

equation infix/postfix/prefix process.

Parameters

fix (function) – a function to make infix, postfix, prefix or None

classmethod load_from_pretrained(pretrained_dir)[source]
operator_mask_process()[source]

operator mask process of equation.

parameters_to_dict()[source]

return the parameters of dataset as format of dict. :return:

reset_dataset()[source]
save_dataset(trained_dir)[source]