# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 18:55:01
# @File: smoothed_cross_entropy_loss.py
import torch
from torch import nn
from mwptoolkit.loss.abstract_loss import AbstractLoss
[docs]class SmoothedCrossEntropyLoss(nn.Module):
"""
Computes cross entropy loss with uniformly smoothed targets.
"""
def __init__(self, smoothing: float = 0.1, ignore_index: int = -1, reduction: str = 'batchmean'):
"""
Cross entropy loss with uniformly smoothed targets.
:param float smoothing: Label smoothing factor, between 0 and 1 (exclusive; default is 0.1)
:param int ignore_index: Index to be ignored. (PAD_ID by default)
:param str reduction: Style of reduction to be done. One of 'batchmean'(default), 'none', or 'sum'.
"""
assert 0 < smoothing < 1, "Smoothing factor should be in (0.0, 1.0)"
assert reduction in {'batchmean', 'none', 'sum'}
super().__init__()
self.smoothing = smoothing
self.ignore_index = ignore_index
self.reduction = reduction
[docs] def forward(self, input: torch.Tensor, target: torch.LongTensor) -> torch.Tensor:
"""
Computes cross entropy loss with uniformly smoothed targets.
Since the entropy of smoothed target distribution is always same, we can compute this with KL-divergence.
:param torch.Tensor input: Log probability for each class. This is a Tensor with shape [B, C]
:param torch.LongTensor target: List of target classes. This is a LongTensor with shape [B]
:rtype: torch.Tensor
:return: Computed loss
"""
target = target.view(-1, 1)
# Prepare smoothed target
# Set all probability of the targets which should be ignored as zero.
# Since D_KL(p, q) = p (log(p) - log(q)), by setting p(x) ≡ 0, these target cannot affect loss anymore.
smoothed_target = torch.zeros(input.shape, requires_grad=False, device=target.device)
# Set target values zero if predicted values are masked with -inf.
for r, row in enumerate(input):
tgt = target[r].item()
if tgt == self.ignore_index:
continue
finites = torch.isfinite(row)
n_cls = finites.sum().item()
assert n_cls > 0
smoothing_prob = self.smoothing / n_cls
smoothed_target[r].masked_fill_(finites, smoothing_prob)
smoothed_target[r, tgt] = 1.0 - self.smoothing
# Compute loss: - p log q
loss = - smoothed_target * input.masked_fill(~torch.isfinite(input), 0.0)
if self.reduction == 'batchmean':
return loss.sum() / input.shape[0]
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
[docs]class SmoothCrossEntropyLoss(AbstractLoss):
"""
Computes cross entropy loss with uniformly smoothed targets.
"""
_NAME = "SmoothCrossEntropyLoss"
def __init__(self, weight=None, mask=None, size_average=True):
"""
Cross entropy loss with uniformly smoothed targets.
:param float smoothing: Label smoothing factor, between 0 and 1 (exclusive; default is 0.1)
:param int ignore_index: Index to be ignored. (PAD_ID by default)
:param str reduction: Style of reduction to be done. One of 'batchmean'(default), 'none', or 'sum'.
"""
super(SmoothCrossEntropyLoss, self).__init__(
self._NAME,
SmoothedCrossEntropyLoss())
self.norm_term = 1
[docs] def get_loss(self):
"""return loss
Returns:
loss (float)
"""
if isinstance(self.acc_loss, int):
return 0
loss = self.acc_loss.item() # .data[0]
loss /= self.norm_term
return loss
[docs] def eval_batch(self, outputs, target):
"""calculate loss
Args:
outputs (Tensor): output distribution of model.
target (Tensor): target classes.
"""
self.acc_loss += self.criterion(outputs, target)
self.norm_term = 1