mwptoolkit.module.Attention.tree_attentio

class mwptoolkit.module.Attention.tree_attention.TreeAttention(input_size, hidden_size)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(hidden, encoder_outputs, seq_mask=None)[source]
Parameters
  • hidden (torch.Tensor) – hidden representation, shape [1, batch_size, hidden_size]

  • encoder_outputs (torch.Tensor) – output from encoder, shape [sequence_length, batch_size, hidden_size].

  • seq_mask (torch.Tensor) – sequence mask, shape [batch_size, sequence_length].

Returns

attention energies, shape [batch_size, 1, sequence_length].

Return type

attn_energies (torch.Tensor)

training: bool