Source code for mwptoolkit.loss.nll_loss

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 18:54:55
# @File: nll_loss.py


from torch import nn

from mwptoolkit.loss.abstract_loss import AbstractLoss
[docs]class NLLLoss(AbstractLoss): _NAME = "Avg NLLLoss" def __init__(self, weight=None, mask=None, size_average=True): """ Args: weight (Tensor, optional): a manual rescaling weight given to each class. mask (Tensor, optional): index of classes to rescale weight """ self.mask = mask self.size_average = size_average if mask is not None: if weight is None: raise ValueError("Must provide weight with a mask.") weight[mask] = 0 #weight = weight.cuda() super(NLLLoss, self).__init__( self._NAME, nn.NLLLoss(weight=weight, reduction="mean"))
[docs] def get_loss(self): """return loss Returns: loss (float) """ if isinstance(self.acc_loss, int): return 0 loss = self.acc_loss.item()#.data[0] if self.size_average: loss /= self.norm_term return loss
[docs] def eval_batch(self, outputs, target): """calculate loss Args: outputs (Tensor): output distribution of model. target (Tensor): target classes. """ self.acc_loss += self.criterion(outputs, target) self.norm_term += 1