Source code for mwptoolkit.loss.mse_loss

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 18:54:15
# @File: mse_loss.py


import torch
from torch.nn import functional as F
from mwptoolkit.loss.abstract_loss import AbstractLoss

[docs]class MSELoss(AbstractLoss): _Name="mean squared error loss" def __init__(self): super().__init__(self._Name, F.mse_loss)
[docs] def get_loss(self): """return loss Returns: loss (float) """ if isinstance(self.acc_loss, int): return 0 loss = self.acc_loss.item() loss /= self.norm_term return loss
[docs] def eval_batch(self, outputs, target): """calculate loss Args: outputs (Tensor): output distribution of model. target (Tensor): target distribution. """ self.acc_loss += self.criterion(outputs, target) self.norm_term += 1