mwptoolkit.data.dataset.template_dataset

class mwptoolkit.data.dataset.template_dataset.TemplateDataset(config)[source]

Bases: AbstractDataset

template dataset.

you need implement:

TemplateDataset._preprocess()

TemplateDataset._build_symbol()

TemplateDataset._build_template_symbol()

overwrite TemplateDataset._build_vocab() if necessary

Parameters

config (mwptoolkit.config.configuration.Config) –

expected that config includes these parameters below:

model (str): model name.

dataset (str): dataset name.

equation_fix (str): [infix | postfix | prefix], convert equation to specified format.

dataset_dir or dataset_path (str): the road path of dataset folder.

language (str): a property of dataset, the language of dataset.

single (bool): a property of dataset, the equation of dataset is single or not.

linear (bool): a property of dataset, the equation of dataset is linear or not.

source_equation_fix (str): [infix | postfix | prefix], a property of dataset, the source format of equation of dataset.

rebuild (bool): when loading additional dataset information, this can decide to build information anew or load information built before.

validset_divide (bool): whether to split validset. if True, the dataset is split to trainset-validset-testset. if False, the dataset is split to trainset-testset.

mask_symbol (str): [NUM | number], the symbol to mask numbers in equation.

min_word_keep (int): in dataset, words that count greater than the value, will be kept in input vocabulary.

min_generate_keep (int): generate number that count greater than the value, will be kept in output symbols.

symbol_for_tree (bool): build output symbols for tree or not.

share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models.

k_fold (int|None): if it’s an integer, it indicates to run k-fold cross validation. if it’s None, it indicates to run trainset-validset-testset split.

read_local_folds (bool): when running k-fold cross validation, if True, then loading split folds from dataset folder. if False, randomly split folds.

shuffle (bool): whether to shuffle trainset before training.

device (torch.device):

resume_training or resume (bool):

_build_symbol()[source]

In this function, you need to implement the codes of building output vocabulary.

Specifically, you need to

  1. reset the list variables TemplateDataset.out_idx2symbol, append the generating symbols into it.

you should return a dictionary object like >>> {‘out_idx2symbol’:out_idx2symbol}

_build_template_symbol()[source]

In this function, you need to implement the codes of building output vocabulary for equation template.

Specifically, you need to

1. reset the list variables TemplateDataset.temp_idx2symbol, append the generating symbols into it. Also, you can do nothing in this function if you don’t need template.

ou should return a dictionary object like >>> {‘temp_idx2symbol’:temp_idx2symbol}

_preprocess()[source]

In this function, you need to implement the codes of data preprocessing.

Specifically, you need to

  1. format input and output of every data, including trainset, validset and testset.

  2. reset the list variables TemplateDataset.generate_list, TemplateDataset.operator_list and TemplateDataset.special_token_list.

  3. reset the integer variables TemplateDataset.copy_nums

you should return a dictionary object like >>> {

‘generate_list’:generate_list, ‘operator_list’:operator_list, ‘special_token_list’:special_token_list, ‘copy_nums’:copy_nums

}

get_vocab_size()[source]
classmethod load_from_pretrained(pretrained_dir: str, resume_training=False)[source]

load dataset parameters from file.

Parameters
  • pretrained_dir – (str) folder which saved the parameter file

  • resume_training – (bool) load parameter for resuming training or not.

Returns

an instantiated object

save_dataset(save_dir: str)[source]

save dataset parameters to file.

Parameters

save_dir – (str) folder which saves the parameter file

Returns