mwptoolkit.model.Seq2Tree.trnn¶
- class mwptoolkit.model.Seq2Tree.trnn.TRNN(config, dataset)[source]¶
Bases:
Module
- Reference:
Wang et al. “Template-Based Math Word Problem Solvers with Recursive Neural Networks” in AAAI 2019.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- ans_module_calculate_loss(batch_data)[source]¶
Finish forward-propagating, calculating loss and back-propagation of answer module.
- Parameters
batch_data – one batch data.
- Returns
loss value of answer module.
- ans_module_forward(seq, seq_length, seq_mask, template, num_pos, equation_target=None, output_all_layers=False)[source]¶
- calculate_loss(batch_data: dict) Tuple[float, float] [source]¶
Finish forward-propagating, calculating loss and back-propagation.
- Parameters
batch_data – one batch data.
- Returns
seq2seq module loss, answer module loss.
- forward(seq, seq_length, seq_mask, num_pos, template_target=None, equation_target=None, output_all_layers=False)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- model_test(batch_data: dict) tuple [source]¶
Model test.
- Parameters
batch_data – one batch data.
- Returns
predicted equation, target equation.
batch_data should include keywords ‘question’, ‘ques len’, ‘equation’, ‘ques mask’, ‘num pos’, ‘num list’, ‘template’
- predict(batch_data: dict, output_all_layers=False)[source]¶
predict samples without target.
- Parameters
batch_data (dict) – one batch data.
output_all_layers (bool) – return all layer outputs of model.
- Returns
token_logits, symbol_outputs, all_layer_outputs
- seq2seq_calculate_loss(batch_data: dict) float [source]¶
Finish forward-propagating, calculating loss and back-propagation of seq2seq module.
- Parameters
batch_data – one batch data.
- Returns
loss value of seq2seq module.
- seq2seq_decoder_forward(encoder_outputs, encoder_hidden, decoder_inputs, target=None, output_all_layers=False)[source]¶
- training: bool¶