mwptoolkit.module.Decoder.tree_decoder¶
- class mwptoolkit.module.Decoder.tree_decoder.HMSDecoder(embedding_model, hidden_size, dropout, op_set, vocab_dict, class_list, device)[source]¶
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(targets=None, encoder_hidden=None, encoder_outputs=None, input_lengths=None, span_length=None, num_pos=None, max_length=None, beam_width=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- forward_beam(decoder_init_hidden, encoder_outputs, masks, embedding_masks, max_length, beam_width=1)[source]¶
- forward_step(node_stacks, tree_stacks, nodes_hidden, encoder_outputs, masks, embedding_masks, decoder_nodes_class=None)[source]¶
- forward_teacher(decoder_nodes_label, decoder_init_hidden, encoder_outputs, masks, embedding_masks, max_length=None)[source]¶
- training: bool¶
- class mwptoolkit.module.Decoder.tree_decoder.LSTMBasedTreeDecoder(embedding_size, hidden_size, op_nums, generate_size, dropout=0.5)[source]¶
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(parent_embed, left_embed, prev_embed, encoder_outputs, num_pades, padding_hidden, seq_mask, nums_mask, hidden, tree_hidden)[source]¶
- Parameters
parent_embed (list) – parent embedding, length [batch_size], list of torch.Tensor with shape [1, 2 * hidden_size].
left_embed (list) – left embedding, length [batch_size], list of torch.Tensor with shape [1, embedding_size].
prev_embed (list) – previous embedding, length [batch_size], list of torch.Tensor with shape [1, embedding_size].
encoder_outputs (torch.Tensor) – output from encoder, shape [batch_size, sequence_length, hidden_size].
num_pades (torch.Tensor) – number representation, shape [batch_size, number_size, hidden_size].
padding_hidden (torch.Tensor) – padding hidden, shape [1,hidden_size].
seq_mask (torch.BoolTensor) – sequence mask, shape [batch_size, sequence_length].
mask_nums (torch.BoolTensor) – number mask, shape [batch_size, number_size].
hidden (tuple(torch.Tensor, torch.Tensor)) – hidden states, shape [batch_size, num_directions * hidden_size].
tree_hidden (tuple(torch.Tensor, torch.Tensor)) – tree hidden states, shape [batch_size, num_directions * hidden_size].
- Returns
num_score, number score, shape [batch_size, number_size]. op, operator score, shape [batch_size, operator_size]. current_embeddings, current node representation, shape [batch_size, 1, num_directions * hidden_size]. current_context, current context representation, shape [batch_size, 1, num_directions * hidden_size]. embedding_weight, embedding weight, shape [batch_size, number_size, embedding_size]. hidden (tuple(torch.Tensor, torch.Tensor)): hidden states, shape [batch_size, num_directions * hidden_size]. tree_hidden (tuple(torch.Tensor, torch.Tensor)): tree hidden states, shape [batch_size, num_directions * hidden_size].
- Return type
tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
- training: bool¶
- class mwptoolkit.module.Decoder.tree_decoder.PredictModel(hidden_size, class_size, dropout=0.4)[source]¶
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(node_hidden, encoder_outputs, masks, embedding_masks)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class mwptoolkit.module.Decoder.tree_decoder.RNNBasedTreeDecoder(input_size, embedding_size, hidden_size, dropout_ratio)[source]¶
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(input_src, prev_c, prev_h, parent_h, sibling_state)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class mwptoolkit.module.Decoder.tree_decoder.SARTreeDecoder(hidden_size, op_nums, generate_size, dropout=0.5)[source]¶
Bases:
Module
Seq2tree decoder with Semantically-Aligned Regularization
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Semantically_Aligned_Regularization(subtree_emb, s_aligned_vector)[source]¶
- Parameters
subtree_emb (torch.Tensor) –
s_aligned_vector (torch.Tensor) –
- Returns
s_aligned_a s_aligned_d
- Return type
tuple(torch.Tensor, torch.Tensor)
- forward(node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, nums_mask)[source]¶
- Parameters
node_stacks (list) – node stacks.
left_childs (list) – representation of left childs.
encoder_outputs (torch.Tensor) – output from encoder, shape [sequence_length, batch_size, hidden_size].
num_pades (torch.Tensor) – number representation, shape [batch_size, number_size, hidden_size].
padding_hidden (torch.Tensor) – padding hidden, shape [1,hidden_size].
seq_mask (torch.BoolTensor) – sequence mask, shape [batch_size, sequence_length].
mask_nums (torch.BoolTensor) – number mask, shape [batch_size, number_size]
- Returns
num_score, number score, shape [batch_size, number_size]. op, operator score, shape [batch_size, operator_size]. current_node, current node representation, shape [batch_size, 1, hidden_size]. current_context, current context representation, shape [batch_size, 1, hidden_size]. embedding_weight, embedding weight, shape [batch_size, number_size, hidden_size].
- Return type
tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
- training: bool¶
- class mwptoolkit.module.Decoder.tree_decoder.TreeDecoder(hidden_size, op_nums, generate_size, dropout=0.5)[source]¶
Bases:
Module
Seq2tree decoder with Problem aware dynamic encoding
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, nums_mask)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶