Source code for mwptoolkit.module.Decoder.rnn_decoder

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 11:11:07
# @File: rnn_decoder.py


import torch
from torch import nn

from mwptoolkit.module.Attention.seq_attention import SeqAttention,Attention,MaskedRelevantScore
from mwptoolkit.module.Layer.layers import Transformer
from mwptoolkit.module.Environment.stack_machine import OPERATIONS

[docs]class BasicRNNDecoder(nn.Module): r""" Basic Recurrent Neural Network (RNN) decoder. """ def __init__(self, embedding_size, hidden_size, num_layers, rnn_cell_type, dropout_ratio=0.0): super(BasicRNNDecoder, self).__init__() self.rnn_cell_type = rnn_cell_type self.num_layers = num_layers self.hidden_size = hidden_size self.embedding_size = embedding_size if rnn_cell_type == 'lstm': self.decoder = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, dropout=dropout_ratio) elif rnn_cell_type == "gru": self.decoder = nn.GRU(embedding_size, hidden_size, num_layers, batch_first=True, dropout=dropout_ratio) elif rnn_cell_type == "rnn": self.decoder = nn.RNN(embedding_size, hidden_size, num_layers, batch_first=True, dropout=dropout_ratio) else: raise ValueError("The RNN type in decoder must in ['lstm', 'gru', 'rnn'].")
[docs] def init_hidden(self, input_embeddings): r""" Initialize initial hidden states of RNN. Args: input_embeddings (torch.Tensor): input sequence embedding, shape: [batch_size, sequence_length, embedding_size]. Returns: torch.Tensor: the initial hidden states. """ batch_size = input_embeddings.size(0) device = input_embeddings.device if self.rnn_cell_type == 'lstm': h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device) c_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device) hidden_states = (h_0, c_0) return hidden_states elif self.rnn_cell_type == 'gru' or self.rnn_cell_type == 'rnn': return torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device) else: raise NotImplementedError("No such rnn type {} for initializing decoder states.".format(self.rnn_type))
[docs] def forward(self, input_embeddings, hidden_states=None): r""" Implement the decoding process. Args: input_embeddings (torch.Tensor): target sequence embedding, shape: [batch_size, sequence_length, embedding_size]. hidden_states (torch.Tensor): initial hidden states, default: None. Returns: tuple(torch.Tensor, torch.Tensor): output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. hidden states, shape: [batch_size, num_layers * num_directions, hidden_size]. """ if hidden_states is None: hidden_states = self.init_hidden(input_embeddings) # hidden_states = hidden_states.contiguous() outputs, hidden_states = self.decoder(input_embeddings, hidden_states) return outputs, hidden_states
[docs]class AttentionalRNNDecoder(nn.Module): r""" Attention-based Recurrent Neural Network (RNN) decoder. """ def __init__(self, embedding_size, hidden_size, context_size, num_dec_layers, rnn_cell_type, dropout_ratio=0.0): super(AttentionalRNNDecoder, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.context_size = context_size self.num_dec_layers = num_dec_layers self.rnn_cell_type = rnn_cell_type self.attentioner=SeqAttention(hidden_size,hidden_size) if rnn_cell_type == 'lstm': self.decoder = nn.LSTM(embedding_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_cell_type == 'gru': self.decoder = nn.GRU(embedding_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_cell_type == 'rnn': self.decoder = nn.RNN(embedding_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) else: raise ValueError("RNN type in attentional decoder must be in ['lstm', 'gru', 'rnn'].") self.attention_dense = nn.Linear(hidden_size, hidden_size)
[docs] def init_hidden(self, input_embeddings): r""" Initialize initial hidden states of RNN. Args: input_embeddings (torch.Tensor): input sequence embedding, shape: [batch_size, sequence_length, embedding_size]. Returns: torch.Tensor: the initial hidden states. """ batch_size = input_embeddings.size(0) device = input_embeddings.device if self.rnn_cell_type == 'lstm': h_0 = torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) c_0 = torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) hidden_states = (h_0, c_0) return hidden_states elif self.rnn_cell_type == 'gru' or self.rnn_cell_type == 'rnn': return torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) else: raise NotImplementedError("No such rnn type {} for initializing decoder states.".format(self.rnn_cell_type))
[docs] def forward(self, input_embeddings, hidden_states=None, encoder_outputs=None, encoder_masks=None): r""" Implement the attention-based decoding process. Args: input_embeddings (torch.Tensor): source sequence embedding, shape: [batch_size, sequence_length, embedding_size]. hidden_states (torch.Tensor): initial hidden states, default: None. encoder_outputs (torch.Tensor): encoder output features, shape: [batch_size, sequence_length, hidden_size], default: None. encoder_masks (torch.Tensor): encoder state masks, shape: [batch_size, sequence_length], default: None. Returns: tuple(torch.Tensor, torch.Tensor): output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. hidden states, shape: [batch_size, num_layers * num_directions, hidden_size]. """ if hidden_states is None: hidden_states = self.init_hidden(input_embeddings) decode_length = input_embeddings.size(1) all_outputs = [] for step in range(decode_length): output, hidden_states = self.decoder(input_embeddings[:,step,:].unsqueeze(1), hidden_states) output, attn = self.attentioner(output, encoder_outputs,encoder_masks) output=self.attention_dense(output.view(-1,self.hidden_size)) output=output.view(-1,1,self.hidden_size) all_outputs.append(output) outputs = torch.cat(all_outputs, dim=1) return outputs, hidden_states
[docs]class SalignedDecoder(nn.Module): def __init__(self, operations, dim_hidden=300, dropout_rate=0.5, device=None): super(SalignedDecoder, self).__init__() self.NOOP = operations.NOOP self.GEN_VAR = operations.GEN_VAR self.ADD = operations.ADD self.SUB = operations.SUB self.MUL = operations.MUL self.DIV = operations.DIV self.POWER = operations.POWER self.EQL = operations.EQL self.N_OPS = operations.N_OPS self.PAD = operations.PAD self.RAW_EQL = operations.RAW_EQL self.BRG = operations.BRG self._device = device self.transformer_add = Transformer(2 * dim_hidden) self.transformer_sub = Transformer(2 * dim_hidden) self.transformer_mul = Transformer(2 * dim_hidden) self.transformer_div = Transformer(2 * dim_hidden) self.transformer_power = Transformer(2 * dim_hidden) self.transformers = { self.ADD: self.transformer_add, self.SUB: self.transformer_sub, self.MUL: self.transformer_mul, self.DIV: self.transformer_div, self.POWER: self.transformer_power, self.RAW_EQL: None, self.BRG: None} self.gen_var = Attention(2 * dim_hidden, dim_hidden, dropout_rate=0.0) self.attention = Attention(2 * dim_hidden, dim_hidden, dropout_rate=dropout_rate) self.choose_arg = MaskedRelevantScore( dim_hidden * 2, dim_hidden * 7, dropout_rate=dropout_rate) self.arg_gate = torch.nn.Linear( dim_hidden * 7, 3, torch.nn.Sigmoid() ) self.rnn = torch.nn.LSTM(2 * dim_hidden, dim_hidden, 1, batch_first=True) self.op_selector = torch.nn.Sequential( torch.nn.Linear(dim_hidden * 7, 256), torch.nn.ReLU(), torch.nn.Dropout(dropout_rate), torch.nn.Linear(256, self.N_OPS+1)) self.op_gate = torch.nn.Linear( dim_hidden * 7, 3, torch.nn.Sigmoid() ) self.dropout = torch.nn.Dropout(dropout_rate) self.register_buffer('noop_padding_return', torch.zeros(dim_hidden * 2)) self.register_buffer('padding_embedding', torch.zeros(dim_hidden * 2))
[docs] def forward(self, context, text_len, operands, stacks, prev_op, prev_output, prev_state, number_emb, N_OPS): """ Args: context (torch.Tensor): Encoded context, with size [batch_size, text_len, dim_hidden]. text_len (torch.Tensor): Text length for each problem in the batch. operands (list of torch.Tensor): List of operands embeddings for each problem in the batch. Each element in the list is of size [n_operands, dim_hidden]. stacks (list of StackMachine): List of stack machines used for each problem. prev_op (torch.LongTensor): Previous operation, with size [batch, 1]. prev_arg (torch.LongTensor): Previous argument indices, with size [batch, 1]. Can be None for the first step. prev_output (torch.Tensor): Previous decoder RNN outputs, with size [batch, dim_hidden]. Can be None for the first step. prev_state (torch.Tensor): Previous decoder RNN state, with size [batch, dim_hidden]. Can be None for the first step. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): op_logits: Logits of operation selection. arg_logits: Logits of argument choosing. outputs: Outputs of decoder RNN. state: Hidden state of decoder RNN. """ batch_size = context.size(0) # collect stack states stack_states = \ torch.stack([stack.get_top2().view(-1,) for stack in stacks], dim=0).to(self._device) #print('stack_states', stack_states) # skip the first step (all NOOP) if prev_output is not None: # result calculated batch-wise batch_result = { self.ADD: self.transformer_add(stack_states), self.SUB: self.transformer_sub(stack_states), self.MUL: self.transformer_mul(stack_states), self.DIV: self.transformer_div(stack_states), self.POWER: self.transformer_power(stack_states) } prev_returns = [] # apply previous op on stacks for b in range(batch_size): #print('prev_op[b].item()', prev_op[b].item()) #print(prev_op[b].item(), OPERATIONS.NOOP, OPERATIONS.GEN_VAR, OPERATIONS.EQL); exit() # no op if prev_op[b].item() == self.NOOP: ret = self.noop_padding_return elif prev_op[b].item() == self.PAD: ret = self.noop_padding_return # generate variable elif prev_op[b].item() == self.GEN_VAR: variable = batch_result[self.GEN_VAR][b] operands[b].append(variable) stacks[b].add_variable(variable) ret = variable #print('add_variable', stacks[b]._operands) # OPERATIONS.ADD, SUB, MUL, DIV elif prev_op[b].item() in [self.ADD, self.SUB, self.MUL, self.DIV, self.POWER]: #print('>>> OPERATIONS.ADD, SUB, MUL, DIV', len(stacks[b]._stack)) transformed = batch_result[prev_op[b].item()][b] #print('transformed', transformed) ret = stacks[b].apply_embed_only( prev_op[b].item(), transformed) # elif prev_op[b].item() in [self.RAW_EQL, self.BRG]: # ret = stacks[b].apply_embed_only(prev_op[b].item(), None) elif prev_op[b].item() == self.EQL: ret = stacks[b].apply_eql(prev_op[b].item()) # push operand else: #if b == 0: print('>>> push operand', len(stacks[b]._stack)) stacks[b].push(prev_op[b].item() - N_OPS) #ret = operands[b][prev_op[b].item() - N_OPS] ret = number_emb[b][prev_op[b].item() - N_OPS] prev_returns.append(ret) #exit() # collect stack states (after applied op) stack_states = \ torch.stack([stack.get_top2().view(-1,) for stack in stacks], dim=0).to(self._device) # collect previous returns prev_returns = torch.stack(prev_returns) prev_returns = self.dropout(prev_returns) # decode outputs, hidden_state = self.rnn(prev_returns.unsqueeze(1), prev_state) outputs = outputs.squeeze(1) # attention #print(context, outputs, text_len) attention = self.attention(context, outputs, text_len) # collect information for op selector #print(outputs, stack_states, attention) gate_in = torch.cat([outputs, stack_states, attention], -1) op_gate_in = self.dropout(gate_in) op_gate = self.op_gate(op_gate_in) arg_gate_in = self.dropout(gate_in) arg_gate = self.arg_gate(arg_gate_in) op_in = torch.cat([op_gate[:, 0:1] * outputs, op_gate[:, 1:2] * stack_states, op_gate[:, 2:3] * attention], -1) arg_in = torch.cat([arg_gate[:, 0:1] * outputs, arg_gate[:, 1:2] * stack_states, arg_gate[:, 2:3] * attention], -1) #print('op_in', op_in.size(), 'arg_in', arg_in.size()) # op_in = arg_in = torch.cat([outputs, stack_states, attention], -1) op_logits = self.op_selector(op_in) n_operands, cated_operands = \ self.pad_and_cat(operands, self.padding_embedding) #print('cated_operands, arg_in, n_operands', cated_operands.size(), arg_in.size(), n_operands) arg_logits = self.choose_arg( cated_operands, arg_in, n_operands) #print('arg_logits', arg_logits.size()) return op_logits, arg_logits, outputs, hidden_state
[docs] def pad_and_cat(self, tensors, padding): """ Pad lists to have same number of elements, and concatenate those elements to a 3d tensor. Args: tensors (list of list of Tensors): Each list contains list of operand embeddings. Each operand embedding is of size (dim_element,). padding (Tensor): Element used to pad lists, with size (dim_element,). Return: n_tensors (list of int): Length of lists in tensors. tensors (Tensor): Concatenated tensor after padding the list. """ n_tensors = [len(ts) for ts in tensors] #print('n_tensors', n_tensors) pad_size = max(n_tensors) # pad to has same number of operands for each problem tensors = [ts + (pad_size - len(ts)) * [padding] for ts in tensors] # tensors.size() = (batch_size, pad_size, dim_hidden) tensors = torch.stack([torch.stack(t) for t in tensors], dim=0) return n_tensors, tensors