mwptoolkit.module.Strategy.beam_search¶
- class mwptoolkit.module.Strategy.beam_search.Beam(score, input_var, hidden, token_logits, outputs, all_output=None)[source]¶
Bases:
object
- class mwptoolkit.module.Strategy.beam_search.BeamNode(score, nodes_hidden, node_stacks, tree_stacks, decoder_outputs_list, sequence_symbols_list)[source]¶
Bases:
object
- class mwptoolkit.module.Strategy.beam_search.Beam_Search_Hypothesis(beam_size, sos_token_idx, eos_token_idx, device, idx2token)[source]¶
Bases:
object
Class designed for beam search.
- generate()[source]¶
Pick the hypothesis with max prob among beam_size hypothesises.
- Returns
the generated tokens
- Return type
List[str]
- step(gen_idx, token_logits, decoder_states=None, encoder_output=None, encoder_mask=None, input_type='token')[source]¶
A step for beam search.
- Parameters
gen_idx (int) – the generated step number.
token_logits (torch.Tensor) – logits distribution, shape: [hyp_num, sequence_length, vocab_size].
decoder_states (torch.Tensor, optional) – the states of decoder needed to choose, shape: [hyp_num, sequence_length, hidden_size], default: None.
encoder_output (torch.Tensor, optional) – the output of encoder needed to copy, shape: [hyp_num, sequence_length, hidden_size], default: None.
encoder_mask (torch.Tensor, optional) – the mask of encoder to copy, shape: [hyp_num, sequence_length], default: None.
- Returns
the next input squence, shape: [hyp_num], torch.Tensor, optional: the chosen states of decoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed output of encoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed mask of encoder, shape: [new_hyp_num, sequence_length]
- Return type
torch.Tensor