mwptoolkit.data.dataset.abstract_dataset¶
- class mwptoolkit.data.dataset.abstract_dataset.AbstractDataset(config)[source]¶
Bases:
object
abstract dataset
the base class of dataset class
- Parameters
config (mwptoolkit.config.configuration.Config) –
expected that config includes these parameters below:
model (str): model name.
dataset (str): dataset name.
equation_fix (str): [infix | postfix | prefix], convert equation to specified format.
dataset_dir or dataset_path (str): the road path of dataset folder.
language (str): a property of dataset, the language of dataset.
single (bool): a property of dataset, the equation of dataset is single or not.
linear (bool): a property of dataset, the equation of dataset is linear or not.
source_equation_fix (str): [infix | postfix | prefix], a property of dataset, the source format of equation of dataset.
rebuild (bool): when loading additional dataset information, this can decide to build information anew or load information built before.
validset_divide (bool): whether to split validset. if True, the dataset is split to trainset-validset-testset. if False, the dataset is split to trainset-testset.
mask_symbol (str): [NUM | number], the symbol to mask numbers in equation.
min_word_keep (int): in dataset, words that count greater than the value, will be kept in input vocabulary.
min_generate_keep (int): generate number that count greater than the value, will be kept in output symbols.
symbol_for_tree (bool): build output symbols for tree or not.
share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models.
k_fold (int|None): if it’s an integer, it indicates to run k-fold cross validation. if it’s None, it indicates to run trainset-validset-testset split.
read_local_folds (bool): when running k-fold cross validation, if True, then loading split folds from dataset folder. if False, randomly split folds.
shuffle (bool): whether to shuffle trainset before training.
device (torch.device):
resume_training or resume (bool):
- cross_validation_load(k_fold, start_fold_t=None)[source]¶
dataset load for cross validation
Build folds for cross validation. Choose one of folds for testset and other folds for trainset.
- Parameters
k_fold (int) – the number of folds, also the cross validation parameter k.
start_fold_t (int) – default None, training start from the training of t-th time.
- Returns
Generator including current training index of cross validation.
- dataset_load()[source]¶
dataset process and build vocab.
when running k-fold setting, this function required to call once per fold.