mwptoolkit.model.Seq2Tree.sausolver¶
- class mwptoolkit.model.Seq2Tree.sausolver.SAUSolver(config, dataset)[source]¶
Bases:
Module
- Reference:
Qin et al. “Semantically-Aligned Universal Tree-Structured Solver for Math Word Problems” in EMNLP 2020.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- calculate_loss(batch_data: dict) float [source]¶
Finish forward-propagating, calculating loss and back-propagation.
- Parameters
batch_data – one batch data.
- Returns
loss value.
batch_data should include keywords ‘question’, ‘ques len’, ‘equation’, ‘equ len’, ‘num stack’, ‘num size’, ‘num pos’
- decoder_forward(encoder_outputs, problem_output, all_nums_encoder_outputs, nums_stack, seq_mask, num_mask, target=None, output_all_layers=False)[source]¶
- evaluate_tree(input_batch, input_length, generate_nums, num_pos, num_start, beam_size=5, max_length=30)[source]¶
- forward(seq, seq_length, nums_stack, num_size, num_pos, target=None, output_all_layers=False) Tuple[Tensor, Tensor, Dict[str, Any]] [source]¶
- Parameters
seq (torch.Tensor) – input sequence, shape: [batch_size, seq_length].
seq_length (torch.Tensor) – the length of sequence, shape: [batch_size].
nums_stack (list) – different positions of the same number, length:[batch_size]
num_size (list) – number of numbers of input sequence, length:[batch_size].
num_pos (list) – number positions of input sequence, length:[batch_size].
target (torch.Tensor | None) – target, shape: [batch_size, target_length], default None.
output_all_layers (bool) – return output of all layers if output_all_layers is True, default False.
:return : token_logits:[batch_size, output_length, output_size], symbol_outputs:[batch_size,output_length], model_all_outputs. :rtype: tuple(torch.Tensor, torch.Tensor, dict)
- get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size, hidden_size)[source]¶
- model_test(batch_data: dict) tuple [source]¶
Model test.
- Parameters
batch_data – one batch data.
- Returns
predicted equation, target equation.
batch_data should include keywords ‘question’, ‘ques len’, ‘equation’, ‘num stack’, ‘num pos’, ‘num list’
- predict(batch_data: dict, output_all_layers=False)[source]¶
predict samples without target.
- Parameters
batch_data (dict) – one batch data.
output_all_layers (bool) – return all layer outputs of model.
- Returns
token_logits, symbol_outputs, all_layer_outputs
- train_tree(input_batch, input_length, target_batch, target_length, nums_stack_batch, num_size_batch, generate_nums, num_pos, unk, num_start, english=False, var_nums=[], batch_first=False)[source]¶
- training: bool¶