MWPToolkit API:
GroupAttention
GroupAttention.forward()
GroupAttention.get_mask()
GroupAttention.src_to_mask()
GroupAttention.training
attention()
group_mask()
src_to_mask()
EPTMultiHeadAttention
EPTMultiHeadAttention.forward()
EPTMultiHeadAttention.training
EPTMultiHeadAttentionWeights
EPTMultiHeadAttentionWeights.forward()
EPTMultiHeadAttentionWeights.hidden_dim
EPTMultiHeadAttentionWeights.num_heads
EPTMultiHeadAttentionWeights.training
MultiHeadAttention
MultiHeadAttention.forward()
MultiHeadAttention.training
SelfAttention
SelfAttention.forward()
SelfAttention.training
SelfAttentionMask
SelfAttentionMask.forward()
SelfAttentionMask.get_mask()
SelfAttentionMask.training
Attention
Attention.forward()
Attention.training
MaskedRelevantScore
MaskedRelevantScore.forward()
MaskedRelevantScore.training
RelevantScore
RelevantScore.forward()
RelevantScore.training
SeqAttention
SeqAttention.forward()
SeqAttention.training
TreeAttention
TreeAttention.forward()
TreeAttention.training