mwptoolkit.module.Decoder.tree_decoder

class mwptoolkit.module.Decoder.tree_decoder.HMSDecoder(embedding_model, hidden_size, dropout, op_set, vocab_dict, class_list, device)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(targets=None, encoder_hidden=None, encoder_outputs=None, input_lengths=None, span_length=None, num_pos=None, max_length=None, beam_width=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_beam(decoder_init_hidden, encoder_outputs, masks, embedding_masks, max_length, beam_width=1)[source]
forward_step(node_stacks, tree_stacks, nodes_hidden, encoder_outputs, masks, embedding_masks, decoder_nodes_class=None)[source]
forward_teacher(decoder_nodes_label, decoder_init_hidden, encoder_outputs, masks, embedding_masks, max_length=None)[source]
get_class_embedding_mask(num_pos, encoder_outputs)[source]
get_generator_embedding_mask(batch_size)[source]
get_mask(encode_lengths, pad_length)[source]
get_pad_masks(encoder_outputs, input_lengths, span_length=None)[source]
get_pointer_embedding(pointer_num_pos, encoder_outputs)[source]
get_pointer_mask(pointer_num_pos)[source]
get_pointer_meta(num_pos, sub_num_poses=None)[source]
get_predict_meta(class_list, vocab_dict, device)[source]
init_stacks(encoder_hidden)[source]
training: bool
class mwptoolkit.module.Decoder.tree_decoder.LSTMBasedTreeDecoder(embedding_size, hidden_size, op_nums, generate_size, dropout=0.5)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(parent_embed, left_embed, prev_embed, encoder_outputs, num_pades, padding_hidden, seq_mask, nums_mask, hidden, tree_hidden)[source]
Parameters
  • parent_embed (list) – parent embedding, length [batch_size], list of torch.Tensor with shape [1, 2 * hidden_size].

  • left_embed (list) – left embedding, length [batch_size], list of torch.Tensor with shape [1, embedding_size].

  • prev_embed (list) – previous embedding, length [batch_size], list of torch.Tensor with shape [1, embedding_size].

  • encoder_outputs (torch.Tensor) – output from encoder, shape [batch_size, sequence_length, hidden_size].

  • num_pades (torch.Tensor) – number representation, shape [batch_size, number_size, hidden_size].

  • padding_hidden (torch.Tensor) – padding hidden, shape [1,hidden_size].

  • seq_mask (torch.BoolTensor) – sequence mask, shape [batch_size, sequence_length].

  • mask_nums (torch.BoolTensor) – number mask, shape [batch_size, number_size].

  • hidden (tuple(torch.Tensor, torch.Tensor)) – hidden states, shape [batch_size, num_directions * hidden_size].

  • tree_hidden (tuple(torch.Tensor, torch.Tensor)) – tree hidden states, shape [batch_size, num_directions * hidden_size].

Returns

num_score, number score, shape [batch_size, number_size]. op, operator score, shape [batch_size, operator_size]. current_embeddings, current node representation, shape [batch_size, 1, num_directions * hidden_size]. current_context, current context representation, shape [batch_size, 1, num_directions * hidden_size]. embedding_weight, embedding weight, shape [batch_size, number_size, embedding_size]. hidden (tuple(torch.Tensor, torch.Tensor)): hidden states, shape [batch_size, num_directions * hidden_size]. tree_hidden (tuple(torch.Tensor, torch.Tensor)): tree hidden states, shape [batch_size, num_directions * hidden_size].

Return type

tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)

training: bool
class mwptoolkit.module.Decoder.tree_decoder.PredictModel(hidden_size, class_size, dropout=0.4)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(node_hidden, encoder_outputs, masks, embedding_masks)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

score_pn(hidden, context, embedding_masks)[source]
training: bool
class mwptoolkit.module.Decoder.tree_decoder.RNNBasedTreeDecoder(input_size, embedding_size, hidden_size, dropout_ratio)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input_src, prev_c, prev_h, parent_h, sibling_state)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class mwptoolkit.module.Decoder.tree_decoder.SARTreeDecoder(hidden_size, op_nums, generate_size, dropout=0.5)[source]

Bases: Module

Seq2tree decoder with Semantically-Aligned Regularization

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Semantically_Aligned_Regularization(subtree_emb, s_aligned_vector)[source]
Parameters
  • subtree_emb (torch.Tensor) –

  • s_aligned_vector (torch.Tensor) –

Returns

s_aligned_a s_aligned_d

Return type

tuple(torch.Tensor, torch.Tensor)

forward(node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, nums_mask)[source]
Parameters
  • node_stacks (list) – node stacks.

  • left_childs (list) – representation of left childs.

  • encoder_outputs (torch.Tensor) – output from encoder, shape [sequence_length, batch_size, hidden_size].

  • num_pades (torch.Tensor) – number representation, shape [batch_size, number_size, hidden_size].

  • padding_hidden (torch.Tensor) – padding hidden, shape [1,hidden_size].

  • seq_mask (torch.BoolTensor) – sequence mask, shape [batch_size, sequence_length].

  • mask_nums (torch.BoolTensor) – number mask, shape [batch_size, number_size]

Returns

num_score, number score, shape [batch_size, number_size]. op, operator score, shape [batch_size, operator_size]. current_node, current node representation, shape [batch_size, 1, hidden_size]. current_context, current context representation, shape [batch_size, 1, hidden_size]. embedding_weight, embedding weight, shape [batch_size, number_size, hidden_size].

Return type

tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)

training: bool
class mwptoolkit.module.Decoder.tree_decoder.TreeDecoder(hidden_size, op_nums, generate_size, dropout=0.5)[source]

Bases: Module

Seq2tree decoder with Problem aware dynamic encoding

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, nums_mask)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool