Source code for mwptoolkit.module.Encoder.graph_based_encoder

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 21:48:33
# @File: graph_based_encoder.py


import torch
from torch import nn
from torch.nn import functional as F

from mwptoolkit.module.Graph.graph_module import Graph_Module,Parse_Graph_Module,Num_Graph_Module
from mwptoolkit.module.Layer.graph_layers import MeanAggregator
from mwptoolkit.module.Embedder.basic_embedder import BasicEmbedder
from mwptoolkit.utils.utils import clones

[docs]def replace_masked_values(tensor, mask, replace_with): return tensor.masked_fill((1 - mask).bool(), replace_with)
[docs]class GraphBasedEncoder(nn.Module): def __init__(self, embedding_size, hidden_size,rnn_cell_type,bidirectional, num_layers=2, dropout_ratio=0.5): super(GraphBasedEncoder, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.num_layers = num_layers self.dropout_ratio = dropout_ratio if rnn_cell_type == 'lstm': self.encoder = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=dropout_ratio, bidirectional=bidirectional) elif rnn_cell_type == 'gru': self.encoder = nn.GRU(embedding_size, hidden_size, num_layers, dropout=dropout_ratio, bidirectional=bidirectional) elif rnn_cell_type == 'rnn': self.encoder = nn.RNN(embedding_size, hidden_size, num_layers, dropout=dropout_ratio, bidirectional=bidirectional) else: raise ValueError("The RNN type of encoder must be in ['lstm', 'gru', 'rnn'].") self.gcn = Graph_Module(hidden_size, hidden_size, hidden_size)
[docs] def forward(self, input_embedding, input_lengths, batch_graph, hidden=None): """ Args: input_embedding (torch.Tensor): input variable, shape [sequence_length, batch_size, embedding_size]. input_lengths (torch.Tensor): length of input sequence, shape: [batch_size]. batch_graph (torch.Tensor): graph input variable, shape [batch_size, 5, sequence_length, sequence_length]. Returns: tuple(torch.Tensor, torch.Tensor): pade_outputs, encoded variable, shape [sequence_length, batch_size, hidden_size]. problem_output, vector representation of problem, shape [batch_size, hidden_size]. """ # Note: we run this all at once (over multiple batches of multiple sequences) packed = torch.nn.utils.rnn.pack_padded_sequence(input_embedding, input_lengths, enforce_sorted=True) pade_hidden = hidden pade_outputs, pade_hidden = self.encoder(packed, pade_hidden) pade_outputs, hidden_states = torch.nn.utils.rnn.pad_packed_sequence(pade_outputs) problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:] pade_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:] # S x B x H _, pade_outputs = self.gcn(pade_outputs, batch_graph) pade_outputs = pade_outputs.transpose(0, 1) return pade_outputs, problem_output
[docs]class GraphEncoder(nn.Module): def __init__(self, vocab_size,embedding_size,hidden_size,sample_size,sample_layer,bidirectional,dropout_ratio): super(GraphEncoder, self).__init__() self.hidden_size=hidden_size self.sample_size=sample_size self.sample_layer=sample_layer self.bidirectional=bidirectional self.embedding_size=embedding_size self.dropout = nn.Dropout(dropout_ratio) self.embedding = BasicEmbedder(vocab_size,self.embedding_size,dropout_ratio,padding_idx=0) self.fw_aggregators = nn.ModuleList() self.bw_aggregators = nn.ModuleList() for layer in range(7): self.fw_aggregators.append( MeanAggregator(2*self.hidden_size, self.hidden_size, concat=True) ) self.bw_aggregators.append( MeanAggregator(2*self.hidden_size, self.hidden_size, concat=True) ) # self.fw_aggregator_0 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregator_1 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregator_2 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregator_3 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregator_4 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregator_5 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregator_6 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_0 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_1 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_2 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_3 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_4 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_5 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.bw_aggregator_6 = MeanAggregator( # 2*self.hidden_size, self.hidden_size, concat=True) # self.fw_aggregators = [self.fw_aggregator_0, self.fw_aggregator_1, self.fw_aggregator_2, # self.fw_aggregator_3, self.fw_aggregator_4, self.fw_aggregator_5, self.fw_aggregator_6] # self.bw_aggregators = [self.bw_aggregator_0, self.bw_aggregator_1, self.bw_aggregator_2, # self.bw_aggregator_3, self.bw_aggregator_4, self.bw_aggregator_5, self.bw_aggregator_6] self.Linear_hidden = nn.Linear( 2 * self.hidden_size, self.hidden_size) #self.concat = opt.concat # self.using_gpu = False # if self.opt.gpuid > -1: # self.using_gpu = True self.embedding_bilstm = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size//2, bidirectional=True, bias = True, batch_first = True, dropout= dropout_ratio, num_layers=1) self.padding_vector = torch.randn(1,self.hidden_size, dtype = torch.float, requires_grad=True)
[docs] def forward(self, fw_adj_info, bw_adj_info, feature_info, batch_nodes): #fw_adj_info, bw_adj_info, feature_info, batch_nodes = graph_batch # print self.hidden_layer_dim # if self.using_gpu: # fw_adj_info = fw_adj_info.cuda() # bw_adj_info = bw_adj_info.cuda() # feature_info = feature_info.cuda() # batch_nodes = batch_nodes.cuda() device = batch_nodes.device #feature_by_sentence = feature_info[:-1,:].view(batch_nodes.size()[0], -1) #feature_sentence_vector = self.embedding(feature_by_sentence) feature_sentence_vector = self.embedding(feature_info) #feature_sentence_vector = self.input_dropout(feature_sentence_vector) output_vector, (ht,_) = self.embedding_bilstm(feature_sentence_vector) feature_vector = output_vector.contiguous().view(-1, self.hidden_size) # if self.using_gpu: # feature_embedded = torch.cat([feature_vector, self.padding_vector.cuda()], 0) # else: # feature_embedded = torch.cat([feature_vector, self.padding_vector], 0) #feature_embedded = torch.cat([feature_vector, self.padding_vector.to(device)], 0) feature_embedded=feature_vector batch_size = feature_embedded.size()[0] node_repres = feature_embedded.view(batch_size, -1) #fw_sampler = UniformNeighborSampler(fw_adj_info) #bw_sampler = UniformNeighborSampler(bw_adj_info) nodes = batch_nodes.long().view(-1, ) fw_hidden = F.embedding(nodes, node_repres) bw_hidden = F.embedding(nodes, node_repres) #fw_sampled_neighbors = fw_sampler((nodes, self.sample_size_per_layer)) #bw_sampled_neighbors = bw_sampler((nodes, self.sample_size_per_layer)) fw_sampled_neighbors_len = torch.tensor(0) bw_sampled_neighbors_len = torch.tensor(0) # sampler fw_tmp = fw_adj_info[nodes] fw_perm = torch.randperm(fw_tmp.size(1)) fw_tmp = fw_tmp[:,fw_perm] fw_sampled_neighbors = fw_tmp[:,:self.sample_size] bw_tmp = bw_adj_info[nodes] bw_perm = torch.randperm(bw_tmp.size(1)) bw_tmp = bw_tmp[:,bw_perm] bw_sampled_neighbors = bw_tmp[:,:self.sample_size] # begin sampling for layer in range(self.sample_layer): if layer == 0: dim_mul = 1 else: dim_mul = 1 # if self.using_gpu and layer <= 6: # self.fw_aggregators[layer] = self.fw_aggregators[layer].cuda() if layer == 0: neigh_vec_hidden = F.embedding( fw_sampled_neighbors, node_repres) tmp_sum = torch.sum(F.relu(neigh_vec_hidden), 2) tmp_mask = torch.sign(tmp_sum) fw_sampled_neighbors_len = torch.sum(tmp_mask, 1) else: # if self.using_gpu: # neigh_vec_hidden = F.embedding(fw_sampled_neighbors, torch.cat([fw_hidden, torch.zeros( # [1, dim_mul * self.hidden_layer_dim]).cuda()], 0)) # else: # neigh_vec_hidden = F.embedding(fw_sampled_neighbors, torch.cat([fw_hidden, torch.zeros( # [1, dim_mul * self.hidden_layer_dim])], 0)) neigh_vec_hidden = F.embedding( fw_sampled_neighbors, torch.cat( [fw_hidden, torch.zeros([1, dim_mul * self.hidden_size]).to(device)],dim=0 ) ) if layer > 6: fw_hidden = self.fw_aggregators[6]( (fw_hidden, neigh_vec_hidden, fw_sampled_neighbors_len)) else: fw_hidden = self.fw_aggregators[layer]( (fw_hidden, neigh_vec_hidden, fw_sampled_neighbors_len)) if self.bidirectional: # if self.using_gpu and layer <= 6: # self.bw_aggregators[layer] = self.bw_aggregators[layer].cuda( # ) if layer == 0: neigh_vec_hidden = F.embedding( bw_sampled_neighbors, node_repres) tmp_sum = torch.sum(F.relu(neigh_vec_hidden), 2) tmp_mask = torch.sign(tmp_sum) bw_sampled_neighbors_len = torch.sum(tmp_mask, 1) else: # if self.using_gpu: # neigh_vec_hidden = F.embedding(bw_sampled_neighbors, torch.cat([bw_hidden, torch.zeros( # [1, dim_mul * self.hidden_layer_dim]).cuda()], 0)) # else: # neigh_vec_hidden = F.embedding(bw_sampled_neighbors, torch.cat([bw_hidden, torch.zeros( # [1, dim_mul * self.hidden_layer_dim])], 0)) neigh_vec_hidden = F.embedding( bw_sampled_neighbors, torch.cat( [bw_hidden, torch.zeros([1, dim_mul * self.hidden_size]).to(device)], dim=0 ) ) bw_hidden = self.dropout(bw_hidden) neigh_vec_hidden = self.dropout(neigh_vec_hidden) if layer > 6: bw_hidden = self.bw_aggregators[6]( (bw_hidden, neigh_vec_hidden, bw_sampled_neighbors_len)) else: bw_hidden = self.bw_aggregators[layer]( (bw_hidden, neigh_vec_hidden, bw_sampled_neighbors_len)) fw_hidden = fw_hidden.view(-1, batch_nodes.size() [1], self.hidden_size) if self.bidirectional: bw_hidden = bw_hidden.view(-1, batch_nodes.size() [1], self.hidden_size) hidden = torch.cat([fw_hidden, bw_hidden], 2) else: hidden = fw_hidden pooled = torch.max(hidden, 1)[0] graph_embedding = pooled.view(-1, self.hidden_size) return hidden, graph_embedding, output_vector
[docs]class GraphBasedMultiEncoder(nn.Module): def __init__(self, input1_size, input2_size, embed_model, embedding1_size, embedding2_size, hidden_size, n_layers=2, hop_size=2, dropout=0.5): super(GraphBasedMultiEncoder, self).__init__() self.input1_size = input1_size self.input2_size = input2_size self.embedding1_size = embedding1_size self.embedding2_size = embedding2_size self.hidden_size = hidden_size self.n_layers = n_layers self.dropout = dropout self.hop_size = hop_size self.embedding1 = embed_model self.embedding2 = nn.Embedding(input2_size, embedding2_size, padding_idx=0) self.em_dropout = nn.Dropout(dropout) self.gru = nn.GRU(embedding1_size+embedding2_size, hidden_size, n_layers, dropout=dropout, bidirectional=True) self.parse_gnn = clones(Parse_Graph_Module(hidden_size), hop_size)
[docs] def forward(self, input1_var, input2_var, input_length, parse_graph, hidden=None): """ """ # Note: we run this all at once (over multiple batches of multiple sequences) embedded1 = self.embedding1(input1_var) # S x B x E embedded2 = self.embedding2(input2_var) embedded = torch.cat((embedded1, embedded2), dim=2) embedded = self.em_dropout(embedded) packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_length) pade_hidden = hidden pade_outputs, pade_hidden = self.gru(packed, pade_hidden) pade_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(pade_outputs) pade_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:] # S x B x H pade_outputs = pade_outputs.transpose(0, 1) for i in range(self.hop_size): pade_outputs = self.parse_gnn[i](pade_outputs, parse_graph[:,2]) pade_outputs = pade_outputs.transpose(0, 1) #problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:] return pade_outputs, pade_hidden
[docs]class NumEncoder(nn.Module): def __init__(self, node_dim, hop_size=2): super(NumEncoder, self).__init__() self.node_dim = node_dim self.hop_size = hop_size self.num_gnn = clones(Num_Graph_Module(node_dim), hop_size)
[docs] def forward(self, encoder_outputs, num_encoder_outputs, num_pos_pad, num_order_pad): num_embedding = num_encoder_outputs.clone() batch_size = num_embedding.size(0) num_mask = (num_pos_pad > -1).long() node_mask = (num_order_pad > 0).long() greater_graph_mask = num_order_pad.unsqueeze(-1).expand(batch_size, -1, num_order_pad.size(-1)) > \ num_order_pad.unsqueeze(1).expand(batch_size, num_order_pad.size(-1), -1) lower_graph_mask = num_order_pad.unsqueeze(-1).expand(batch_size, -1, num_order_pad.size(-1)) <= \ num_order_pad.unsqueeze(1).expand(batch_size, num_order_pad.size(-1), -1) greater_graph_mask = greater_graph_mask.long() lower_graph_mask = lower_graph_mask.long() diagmat = torch.diagflat(torch.ones(num_embedding.size(1), dtype=torch.long, device=num_embedding.device)) diagmat = diagmat.unsqueeze(0).expand(num_embedding.size(0), -1, -1) graph_ = node_mask.unsqueeze(1) * node_mask.unsqueeze(-1) * (1-diagmat) graph_greater = graph_ * greater_graph_mask + diagmat graph_lower = graph_ * lower_graph_mask + diagmat for i in range(self.hop_size): num_embedding = self.num_gnn[i](num_embedding, graph_greater, graph_lower) # gnn_info_vec = torch.zeros((batch_size, 1, encoder_outputs.size(-1)), # dtype=torch.float, device=num_embedding.device) # gnn_info_vec = torch.cat((encoder_outputs.transpose(0, 1), gnn_info_vec), dim=1) gnn_info_vec = torch.zeros((batch_size, encoder_outputs.size(0)+1, encoder_outputs.size(-1)), dtype=torch.float, device=num_embedding.device) clamped_number_indices = replace_masked_values(num_pos_pad, num_mask, gnn_info_vec.size(1)-1) gnn_info_vec.scatter_(1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, num_embedding.size(-1)), num_embedding) gnn_info_vec = gnn_info_vec[:, :-1, :] gnn_info_vec = gnn_info_vec.transpose(0, 1) gnn_info_vec = encoder_outputs + gnn_info_vec num_embedding = num_encoder_outputs + num_embedding problem_output = torch.max(gnn_info_vec, 0).values return gnn_info_vec, num_embedding, problem_output
# class EEHEncoder(nn.Module): # def __init__(self, # input_size, # embedding_size, # pos_embedding_size, # hidden_size, # category_size, # pretrain_emb, # edge_size, # n_layers=2, # dropout=0.5, # use_self_attention=False, # use_glove_embedding=False, # use_kas2t_encoder=False, # use_g2t_stanford=False, # use_gate=False, # use_cate=False, # use_compare=False, # use_number_encoder=False # ): # super(EEHEncoder, self).__init__() # # self.input_size = input_size # self.embedding_size = embedding_size # self.hidden_size = hidden_size # self.n_layers = n_layers # self.dropout = dropout # self.shrink_size = hidden_size # self.pos_embedding_size = pos_embedding_size # BasicRNNEncoder = importlib.import_module('mwptoolkit.module.Encoder.rnn_encoder','BasicRNNEncoder') # if use_number_encoder == True: # self.shrink_size = self.shrink_size - pos_embedding_size # if use_self_attention == True: # self.shrink_size = self.shrink_size - pos_embedding_size # # self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0) # if use_glove_embedding == True: # self.embedding.weight.data.copy_(torch.from_numpy(pretrain_emb)) # self.embedding.weight.requires_grad = False # # self.em_dropout = nn.Dropout(dropout) # # self.gru_pade = nn.GRU(embedding_size, hidden_size, n_layers, dropout=dropout, bidirectional=True) # # self.edge_embedding = nn.Embedding(edge_size, 32, padding_idx=0) # self.edge_dense1 = nn.Linear(384, 32) # self.edge_dense2 = nn.Linear(384, 32) # if use_kas2t_encoder == False: # self.rnn_encoder = BasicRNNEncoder() # if use_g2t_stanford == True: # self.gcn = Graph_Module(self.shrink_size, self.shrink_size, self.shrink_size) # else: # self.category_embedding = nn.Embedding(category_size, 384, padding_idx=0) # self.dis_embedding = DisPositionalEncoding(384, 160) # self.dis_embedding1 = DisPositionalEncoding(384, 160) # # self.dis_embedding = Dis_PositionalEncoding(384, 160) # # self.gru_pade = nn.GRU(embedding_size, 384, n_layers, dropout=dropout, bidirectional=True) # # if use_g2t_stanford == True: # self.gcn = Graph_Module(384, 384, 384) # self.gat_n = [GraphAttentionLayer(384, 384, 0.5, 0.1, concat=False) for _ in range(8)] # for i, attention in enumerate(self.gat_n): # self.add_module('attention_{}'.format(i), attention) # # self.pos_embedding = PositionalEncoding(128, 160) # # self.encoder_layers = EncoderLayer(self.shrink_size, 16, self.shrink_size, dropout) # # self.encoder_layers2 = EncoderLayer(self.shrink_size, 4, self.shrink_size, dropout) # # if use_gate == True: # self.enc_linear = nn.Linear(self.shrink_size, self.shrink_size) # self.enc_dense = nn.Linear(self.shrink_size, 1) # self.num_pos_embedding = nn.Embedding(2, pos_embedding_size, padding_idx=0) # # if use_cate == True: # self.type_linear = nn.Linear(self.shrink_size, self.shrink_size) # self.type_dense = nn.Linear(self.shrink_size, 5) # self.type_pos_embedding = nn.Embedding(5, pos_embedding_size, padding_idx=0) # # if use_compare == True: # self.pair_score = nn.Linear(self.hidden_size, 1) # # if use_number_encoder == True: # self.dense_emb = nn.Linear(self.embedding_size, self.pos_embedding_size) # # self.nums_char_embedding = nn.Embedding(16, pos_embedding_size, padding_idx=0) # self.pade_embedding = nn.Embedding(2, pos_embedding_size, padding_idx=0) # self.num_gru = nn.GRU(pos_embedding_size, pos_embedding_size, n_layers, dropout=dropout, bidirectional=True) # # self.self_attn = SelfAttention(pos_embedding_size)