MWPToolkit
latest

MWPToolkit API:

  • mwptoolkit.config.configuration
  • mwptoolkit.data
  • mwptoolkit.evaluate.evaluator
  • mwptoolkit.loss
  • mwptoolkit.model
  • mwptoolkit.module
    • mwptoolkit.module.Attention
    • mwptoolkit.module.Decoder
    • mwptoolkit.module.Embedder
    • mwptoolkit.module.Encoder
    • mwptoolkit.module.Environment
    • mwptoolkit.module.Graph
    • mwptoolkit.module.Layer
    • mwptoolkit.module.Strategy
      • mwptoolkit.module.Strategy.beam_search
        • Beam
        • BeamNode
        • Beam_Search_Hypothesis
        • TreeBeam
      • mwptoolkit.module.Strategy.greedy
      • mwptoolkit.module.Strategy.sampling
  • mwptoolkit.trainer
  • mwptoolkit.utils
  • mwptoolkit.hyper_search
  • mwptoolkit.quick_start
MWPToolkit
  • »
  • mwptoolkit.module »
  • mwptoolkit.module.Strategy »
  • mwptoolkit.module.Strategy.beam_search
  • Edit on GitHub

mwptoolkit.module.Strategy.beam_search¶

class mwptoolkit.module.Strategy.beam_search.Beam(score, input_var, hidden, token_logits, outputs, all_output=None)[source]¶

Bases: object

class mwptoolkit.module.Strategy.beam_search.BeamNode(score, nodes_hidden, node_stacks, tree_stacks, decoder_outputs_list, sequence_symbols_list)[source]¶

Bases: object

copy()[source]¶
class mwptoolkit.module.Strategy.beam_search.Beam_Search_Hypothesis(beam_size, sos_token_idx, eos_token_idx, device, idx2token)[source]¶

Bases: object

Class designed for beam search.

generate()[source]¶

Pick the hypothesis with max prob among beam_size hypothesises.

Returns

the generated tokens

Return type

List[str]

step(gen_idx, token_logits, decoder_states=None, encoder_output=None, encoder_mask=None, input_type='token')[source]¶

A step for beam search.

Parameters
  • gen_idx (int) – the generated step number.

  • token_logits (torch.Tensor) – logits distribution, shape: [hyp_num, sequence_length, vocab_size].

  • decoder_states (torch.Tensor, optional) – the states of decoder needed to choose, shape: [hyp_num, sequence_length, hidden_size], default: None.

  • encoder_output (torch.Tensor, optional) – the output of encoder needed to copy, shape: [hyp_num, sequence_length, hidden_size], default: None.

  • encoder_mask (torch.Tensor, optional) – the mask of encoder to copy, shape: [hyp_num, sequence_length], default: None.

Returns

the next input squence, shape: [hyp_num], torch.Tensor, optional: the chosen states of decoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed output of encoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed mask of encoder, shape: [new_hyp_num, sequence_length]

Return type

torch.Tensor

stop()[source]¶

Determine if the beam search is over.

Returns

True represents the search over, Flase represents the search working.

Return type

Bool

class mwptoolkit.module.Strategy.beam_search.TreeBeam(score, node_stack, embedding_stack, left_childs, out, token_logit=None)[source]¶

Bases: object

Previous Next

© Copyright 2021, ''. Revision 0993beb3.

Built with Sphinx using a theme provided by Read the Docs.