mwptoolkit.module.Attention.tree_attentio¶
- class mwptoolkit.module.Attention.tree_attention.TreeAttention(input_size, hidden_size)[source]¶
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(hidden, encoder_outputs, seq_mask=None)[source]¶
- Parameters
hidden (torch.Tensor) – hidden representation, shape [1, batch_size, hidden_size]
encoder_outputs (torch.Tensor) – output from encoder, shape [sequence_length, batch_size, hidden_size].
seq_mask (torch.Tensor) – sequence mask, shape [batch_size, sequence_length].
- Returns
attention energies, shape [batch_size, 1, sequence_length].
- Return type
attn_energies (torch.Tensor)
- training: bool¶