Source code for mwptoolkit.module.Strategy.beam_search

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 22:12:20
# @File: beam_search.py


import copy
import torch
from torch.nn import functional as F
from mwptoolkit.utils.utils import copy_list


[docs]class Beam: # the class save the beam node def __init__(self, score, input_var, hidden,token_logits,outputs, all_output=None): self.score = score self.input_var = input_var self.hidden = hidden self.all_output = all_output self.token_logits = token_logits self.outputs = outputs
[docs]class TreeBeam: # the class save the beam node def __init__(self, score, node_stack, embedding_stack, left_childs, out, token_logit=None): self.score = score self.embedding_stack = copy_list(embedding_stack) self.node_stack = copy_list(node_stack) self.left_childs = copy_list(left_childs) self.out = copy.deepcopy(out) self.token_logit = token_logit
[docs]class BeamNode: def __init__(self, score, nodes_hidden, node_stacks, tree_stacks, decoder_outputs_list, sequence_symbols_list): self.score = score self.nodes_hidden = nodes_hidden self.node_stacks = node_stacks self.tree_stacks = tree_stacks self.decoder_outputs_list = decoder_outputs_list self.sequence_symbols_list = sequence_symbols_list return
[docs] def copy(self): node = BeamNode( self.score, self.nodes_hidden, copy_list(self.node_stacks), copy_list(self.tree_stacks), copy_list(self.decoder_outputs_list), copy_list(self.sequence_symbols_list) ) return node
[docs]class Beam_Search_Hypothesis(object): r""" Class designed for beam search. """ def __init__(self, beam_size, sos_token_idx, eos_token_idx, device, idx2token): self.beam_size = beam_size self.sos_token_idx = sos_token_idx self.eos_token_idx = eos_token_idx self.device = device self.idx2token = idx2token self.hypthetic_token_idx = [[sos_token_idx]] self.completed_hypotheses = [] self.hyp_scores = torch.zeros(1).to(device)
[docs] def generate(self): r""" Pick the hypothesis with max prob among beam_size hypothesises. Return: List[str]: the generated tokens """ generate_idx = self.hypthetic_token_idx[0][1:] if (len(self.completed_hypotheses) == 0) else max(self.completed_hypotheses, key = lambda hyp: hyp[1])[0] generate_tokens = [self.idx2token[idx.item()] for idx in generate_idx] return generate_tokens
[docs] def stop(self): r""" Determine if the beam search is over. Return: Bool: ``True`` represents the search over, `Flase` represents the search working. """ return len(self.completed_hypotheses) == self.beam_size
[docs] def step(self, gen_idx, token_logits, decoder_states=None, encoder_output=None, encoder_mask=None, input_type='token'): r""" A step for beam search. Args: gen_idx (int): the generated step number. token_logits (torch.Tensor): logits distribution, shape: [hyp_num, sequence_length, vocab_size]. decoder_states (torch.Tensor, optional): the states of decoder needed to choose, shape: [hyp_num, sequence_length, hidden_size], default: None. encoder_output (torch.Tensor, optional): the output of encoder needed to copy, shape: [hyp_num, sequence_length, hidden_size], default: None. encoder_mask (torch.Tensor, optional): the mask of encoder to copy, shape: [hyp_num, sequence_length], default: None. Return: torch.Tensor: the next input squence, shape: [hyp_num], torch.Tensor, optional: the chosen states of decoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed output of encoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed mask of encoder, shape: [new_hyp_num, sequence_length] """ token_probs = F.log_softmax(token_logits, dim=-1).squeeze(1) vocab_size = token_probs.shape[-1] live_hyp_num = self.beam_size - len(self.completed_hypotheses) tmp_hyp_scores = (self.hyp_scores.unsqueeze(1).expand_as(token_probs) + token_probs).view(-1) top_scores, top_pos = torch.topk(tmp_hyp_scores, k=live_hyp_num) hyp_ids = top_pos // vocab_size word_ids = top_pos % vocab_size new_hypotheses = [] new_ids = [] new_scores = [] for hyp_id, word_id, score in zip(hyp_ids, word_ids, top_scores): new_hyp = self.hypthetic_token_idx[hyp_id] + [word_id] if (word_id == self.eos_token_idx): self.completed_hypotheses.append((new_hyp[1:-1], score / (gen_idx - 1))) else: new_hypotheses.append(new_hyp) new_ids.append(hyp_id) new_scores.append(score) if (len(self.completed_hypotheses) == self.beam_size): none_cnt = (decoder_states is not None) + (encoder_output is not None) + (encoder_mask is not None) + 1 return [None] * none_cnt self.hypthetic_token_idx = new_hypotheses self.hyp_scores = torch.tensor(new_scores).to(self.device) hyp_num = len(self.hypthetic_token_idx) if (input_type == 'token'): input_seq = [hyp[-1] for hyp in self.hypthetic_token_idx] input_seq = torch.tensor(input_seq).unsqueeze(1).to(self.device) elif (input_type == 'whole'): input_seq = torch.tensor(self.hypthetic_token_idx).to(self.device) else: raise ValueError("The input type must be in ['token', 'whole'].") returns = [input_seq] if (decoder_states is not None): new_ids = torch.tensor(new_ids).to(self.device) decoder_states = decoder_states[:, new_ids, :] returns += [decoder_states] if (encoder_output is not None): encoder_output = encoder_output[0:1].repeat(hyp_num, 1, 1) encoder_mask = encoder_mask[0:1].repeat(hyp_num, 1) returns += [encoder_output, encoder_mask] return returns