Source code for mwptoolkit.module.Encoder.rnn_encoder

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 21:48:46
# @File: rnn_encoder.py


from copy import deepcopy

import torch
from torch import nn

from mwptoolkit.module.Attention.seq_attention import SeqAttention
from mwptoolkit.module.Attention.group_attention import GroupAttention
from mwptoolkit.module.Attention.hierarchical_attention import Attention
from mwptoolkit.module.Attention.seq_attention import Attention as Attention_x
from mwptoolkit.module.Embedder.position_embedder import PositionalEncoding
from mwptoolkit.module.Embedder.basic_embedder import BasicEmbedder
from mwptoolkit.module.Encoder.transformer_encoder import GroupATTEncoder
from mwptoolkit.module.Layer.transformer_layer import PositionwiseFeedForward, GAEncoderLayer


[docs]class BasicRNNEncoder(nn.Module): """ Basic Recurrent Neural Network (RNN) encoder. """ def __init__(self, embedding_size, hidden_size, num_layers, rnn_cell_type, dropout_ratio, bidirectional=True, batch_first=True): super(BasicRNNEncoder, self).__init__() self.rnn_cell_type = rnn_cell_type self.num_layers = num_layers self.hidden_size = hidden_size self.embedding_size = embedding_size self.bidirectional = bidirectional self.num_directions = 2 if self.bidirectional else 1 self.batch_first = batch_first if rnn_cell_type == 'lstm': self.encoder = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=batch_first, dropout=dropout_ratio, bidirectional=bidirectional) elif rnn_cell_type == 'gru': self.encoder = nn.GRU(embedding_size, hidden_size, num_layers, batch_first=batch_first, dropout=dropout_ratio, bidirectional=bidirectional) elif rnn_cell_type == 'rnn': self.encoder = nn.RNN(embedding_size, hidden_size, num_layers, batch_first=batch_first, dropout=dropout_ratio, bidirectional=bidirectional) else: raise ValueError("The RNN type of encoder must be in ['lstm', 'gru', 'rnn'].")
[docs] def init_hidden(self, input_embeddings): r""" Initialize initial hidden states of RNN. Args: input_embeddings (torch.Tensor): input sequence embedding, shape: [batch_size, sequence_length, embedding_size]. Returns: torch.Tensor: the initial hidden states. """ if self.batch_first: batch_size = input_embeddings.size(0) else: batch_size = input_embeddings.size(1) device = input_embeddings.device if self.rnn_cell_type == 'lstm': h_0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(device) c_0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(device) hidden_states = (h_0, c_0) return hidden_states elif self.rnn_cell_type == 'gru' or self.rnn_cell_type == 'rnn': tp_vec = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size) return tp_vec.to(device) else: raise NotImplementedError("No such rnn type {} for initializing encoder states.".format(self.rnn_type))
[docs] def forward(self, input_embeddings, input_length, hidden_states=None): r""" Implement the encoding process. Args: input_embeddings (torch.Tensor): source sequence embedding, shape: [batch_size, sequence_length, embedding_size]. input_length (torch.Tensor): length of input sequence, shape: [batch_size]. hidden_states (torch.Tensor): initial hidden states, default: None. Returns: tuple(torch.Tensor, torch.Tensor): output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. hidden states, shape: [batch_size, num_layers * num_directions, hidden_size]. """ if hidden_states is None: hidden_states = self.init_hidden(input_embeddings) packed_input_embeddings = torch.nn.utils.rnn.pack_padded_sequence(input_embeddings, input_length, batch_first=self.batch_first, enforce_sorted=True) outputs, hidden_states = self.encoder(packed_input_embeddings, hidden_states) outputs, outputs_length = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=self.batch_first) return outputs, hidden_states
[docs]class SelfAttentionRNNEncoder(nn.Module): """ Self Attentional Recurrent Neural Network (RNN) encoder. """ def __init__(self, embedding_size, hidden_size, context_size, num_layers, rnn_cell_type, dropout_ratio, bidirectional=True): super(SelfAttentionRNNEncoder, self).__init__() self.rnn_cell_type = rnn_cell_type self.num_layers = num_layers self.hidden_size = hidden_size self.embedding_size = embedding_size self.bidirectional = bidirectional self.num_directions = 2 if self.bidirectional else 1 if rnn_cell_type == 'lstm': self.encoder = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional) elif rnn_cell_type == 'gru': self.encoder = nn.GRU(embedding_size, hidden_size, num_layers, batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional) elif rnn_cell_type == 'rnn': self.encoder = nn.RNN(embedding_size, hidden_size, num_layers, batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional) else: raise ValueError("The RNN type of encoder must be in ['lstm', 'gru', 'rnn'].") self.attention = SeqAttention(hidden_size, context_size)
[docs] def init_hidden(self, input_embeddings): r""" Initialize initial hidden states of RNN. Args: input_embeddings (torch.Tensor): input sequence embedding, shape: [batch_size, sequence_length, embedding_size]. Returns: torch.Tensor: the initial hidden states. """ batch_size = input_embeddings.size(0) device = input_embeddings.device if self.rnn_cell_type == 'lstm': h_0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(device) c_0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(device) hidden_states = (h_0, c_0) return hidden_states elif self.rnn_cell_type == 'gru' or self.rnn_cell_type == 'rnn': tp_vec = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size) return tp_vec.to(device) else: raise NotImplementedError("No such rnn type {} for initializing encoder states.".format(self.rnn_type))
[docs] def forward(self, input_embeddings, input_length, hidden_states=None): r""" Implement the encoding process. Args: input_embeddings (torch.Tensor): source sequence embedding, shape: [batch_size, sequence_length, embedding_size]. input_length (torch.Tensor): length of input sequence, shape: [batch_size]. hidden_states (torch.Tensor): initial hidden states, default: None. Returns: tuple(torch.Tensor, torch.Tensor): output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. hidden states, shape: [batch_size, num_layers * num_directions, hidden_size]. """ if hidden_states is None: hidden_states = self.init_hidden(input_embeddings) packed_input_embeddings = torch.nn.utils.rnn.pack_padded_sequence(input_embeddings, input_length, batch_first=True, enforce_sorted=True) outputs, hidden_states = self.encoder(packed_input_embeddings, hidden_states) outputs, outputs_length = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) if self.bidirectional: encoder_outputs = outputs[:, :, self.hidden_size:] + outputs[:, :, :self.hidden_size] if (self.rnn_cell_type == 'lstm'): encoder_hidden = (hidden_states[0][::2].contiguous(), hidden_states[1][::2].contiguous()) else: encoder_hidden = hidden_states[::2].contiguous() outputs, attn = self.attention.forward(encoder_outputs, encoder_outputs, mask=None) return outputs, hidden_states
[docs]class GroupAttentionRNNEncoder(nn.Module): """Group Attentional Recurrent Neural Network (RNN) encoder. """ def __init__(self, emb_size=100, hidden_size=128, n_layers=1, bidirectional=False, \ rnn_cell=None, rnn_cell_name='gru', variable_lengths=True, \ d_ff=2048, dropout=0.3, N=1): super(GroupAttentionRNNEncoder, self).__init__() self.variable_lengths = variable_lengths self.bidirectional = bidirectional self.dropout = dropout if bidirectional: self.d_model = 2 * hidden_size else: self.d_model = hidden_size ff = PositionwiseFeedForward(self.d_model, d_ff, dropout) if rnn_cell_name.lower() == 'lstm': self.rnn_cell = nn.LSTM elif rnn_cell_name.lower() == 'gru': self.rnn_cell = nn.GRU if rnn_cell is None: self.rnn = self.rnn_cell(emb_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional, dropout=self.dropout) else: self.rnn = rnn_cell self.group_attention = GroupAttention(8, self.d_model) self.onelayer = GroupATTEncoder(GAEncoderLayer(self.d_model, deepcopy(self.group_attention), deepcopy(ff), dropout), N)
[docs] def forward(self, embedded, input_var, split_list, input_lengths=None): """ Args: embedded (torch.Tensor): embedded inputs, shape [batch_size, sequence_length, embedding_size]. input_var (torch.Tensor): source sequence, shape [batch_size, sequence_length]. split_list (list): group split index. input_lengths (torch.Tensor): length of input sequence, shape: [batch_size]. Returns: tuple(torch.Tensor, torch.Tensor): output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. hidden states, shape: [batch_size, num_layers * num_directions, hidden_size]. """ if self.variable_lengths: embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True,\ enforce_sorted=True) else: embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True,\ enforce_sorted=False) output, hidden = self.rnn(embedded) output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) src_mask = self.group_attention.get_mask(input_var, split_list) output = self.onelayer(output, src_mask) return output, hidden
[docs]class HWCPEncoder(nn.Module): r"""Hierarchical word-clause-problem encoder""" def __init__(self, embedding_model, embedding_size, hidden_size=512, span_size=0, dropout_ratio=0.4): super(HWCPEncoder, self).__init__() self.hidden_size = hidden_size self.embedding = embedding_model # word encoding self.word_rnn = nn.GRU(embedding_size, hidden_size, num_layers=2, bidirectional=True, batch_first=True, dropout=dropout_ratio) self.dropout = nn.Dropout(p=dropout_ratio) # span encoding # span sequence self.span_attn = Attention(self.hidden_size, mix=True, fn=True) self.pos_enc = PositionalEncoding(span_size, hidden_size) # merge subtree/word node self.to_parent = Attention(self.hidden_size, mix=True, fn=True) return
[docs] def forward(self, input_var, input_lengths, span_length, tree=None, output_all_layers=False): """Not implemented""" device = span_length.device word_outputs = [] span_inputs = [] input_vars = input_var trees = tree bi_word_hidden = None for span_index, input_var in enumerate(input_vars): input_length = input_lengths[span_index] # word encoding embedded = self.embedding(input_var) # word level encoding word_output, bi_word_hidden = self.word_level_forward(embedded, input_length, bi_word_hidden) word_output, word_hidden = self.bi_combine(word_output, bi_word_hidden) # tree encoding/clause level tree_batch = trees[span_index] span_span_input = self.clause_level_forward(word_output, tree_batch) span_input = torch.cat(span_span_input, dim=0) span_inputs.append(span_input.unsqueeze(1)) word_outputs.append(word_output) # span encoding / problem level span_input = torch.cat(span_inputs, dim=1) span_mask = self.get_mask(span_length, span_input.size(1)) span_output, _ = self.problem_level_forword(span_input, span_mask) span_output = span_output * (span_mask == 0).unsqueeze(-1) dim0 = torch.arange(span_output.size(0)).to(device) span_hidden = span_output[dim0, span_length - 1].unsqueeze(0) return (span_output, word_outputs), span_hidden #【4,5,512】5*【4,length,512】【1,4,512】
[docs] def word_level_forward(self, embedding_inputs, input_length, bi_word_hidden=None): # at least 1 word in some full padding span pad_input_length = input_length.clone() pad_input_length[pad_input_length == 0] = 1 embedded = nn.utils.rnn.pack_padded_sequence(embedding_inputs, pad_input_length, batch_first=True, enforce_sorted=False) word_output, bi_word_hidden = self.word_rnn(embedded, bi_word_hidden) word_output, _ = nn.utils.rnn.pad_packed_sequence(word_output, batch_first=True) #word_output, word_hidden = self.bi_combine(word_output, bi_word_hidden) return word_output, bi_word_hidden
[docs] def clause_level_forward(self, word_output, tree_batch): device = word_output.device span_span_input = [] for b_i, data_word_output in enumerate(word_output): data_word_output = data_word_output.unsqueeze(0) tree = tree_batch[b_i] if tree is not None: data_span_input = self.dependency_encode(data_word_output, tree.root) else: pad_hidden = torch.zeros(1, self.hidden_size).to(device) data_span_input = pad_hidden span_span_input.append(data_span_input) return span_span_input
[docs] def problem_level_forword(self, span_input, span_mask): span_output = self.pos_enc(span_input) span_output = self.dropout(span_output) span_output, span_attn = self.span_attn(span_output, span_output, span_mask) return span_output, span_attn
[docs] def bi_combine(self, output, hidden): # combine forward and backward LSTM # (num_layers * num_directions, batch, hidden_size).view(num_layers, num_directions, batch, hidden_size) hidden = hidden[0:hidden.size(0):2] + hidden[1:hidden.size(0):2] output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] return output, hidden
[docs] def dependency_encode(self, word_output, node): pos = node.position word_vector = word_output[:, pos] if node.is_leaf: vector = word_vector else: children = node.left_nodes + node.right_nodes children_vector = [self.dependency_encode(word_output, child).unsqueeze(1) for child in children] children_vector = torch.cat(children_vector, dim=1) query = word_vector.unsqueeze(1) vector = self.to_parent(query, children_vector)[0].squeeze(1) return vector
[docs] def get_mask(self, encode_lengths, pad_length): device = encode_lengths.device batch_size = encode_lengths.size(0) index = torch.arange(pad_length).to(device) mask = (index.unsqueeze(0).expand(batch_size, -1) >= encode_lengths.unsqueeze(-1)).byte() # save one position for full padding span to prevent nan in softmax # invalid value in full padding span will be ignored in span level attention mask[mask.sum(dim=-1) == pad_length, 0] = 0 return mask
[docs]class SalignedEncoder(nn.Module): """ Simple RNN encoder with attention which also extract variable embedding. """ def __init__(self, dim_embed, dim_hidden, dim_last, dropout_rate, dim_attn_hidden=256): """ Args: dim_embed (int): Dimension of input embedding. dim_hidden (int): Dimension of encoder RNN. dim_last (int): Dimension of the last state will be transformed to. dropout_rate (float): Dropout rate. """ super(SalignedEncoder, self).__init__() self.rnn = torch.nn.LSTM(dim_embed, dim_hidden, 1, bidirectional=True, batch_first=True) self.mlp1 = torch.nn.Sequential(torch.nn.Linear(dim_hidden * 2, dim_last), torch.nn.Dropout(dropout_rate), torch.nn.Tanh()) self.mlp2 = torch.nn.Sequential(torch.nn.Linear(dim_hidden * 2, dim_last), torch.nn.Dropout(dropout_rate), torch.nn.Tanh()) self.attn = Attention_x(dim_hidden * 2, dim_hidden * 2, dim_attn_hidden) self.register_buffer('padding', torch.zeros(dim_hidden * 2)) self.embeddings = torch.nn.Parameter(torch.normal(torch.zeros(20, 2 * dim_hidden), 0.01)) self.dim_hidden = dim_hidden
[docs] def initialize_fix_constant(self, con_len, device): self.embedding_con = [torch.nn.Parameter(torch.normal(torch.zeros(2 * self.dim_hidden), 0.01)).to(device) for c in range(con_len)]
[docs] def get_fix_constant(self): return self.embedding_con
[docs] def forward(self, inputs, lengths, constant_indices): """ Args: inputs (torch.Tensor): Indices of words, shape [batch_size, sequence_length]. length (torch.Tensor): Length of inputs, shape [batch_size]. constant_indices (list of int): Each list contains list. Return: torch.Tensor: Encoded sequence, shape [batch_size, sequence_length, hidden_size]. """ packed = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=True) hidden_state = None outputs, hidden_state = self.rnn(packed, hidden_state) outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) # reshape (2, batch, dim_hidden) to (batch, dim_hidden) hidden_state = \ (hidden_state[0].transpose(1, 0).contiguous() .view(hidden_state[0].size(1), -1), hidden_state[1].transpose(1, 0).contiguous() .view(hidden_state[1].size(1), -1)) hidden_state = \ (self.mlp1(hidden_state[0]).unsqueeze(0), self.mlp2(hidden_state[1]).unsqueeze(0)) batch_size = outputs.size(0) operands = [[outputs[b][i] for i in constant_indices[b]] for b in range(batch_size)] return outputs, hidden_state, operands