Source code for mwptoolkit.trainer.template_trainer

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 22:14:27
# @File: template_trainer.py


import time

from mwptoolkit.trainer.abstract_trainer import AbstractTrainer
from mwptoolkit.utils.enum_type import DatasetType
from mwptoolkit.utils.utils import time_since


[docs]class TemplateTrainer(AbstractTrainer): r"""template trainer. you need implement: TemplateTrainer._build_optimizer() TemplateTrainer._save_checkpoint() TemplateTrainer._load_checkpoint() TemplateTrainer._train_batch() TemplateTrainer._eval_batch() """ def __init__(self, config, model, dataloader, evaluator): super().__init__(config, model, dataloader, evaluator) def _build_optimizer(self): raise NotImplementedError def _save_checkpoint(self): raise NotImplementedError def _load_checkpoint(self): raise NotImplementedError def _build_loss(self): raise NotImplementedError def _idx2word_2idx(self): raise NotImplementedError def _train_batch(self, batch): raise NotImplementedError def _eval_batch(self, batch): raise NotImplementedError def _train_epoch(self): epoch_start_time = time.time() loss_total = 0. self.model.train() for batch_idx, batch in enumerate(self.dataloader.load_data(DatasetType.Train)): self.batch_idx = batch_idx + 1 self.model.zero_grad() batch_loss = self._train_batch(batch) loss_total += batch_loss self.loss.backward() self.optimizer.step() self.loss.reset() epoch_time_cost = time_since(time.time() - epoch_start_time) return loss_total, epoch_time_cost
[docs] def fit(self): train_batch_size = self.config["train_batch_size"] epoch_nums = self.config["epoch_nums"] self.train_batch_nums = int(self.dataloader.trainset_nums / train_batch_size) + 1 self.logger.info("start training...") for epo in range(self.start_epoch, epoch_nums): self.epoch_i = epo + 1 self.model.train() loss_total, train_time_cost = self._train_epoch() self.logger.info("epoch [%3d] avr loss [%2.8f] | train time %s"\ %(self.epoch_i,loss_total/self.train_batch_nums,train_time_cost)) if epo % self.test_step == 0 or epo > epoch_nums - 5: valid_equ_ac, valid_val_ac, valid_total, valid_time_cost = self.evaluate(DatasetType.Valid) self.logger.info("----------- valid total [%d] | valid equ acc [%2.3f] | valid value acc [%2.3f] | valid time %s"\ %(valid_total,valid_equ_ac,valid_val_ac,valid_time_cost)) test_equ_ac, test_val_ac, test_total, test_time_cost = self.evaluate(DatasetType.Test) self.logger.info("----------- test total [%d] | test equ acc [%2.3f] | test value acc [%2.3f] | test time %s"\ %(test_total,test_equ_ac,test_val_ac,test_time_cost)) if valid_val_ac >= self.best_valid_value_accuracy: self.best_valid_value_accuracy = valid_val_ac self.best_valid_equ_accuracy = valid_equ_ac self.best_test_value_accuracy = test_val_ac self.best_test_equ_accuracy = test_equ_ac self._save_model() if epo % 5 == 0: self._save_checkpoint() self.logger.info('''training finished. best valid result: equation accuracy [%2.3f] | value accuracy [%2.3f] best test result : equation accuracy [%2.3f] | value accuracy [%2.3f]'''\ %(self.best_valid_equ_accuracy,self.best_valid_value_accuracy,\ self.best_test_equ_accuracy,self.best_test_value_accuracy))
[docs] def evaluate(self, eval_set): self.model.eval() value_ac = 0 equation_ac = 0 eval_total = 0 test_start_time = time.time() for batch in self.dataloader.load_data(eval_set): batch_val_ac, batch_equ_ac = self._eval_batch(batch) value_ac += batch_val_ac.count(True) equation_ac += batch_equ_ac.count(True) eval_total += len(batch_val_ac) test_time_cost = time_since(time.time() - test_start_time) return equation_ac / eval_total, value_ac / eval_total, eval_total, test_time_cost
[docs] def test(self): self._load_model() self.model.eval() value_ac = 0 equation_ac = 0 eval_total = 0 test_start_time = time.time() for batch in self.dataloader.load_data(DatasetType.Test): batch_val_ac, batch_equ_ac = self._eval_batch(batch) value_ac += batch_val_ac.count(True) equation_ac += batch_equ_ac.count(True) eval_total += len(batch_val_ac) test_time_cost = time_since(time.time() - test_start_time) self.logger.info("test total [%d] | test equ acc [%2.3f] | test value acc [%2.3f] | test time %s"\ %(eval_total,equation_ac/eval_total,value_ac/eval_total,test_time_cost))