mwptoolkit.data.dataset.pretrain_dataset

class mwptoolkit.data.dataset.pretrain_dataset.PretrainDataset(config)[source]

Bases: AbstractDataset

dataset class for pre-train model.

Parameters

config (mwptoolkit.config.configuration.Config) –

expected that config includes these parameters below:

task_type (str): [single_equation | multi_equation], the type of task.

embedding (str|None): embedding module name, use pre-train model as embedding module, if None, not to use pre-train model.

rule1 (bool): convert equation according to rule 1.

rule2 (bool): convert equation according to rule 2.

parse_tree_file_name (str|None): the name of the file to save parse tree information.

pretrained_model or transformers_pretrained_model (str|None): road path or name of pretrained model.

model (str): model name.

dataset (str): dataset name.

equation_fix (str): [infix | postfix | prefix], convert equation to specified format.

dataset_dir or dataset_path (str): the road path of dataset folder.

language (str): a property of dataset, the language of dataset.

single (bool): a property of dataset, the equation of dataset is single or not.

linear (bool): a property of dataset, the equation of dataset is linear or not.

source_equation_fix (str): [infix | postfix | prefix], a property of dataset, the source format of equation of dataset.

rebuild (bool): when loading additional dataset information, this can decide to build information anew or load information built before.

validset_divide (bool): whether to split validset. if True, the dataset is split to trainset-validset-testset. if False, the dataset is split to trainset-testset.

mask_symbol (str): [NUM | number], the symbol to mask numbers in equation.

min_word_keep (int): in dataset, words that count greater than the value, will be kept in input vocabulary.

min_generate_keep (int): generate number that count greater than the value, will be kept in output symbols.

symbol_for_tree (bool): build output symbols for tree or not.

share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models.

k_fold (int|None): if it’s an integer, it indicates to run k-fold cross validation. if it’s None, it indicates to run trainset-validset-testset split.

read_local_folds (bool): when running k-fold cross validation, if True, then loading split folds from dataset folder. if False, randomly split folds.

shuffle (bool): whether to shuffle trainset before training.

device (torch.device):

resume_training or resume (bool):

get_vocab_size()[source]
Returns

the length of input vocabulary and output symbols

Return type

(tuple(int, int))

classmethod load_from_pretrained(pretrained_dir: str, resume_training=False)[source]

load dataset parameters from file.

Parameters
  • pretrained_dir – (str) folder which saved the parameter file

  • resume_training – (bool) load parameter for resuming training or not.

Returns

an instantiated object

save_dataset(save_dir: str)[source]

save dataset parameters to file.

Parameters

save_dir – (str) folder which saves the parameter file

Returns