mwptoolkit.module.Attention.group_attention

class mwptoolkit.module.Attention.group_attention.GroupAttention(h, d_model, dropout=0.1)[source]

Bases: Module

Take in model size and number of heads.

forward(query, key, value, mask=None)[source]
Parameters
  • query (torch.Tensor) – shape [batch_size, head_nums, sequence_length, dim_k].

  • key (torch.Tensor) – shape [batch_size, head_nums, sequence_length, dim_k].

  • value (torch.Tensor) – shape [batch_size, head_nums, sequence_length, dim_k].

  • mask (torch.Tensor) – group attention mask, shape [batch_size, head_nums, sequence_length, sequence_length].

Returns

shape [batch_size, sequence_length, hidden_size].

Return type

torch.Tensor

get_mask(src, split_list, pad=0)[source]
Parameters
  • src (torch.Tensor) – source sequence, shape [batch_size, sequence_length].

  • split_list (list) – group split index.

  • pad (int) – pad token index.

Returns

group attention mask, shape [batch_size, 4, sequence_length, sequence_length].

Return type

torch.Tensor

src_to_mask(src, split_list)[source]
training: bool
mwptoolkit.module.Attention.group_attention.attention(query, key, value, mask=None, dropout=None)[source]

Compute Scaled Dot Product Attention

Parameters
  • query (torch.Tensor) – shape [batch_size, sequence_length, hidden_size].

  • key (torch.Tensor) – shape [batch_size, sequence_length, hidden_size].

  • value (torch.Tensor) – shape [batch_size, sequence_length, hidden_size].

  • mask (torch.Tensor) – group attention mask, shape [batch_size, 4, sequence_length, sequence_length].

Return type

tuple(torch.Tensor, torch.Tensor)

mwptoolkit.module.Attention.group_attention.group_mask(batch, type='self', pad=0)[source]
mwptoolkit.module.Attention.group_attention.src_to_mask(src, vocab_dict)[source]