mwptoolkit.loss.masked_cross_entropy_loss¶
- class mwptoolkit.loss.masked_cross_entropy_loss.MaskedCrossEntropyLoss[source]¶
Bases:
AbstractLoss
- mwptoolkit.loss.masked_cross_entropy_loss.masked_cross_entropy(logits, target, length)[source]¶
- Parameters
logits – A Variable containing a FloatTensor of size (batch, max_len, num_classes) which contains the unnormalized probability for each class.
target – A Variable containing a LongTensor of size (batch, max_len) which contains the index of the true class for each corresponding step.
length – A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch.
- Returns
An average loss value masked by the length.
- Return type
loss