mwptoolkit.model.Seq2Tree.trnn

class mwptoolkit.model.Seq2Tree.trnn.TRNN(config, dataset)[source]

Bases: Module

Reference:

Wang et al. “Template-Based Math Word Problem Solvers with Recursive Neural Networks” in AAAI 2019.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

ans_module_calculate_loss(batch_data)[source]

Finish forward-propagating, calculating loss and back-propagation of answer module.

Parameters

batch_data – one batch data.

Returns

loss value of answer module.

ans_module_forward(seq, seq_length, seq_mask, template, num_pos, equation_target=None, output_all_layers=False)[source]
calculate_loss(batch_data: dict) Tuple[float, float][source]

Finish forward-propagating, calculating loss and back-propagation.

Parameters

batch_data – one batch data.

Returns

seq2seq module loss, answer module loss.

convert_idx2symbol(output, num_list)[source]
convert_in_idx_2_temp_idx(output)[source]
convert_temp_idx2symbol(output)[source]
convert_temp_idx_2_in_idx(output)[source]
forward(seq, seq_length, seq_mask, num_pos, template_target=None, equation_target=None, output_all_layers=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_seq2seq_decoder_inputs(target, device, batch_size)[source]
mask2num(output, num_list)[source]
model_test(batch_data: dict) tuple[source]

Model test.

Parameters

batch_data – one batch data.

Returns

predicted equation, target equation.

batch_data should include keywords ‘question’, ‘ques len’, ‘equation’, ‘ques mask’, ‘num pos’, ‘num list’, ‘template’

predict(batch_data: dict, output_all_layers=False)[source]

predict samples without target.

Parameters
  • batch_data (dict) – one batch data.

  • output_all_layers (bool) – return all layer outputs of model.

Returns

token_logits, symbol_outputs, all_layer_outputs

seq2seq_calculate_loss(batch_data: dict) float[source]

Finish forward-propagating, calculating loss and back-propagation of seq2seq module.

Parameters

batch_data – one batch data.

Returns

loss value of seq2seq module.

seq2seq_decoder_forward(encoder_outputs, encoder_hidden, decoder_inputs, target=None, output_all_layers=False)[source]
seq2seq_encoder_forward(seq_emb, seq_length, output_all_layers=False)[source]
seq2seq_forward(seq, seq_length, target=None, output_all_layers=False)[source]
seq2seq_generate_t(encoder_outputs, encoder_hidden, decoder_inputs)[source]
seq2seq_generate_without_t(encoder_outputs, encoder_hidden, decoder_input)[source]
symbol2idx(symbols)[source]

symbol to idx equation symbol to equation idx

template2tree(template)[source]
training: bool
tree2equation(tree)[source]