mwptoolkit.trainer.template_trainer

class mwptoolkit.trainer.template_trainer.TemplateTrainer(config, model, dataloader, evaluator)[source]

Bases: AbstractTrainer

template trainer.

you need implement:

TemplateTrainer._build_optimizer()

TemplateTrainer._save_checkpoint()

TemplateTrainer._load_checkpoint()

TemplateTrainer._train_batch()

TemplateTrainer._eval_batch()

Parameters
  • config (config) – An instance object of Config, used to record parameter information.

  • model (Model) – An object of deep-learning model.

  • dataloader (Dataloader) – dataloader object.

  • evaluator (Evaluator) – evaluator object.

expected that config includes these parameters below:

test_step (int): the epoch number of training after which conducts the evaluation on test.

best_folds_accuracy (list|None): when running k-fold cross validation, this keeps the accuracy of folds that already run.

evaluate(eval_set)[source]
fit()[source]
test()[source]