mwptoolkit.module.Attention.group_attention¶
- class mwptoolkit.module.Attention.group_attention.GroupAttention(h, d_model, dropout=0.1)[source]¶
Bases:
Module
Take in model size and number of heads.
- forward(query, key, value, mask=None)[source]¶
- Parameters
query (torch.Tensor) – shape [batch_size, head_nums, sequence_length, dim_k].
key (torch.Tensor) – shape [batch_size, head_nums, sequence_length, dim_k].
value (torch.Tensor) – shape [batch_size, head_nums, sequence_length, dim_k].
mask (torch.Tensor) – group attention mask, shape [batch_size, head_nums, sequence_length, sequence_length].
- Returns
shape [batch_size, sequence_length, hidden_size].
- Return type
torch.Tensor
- get_mask(src, split_list, pad=0)[source]¶
- Parameters
src (torch.Tensor) – source sequence, shape [batch_size, sequence_length].
split_list (list) – group split index.
pad (int) – pad token index.
- Returns
group attention mask, shape [batch_size, 4, sequence_length, sequence_length].
- Return type
torch.Tensor
- training: bool¶
- mwptoolkit.module.Attention.group_attention.attention(query, key, value, mask=None, dropout=None)[source]¶
Compute Scaled Dot Product Attention
- Parameters
query (torch.Tensor) – shape [batch_size, sequence_length, hidden_size].
key (torch.Tensor) – shape [batch_size, sequence_length, hidden_size].
value (torch.Tensor) – shape [batch_size, sequence_length, hidden_size].
mask (torch.Tensor) – group attention mask, shape [batch_size, 4, sequence_length, sequence_length].
- Return type
tuple(torch.Tensor, torch.Tensor)