mwptoolkit.module.Attention.multi_head_attention¶
- class mwptoolkit.module.Attention.multi_head_attention.EPTMultiHeadAttention(**config)[source]¶
Bases:
Module
Class for computing multi-head attention (follows the paper, ‘Attention is all you need’)
This class computes attention over K-V pairs with query Q, i.e.
Initialize MultiHeadAttention class
- Keyword Arguments
hidden_dim (int) – Vector dimension of hidden states (H). 768 by default
num_heads (int) – Number of attention heads (N). 12 by default
dropout_p (float) – Probability of dropout. 0 by default
- forward(query: Tensor, key_value: Optional[Tensor] = None, key_ignorance_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, return_weights: bool = False, **kwargs)[source]¶
Compute multi-head attention
- Parameters
query (torch.Tensor) – FloatTensor representing the query matrix with shape [batch_size, query_sequence_length, hidden_size].
key_value (torch.Tensor) – FloatTensor representing the key matrix or value matrix with shape [batch_size, key_sequence_length, hidden_size] or [1, key_sequence_length, hidden_size]. By default, this is None (Use query matrix as a key matrix).
key_ignorance_mask (torch.Tensor) – BoolTensor representing the mask for ignoring column vector in key matrix, with shape [batch_size, key_sequence_length]. If an element at (b, t) is True, then all return elements at batch_size=b, key_sequence_length=t will set to be -Infinity. By default, this is None (There’s no mask to apply).
attention_mask (torch.Tensor) – BoolTensor representing Attention mask for ignoring a key for each query item, with shape [query_sequence_length, key_sequence_length]. If an element at (s, t) is True, then all return elements at query_sequence_length=s, key_sequence_length=t will set to be -Infinity. By default, this is None (There’s no mask to apply).
return_weights (bool) – Use True to return attention weights. By default, this is True.
- Returns
If head_at_last is True, return (Attention Output, Attention Weights). Otherwise, return only the Attention Output. Attention Output: Shape [batch_size, query_sequence_length, hidden_size]. Attention Weights: Shape [batch_size, query_sequence_length, key_sequence_length, head_nums].
- Return type
Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]
- training: bool¶
- class mwptoolkit.module.Attention.multi_head_attention.EPTMultiHeadAttentionWeights(**config)[source]¶
Bases:
Module
Class for computing multi-head attention weights (follows the paper, ‘Attention is all you need’)
This class computes dot-product between query Q and key K, i.e.
Initialize MultiHeadAttentionWeights class
- Keyword Arguments
hidden_dim (int) – Vector dimension of hidden states (H). 768 by default.
num_heads (int) – Number of attention heads (N). 12 by default.
- forward(query: Tensor, key: Optional[Tensor] = None, key_ignorance_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, head_at_last: bool = True) Tensor [source]¶
Compute multi-head attention weights
- Parameters
query (torch.Tensor) – FloatTensor representing the query matrix with shape [batch_size, query_sequence_length, hidden_size].
key (torch.Tensor) – FloatTensor representing the key matrix with shape [batch_size, key_sequence_length, hidden_size] or [1, key_sequence_length, hidden_size]. By default, this is None (Use query matrix as a key matrix)
key_ignorance_mask (torch.Tensor) – BoolTensor representing the mask for ignoring column vector in key matrix, with shape [batch_size, key_sequence_length]. If an element at (b, t) is True, then all return elements at batch_size=b, key_sequence_length=t will set to be -Infinity. By default, this is None (There’s no mask to apply).
attention_mask (torch.Tensor) – BoolTensor representing Attention mask for ignoring a key for each query item, with shape [query_sequence_length, key_sequence_length]. If an element at (s, t) is True, then all return elements at sequence_length=s, T=t will set to be -Infinity. By default, this is None (There’s no mask to apply).
head_at_last (bool) – Use True to make shape of return value be [batch_size, query_sequence_length, key_sequence_length, head_nums]. If False, this method will return [batch_size, head_nums, sequence_length, key_sequence_length]. By default, this is True
- Returns
FloatTensor of Multi-head Attention weights.
- Return type
torch.FloatTensor
int :return: Vector dimension of hidden states (H)
- Type
rtype
- property num_heads: int¶
int :return: Number of attention heads (N)
- Type
rtype
- training: bool¶
- class mwptoolkit.module.Attention.multi_head_attention.MultiHeadAttention(embedding_size, num_heads, dropout_ratio=0.0)[source]¶
Bases:
Module
Multi-head Attention is proposed in the following paper: Attention Is All You Need.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(query, key, value, key_padding_mask=None, attn_mask=None)[source]¶
Multi-head attention
- Parameters
query (torch.Tensor) – shape [batch_size, tgt_len, embedding_size].
key (torch.Tensor) – shape [batch_size, src_len, embedding_size].
value (torch.Tensor) – shape [batch_size, src_len, embedding_size].
key_padding_mask (torch.Tensor) – shape [batch_size, src_len].
attn_mask (torch.BoolTensor) – shape [batch_size, tgt_len, src_len].
- Returns
attn_repre, shape [batch_size, tgt_len, embedding_size]. attn_weights, shape [batch_size, tgt_len, src_len].
- Return type
tuple(torch.Tensor, torch.Tensor)
- training: bool¶