Source code for mwptoolkit.module.Layer.tree_layers

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 22:11:58
# @File:

import torch
from torch import nn
from torch.nn import functional as F

from mwptoolkit.utils.enum_type import SpecialTokens

[docs]class TreeNode: # the class save the tree node def __init__(self, embedding, left_flag=False,terminal=False): self.embedding = embedding self.left_flag = left_flag
[docs]class NodeEmbeddingNode: def __init__(self, node_hidden, node_context=None, label_embedding=None): self.node_hidden = node_hidden self.node_context = node_context self.label_embedding = label_embedding return
[docs]class Node(): def __init__(self, node_value, isleaf=True): self.node_value = node_value self.is_leaf = isleaf self.embedding = None self.left_node = None self.right_node = None
[docs] def set_left_node(self, node): self.left_node = node
[docs] def set_right_node(self, node): self.right_node = node
[docs]class TreeEmbedding: # the class save the tree def __init__(self, embedding, terminal=False): self.embedding = embedding self.terminal = terminal
[docs]class Score(nn.Module): def __init__(self, input_size, hidden_size): super(Score, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.attn = nn.Linear(hidden_size + input_size, hidden_size) self.score = nn.Linear(hidden_size, 1, bias=False)
[docs] def forward(self, hidden, num_embeddings, num_mask=None): """ Args: hidden (torch.Tensor): hidden representation, shape [batch_size, 1, hidden_size + input_size]. num_embeddings (torch.Tensor): number embedding, shape [batch_size, number_size, hidden_size]. num_mask (torch.BoolTensor): number mask, shape [batch_size, number_size]. Returns: score (torch.Tensor): shape [batch_size, number_size]. """ max_len = num_embeddings.size(1) repeat_dims = [1] * hidden.dim() repeat_dims[1] = max_len hidden = hidden.repeat(*repeat_dims) # B x O x H # For each position of encoder outputs this_batch_size = num_embeddings.size(0) energy_in =, num_embeddings), 2).view(-1, self.input_size + self.hidden_size) score = self.score(torch.tanh(self.attn(energy_in))) # (B x O) x 1 score = score.squeeze(1) score = score.view(this_batch_size, -1) # B x O if num_mask is not None: score = score.masked_fill_(num_mask.bool(), -1e12) return score
[docs]class ScoreModel(nn.Module): def __init__(self, hidden_size): super(ScoreModel, self).__init__() self.w = nn.Linear(hidden_size * 3, hidden_size) self.score = nn.Linear(hidden_size, 1)
[docs] def forward(self, hidden, context, token_embeddings): # hidden/context: batch_size * hidden_size # token_embeddings: batch_size * class_size * hidden_size batch_size, class_size, _ = token_embeddings.size() hc =, context), dim=-1) # (b, c, h) hc = hc.unsqueeze(1).expand(-1, class_size, -1) hidden =, token_embeddings), dim=-1) hidden = F.leaky_relu(self.w(hidden)) score = self.score(hidden).view(batch_size, class_size) return score
[docs]class NodeGenerater(nn.Module): def __init__(self, hidden_size, op_nums, embedding_size, dropout=0.5): super(NodeGenerater, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.embeddings = nn.Embedding(op_nums, embedding_size) self.em_dropout = nn.Dropout(dropout) self.generate_left = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.generate_right = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.generate_left_g = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.generate_right_g = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
[docs] def forward(self, node_embedding, node_label, current_context): node_label_ = self.embeddings(node_label) node_label = self.em_dropout(node_label_) node_embedding = node_embedding.squeeze(1) current_context = current_context.squeeze(1) node_embedding = self.em_dropout(node_embedding) current_context = self.em_dropout(current_context) l_child = torch.tanh(self.generate_left(, current_context, node_label), 1))) l_child_g = torch.sigmoid(self.generate_left_g(, current_context, node_label), 1))) r_child = torch.tanh(self.generate_right(, current_context, node_label), 1))) r_child_g = torch.sigmoid(self.generate_right_g(, current_context, node_label), 1))) l_child = l_child * l_child_g r_child = r_child * r_child_g return l_child, r_child, node_label_
[docs]class NodeEmbeddingLayer(nn.Module): def __init__(self, op_nums, embedding_size): super(NodeEmbeddingLayer, self).__init__() self.embedding_size = embedding_size self.op_nums = op_nums self.embeddings = nn.Embedding(op_nums, embedding_size)
[docs] def forward(self, node_embedding, node_label, current_context): """ Args: node_embedding (torch.Tensor): node embedding, shape [batch_size, num_directions * hidden_size]. node_label (torch.Tensor): shape [batch_size]. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor): l_child, representation of left child, shape [batch_size, num_directions * hidden_size]. r_child, representation of right child, shape [batch_size, num_directions * hidden_size]. node_label_, representation of node label, shape [batch_size, embedding_size]. """ node_label_ = self.embeddings(node_label) return node_embedding, node_embedding, node_label_
[docs]class TreeEmbeddingModel(nn.Module): def __init__(self, hidden_size, op_set, dropout=0.4): super(TreeEmbeddingModel, self).__init__() self.op_set = op_set self.dropout = nn.Dropout(p=dropout) self.combine = GateNN(hidden_size, hidden_size * 2, dropout=dropout, single_layer=True) return
[docs] def merge(self, op_embedding, left_embedding, right_embedding): te_input =, right_embedding), dim=-1) te_input = self.dropout(te_input) op_embedding = self.dropout(op_embedding) tree_embed = self.combine(op_embedding, te_input) return tree_embed
[docs] def forward(self, class_embedding, tree_stacks, embed_node_index): # embed_node_index: batch_size use_cuda = embed_node_index.is_cuda batch_index = torch.arange(embed_node_index.size(0)) if use_cuda: batch_index = batch_index.cuda() labels_embedding = class_embedding[batch_index, embed_node_index] for node_label, tree_stack, label_embedding in zip(embed_node_index.cpu().tolist(), tree_stacks, labels_embedding): # operations if node_label in self.op_set: tree_node = TreeEmbedding(label_embedding, terminal=False) # numbers else: right_embedding = label_embedding # on right tree => merge while len(tree_stack) >= 2 and tree_stack[-1].terminal and (not tree_stack[-2].terminal): left_embedding = tree_stack.pop().embedding op_embedding = tree_stack.pop().embedding right_embedding = self.merge(op_embedding, left_embedding, right_embedding) tree_node = TreeEmbedding(right_embedding, terminal=True) tree_stack.append(tree_node) return labels_embedding
[docs]class SubTreeMerger(nn.Module): def __init__(self, hidden_size, embedding_size, dropout=0.5): super(SubTreeMerger, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.em_dropout = nn.Dropout(dropout) self.merge = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.merge_g = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
[docs] def forward(self, node_embedding, sub_tree_1, sub_tree_2): sub_tree_1 = self.em_dropout(sub_tree_1) sub_tree_2 = self.em_dropout(sub_tree_2) node_embedding = self.em_dropout(node_embedding) sub_tree = torch.tanh(self.merge(, sub_tree_1, sub_tree_2), 1))) sub_tree_g = torch.sigmoid(self.merge_g(, sub_tree_1, sub_tree_2), 1))) sub_tree = sub_tree * sub_tree_g return sub_tree
[docs]class RecursiveNN(nn.Module): def __init__(self, emb_size, op_size, op_list): super().__init__() self.emb_size = emb_size self.op_size = op_size self.W = nn.Linear(emb_size * 2, emb_size, bias=True) self.generate_linear = nn.Linear(emb_size, op_size, bias=True) #self.softmax = nn.functional.softmax self.classes = op_list
[docs] def forward(self, expression_tree, num_embedding, look_up, out_idx2symbol): device = num_embedding.device self.out_idx2symbol = out_idx2symbol self.leaf_emb(expression_tree, num_embedding, look_up) self.nodeProbList = [] self.labelList = [] _ = self.traverse(expression_tree) if self.nodeProbList != []: nodeProb =, dim=0).to(device) label = torch.tensor(self.labelList).to(device) else: nodeProb = self.nodeProbList label = self.labelList return nodeProb, label
[docs] def test(self, expression_tree, num_embedding, look_up, out_idx2symbol): device = num_embedding.device self.out_idx2symbol = out_idx2symbol self.leaf_emb(expression_tree, num_embedding, look_up) self.nodeProbList = [] self.labelList = [] _ = self.test_traverse(expression_tree) if self.nodeProbList != []: nodeProb =, dim=0).to(device) label = torch.tensor(self.labelList).to(device) else: nodeProb = self.nodeProbList label = self.labelList return nodeProb, label, expression_tree
[docs] def leaf_emb(self, node, num_embed, look_up): if node.is_leaf: #symbol=self.out_idx2symbol[node.node_value] symbol = node.node_value if symbol not in look_up: node.embedding = num_embed[0] else: node.embedding = num_embed[look_up.index(symbol)] else: self.leaf_emb(node.left_node, num_embed, look_up) self.leaf_emb(node.right_node, num_embed, look_up)
[docs] def traverse(self, node): if node.is_leaf: currentNode = node.embedding.unsqueeze(0) else: left_vector = self.traverse(node.left_node) right_vector = self.traverse(node.right_node) combined_v =, right_vector), 1) currentNode, op_prob = self.RecurCell(combined_v) node.embedding = currentNode.squeeze(0) self.nodeProbList.append(op_prob) #node.numclass_probs = proj_probs self.labelList.append(self.classes.index(node.node_value)) return currentNode
[docs] def test_traverse(self, node): if node.is_leaf: currentNode = node.embedding.unsqueeze(0) else: left_vector = self.test_traverse(node.left_node) right_vector = self.test_traverse(node.right_node) combined_v =, right_vector), 1) currentNode, op_prob = self.RecurCell(combined_v) node.embedding = currentNode.squeeze(0) op_idx = torch.topk(op_prob, 1, 1)[1] self.nodeProbList.append(op_prob) node.node_value = self.classes[op_idx] self.labelList.append(self.classes.index(node.node_value)) return currentNode
[docs] def RecurCell(self, combine_emb): node_embedding = torch.tanh(self.W(combine_emb)) #op=self.softmax(self.generate_linear(node_embedding),dim=1) op = self.generate_linear(node_embedding) return node_embedding, op
[docs]class Dec_LSTM(nn.Module): def __init__(self, embedding_size, hidden_size, dropout_ratio): super(Dec_LSTM, self).__init__() #self.opt = opt self.embedding_size = embedding_size self.hidden_size = hidden_size self.dropout_ratio = dropout_ratio self.i2h = nn.Linear(self.embedding_size + 2 * self.hidden_size, 4 * self.hidden_size) self.h2h = nn.Linear(self.hidden_size, 4 * self.hidden_size) self.dropout = nn.Dropout(self.dropout_ratio)
[docs] def forward(self, x, prev_c, prev_h, parent_h, sibling_state): input_cat =, parent_h, sibling_state), 1) gates = self.i2h(input_cat) + self.h2h(prev_h) ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) ingate = torch.sigmoid(ingate) forgetgate = torch.sigmoid(forgetgate) cellgate = torch.tanh(cellgate) outgate = torch.sigmoid(outgate) cellgate = self.dropout(cellgate) cy = (forgetgate * prev_c) + (ingate * cellgate) hy = outgate * torch.tanh(cy) return cy, hy
[docs]class DQN(nn.Module): def __init__(self, input_size, embedding_size, hidden_size, output_size, dropout_ratio): super(DQN, self).__init__() self.hidden_layer_1 = nn.Linear(input_size, hidden_size) self.hidden_layer_2 = nn.Linear(hidden_size, embedding_size) self.action_pred = nn.Linear(hidden_size, output_size)
[docs] def forward(self, inputs): out_1 = self.hidden_layer_1(inputs) out_2 = self.hidden_layer_2(out_1) pred = self.action_pred(out_1) return pred, out_2
[docs] def play_one(self, inputs): pred, obv = self.forward(inputs) act = pred.topk(1, dim=0)[1] return act, obv
[docs]class GateNN(nn.Module): def __init__(self, hidden_size, input1_size, input2_size=0, dropout=0.4, single_layer=False): super(GateNN, self).__init__() self.single_layer = single_layer self.hidden_l1 = nn.Linear(input1_size + hidden_size, hidden_size) self.gate_l1 = nn.Linear(input1_size + hidden_size, hidden_size) if not single_layer: self.dropout = nn.Dropout(p=dropout) self.hidden_l2 = nn.Linear(input2_size + hidden_size, hidden_size) self.gate_l2 = nn.Linear(input2_size + hidden_size, hidden_size) return
[docs] def forward(self, hidden, input1, input2=None): input1 =, input1), dim=-1) h = torch.tanh(self.hidden_l1(input1)) g = torch.sigmoid(self.gate_l1(input1)) h = h * g if not self.single_layer: h1 = self.dropout(h) if input2 is not None: input2 =, input2), dim=-1) else: input2 = h1 h = torch.tanh(self.hidden_l2(input2)) g = torch.sigmoid(self.gate_l2(input2)) h = h * g return h
[docs]class DecomposeModel(nn.Module): def __init__(self, hidden_size, dropout, device): super(DecomposeModel, self).__init__() self.pad_hidden = torch.zeros(hidden_size) self.pad_hidden = self.dropout = nn.Dropout(p=dropout) self.l_decompose = GateNN(hidden_size, hidden_size * 2, 0, dropout=dropout, single_layer=False) self.r_decompose = GateNN(hidden_size, hidden_size * 2, hidden_size, dropout=dropout, single_layer=False) return
[docs] def forward(self, node_stacks, tree_stacks, nodes_context, labels_embedding, pad_node=True): children_hidden = [] for node_stack, tree_stack, node_context, label_embedding in zip(node_stacks, tree_stacks, nodes_context, labels_embedding): # start from encoder_hidden # len == 0 => finished decode if len(node_stack) > 0: # left if not tree_stack[-1].terminal: node_hidden = node_stack[-1].node_hidden # parent, still need for right node_stack[-1] = NodeEmbeddingNode(node_hidden, node_context, label_embedding) # add context and label of parent for right child l_input =, label_embedding), dim=-1) l_input = self.dropout(l_input) node_hidden = self.dropout(node_hidden) child_hidden = self.l_decompose(node_hidden, l_input, None) node_stack.append(NodeEmbeddingNode(child_hidden, None, None)) # only hidden for left child # right else: node_stack.pop() # left child, no need if len(node_stack) > 0: parent_node = node_stack.pop() # parent, no longer need node_hidden = parent_node.node_hidden node_context = parent_node.node_context label_embedding = parent_node.label_embedding left_embedding = tree_stack[-1].embedding # left tree left_embedding = self.dropout(left_embedding) r_input =, label_embedding), dim=-1) r_input = self.dropout(r_input) node_hidden = self.dropout(node_hidden) child_hidden = self.r_decompose(node_hidden, r_input, left_embedding) node_stack.append(NodeEmbeddingNode(child_hidden, None, None)) # only hidden for right child # else finished decode # finished decode, pad if len(node_stack) == 0: child_hidden = self.pad_hidden if pad_node: node_stack.append(NodeEmbeddingNode(child_hidden, None, None)) children_hidden.append(child_hidden) children_hidden = torch.stack(children_hidden, dim=0) return children_hidden
[docs]class Prediction(nn.Module): # a seq2tree decoder with Problem aware dynamic encoding def __init__(self, hidden_size, op_nums, input_size, dropout=0.5): super(Prediction, self).__init__() # Keep for reference self.hidden_size = hidden_size self.input_size = input_size self.op_nums = op_nums # Define layers self.dropout = nn.Dropout(dropout) self.embedding_weight = nn.Parameter(torch.randn(1, input_size, hidden_size)) # for Computational symbols and Generated numbers self.concat_l = nn.Linear(hidden_size, hidden_size) self.concat_r = nn.Linear(hidden_size * 2, hidden_size) self.concat_lg = nn.Linear(hidden_size, hidden_size) self.concat_rg = nn.Linear(hidden_size * 2, hidden_size) self.ops = nn.Linear(hidden_size * 2, op_nums) self.attn = TreeAttention(hidden_size, hidden_size) self.score = Score(hidden_size * 2, hidden_size)
[docs] def forward(self, node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, mask_nums): """ Args: node_stacks (list): node stacks. left_childs (list): representation of left childs. encoder_outputs (torch.Tensor): output from encoder, shape [sequence_length, batch_size, hidden_size]. num_pades (torch.Tensor): number representation, shape [batch_size, number_size, hidden_size]. padding_hidden (torch.Tensor): padding hidden, shape [1,hidden_size]. seq_mask (torch.BoolTensor): sequence mask, shape [batch_size, sequence_length]. mask_nums (torch.BoolTensor): number mask, shape [batch_size, number_size]. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): num_score, number score, shape [batch_size, number_size]. op, operator score, shape [batch_size, operator_size]. current_node, current node representation, shape [batch_size, 1, hidden_size]. current_context, current context representation, shape [batch_size, 1, hidden_size]. embedding_weight, embedding weight, shape [batch_size, number_size, hidden_size]. """ current_embeddings = [] for st in node_stacks: if len(st) == 0: current_embeddings.append(padding_hidden) else: current_node = st[-1] current_embeddings.append(current_node.embedding) current_node_temp = [] for l, c in zip(left_childs, current_embeddings): if l is None: c = self.dropout(c) g = torch.tanh(self.concat_l(c)) t = torch.sigmoid(self.concat_lg(c)) current_node_temp.append(g * t) else: ld = self.dropout(l) c = self.dropout(c) g = torch.tanh(self.concat_r(, c), 1))) t = torch.sigmoid(self.concat_rg(, c), 1))) current_node_temp.append(g * t) current_node = torch.stack(current_node_temp) current_embeddings = self.dropout(current_node) current_attn = self.attn(current_embeddings.transpose(0, 1), encoder_outputs, seq_mask) current_context = current_attn.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N # the information to get the current quantity batch_size = current_embeddings.size(0) # predict the output (this node corresponding to output(number or operator)) with PADE repeat_dims = [1] * self.embedding_weight.dim() repeat_dims[0] = batch_size embedding_weight = self.embedding_weight.repeat(*repeat_dims) # B x input_size x N embedding_weight =, num_pades), dim=1) # B x O x N leaf_input =, current_context), 2) leaf_input = leaf_input.squeeze(1) leaf_input = self.dropout(leaf_input) # p_leaf = nn.functional.softmax(self.is_leaf(leaf_input), 1) # max pooling the embedding_weight embedding_weight_ = self.dropout(embedding_weight) num_score = self.score(leaf_input.unsqueeze(1), embedding_weight_, mask_nums) # num_score = nn.functional.softmax(num_score, 1) op = self.ops(leaf_input) # return p_leaf, num_score, op, current_embeddings, current_attn return num_score, op, current_node, current_context, embedding_weight
[docs]class GenerateNode(nn.Module): def __init__(self, hidden_size, op_nums, embedding_size, dropout=0.5): super(GenerateNode, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.embeddings = nn.Embedding(op_nums, embedding_size) self.em_dropout = nn.Dropout(dropout) self.generate_l = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.generate_r = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.generate_lg = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.generate_rg = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
[docs] def forward(self, node_embedding, node_label, current_context): """ Args: node_embedding (torch.Tensor): node embedding, shape [batch_size, hidden_size]. node_label (torch.Tensor): representation of node label, shape [batch_size, embedding_size]. current_context (torch.Tensor): current context, shape [batch_size, hidden_size]. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor): l_child, representation of left child, shape [batch_size, hidden_size]. r_child, representation of right child, shape [batch_size, hidden_size]. node_label_, representation of node label, shape [batch_size, embedding_size]. """ node_label_ = self.embeddings(node_label) node_label = self.em_dropout(node_label_) node_embedding = node_embedding.squeeze(1) current_context = current_context.squeeze(1) node_embedding = self.em_dropout(node_embedding) current_context = self.em_dropout(current_context) l_child = torch.tanh(self.generate_l(, current_context, node_label), 1))) l_child_g = torch.sigmoid(self.generate_lg(, current_context, node_label), 1))) r_child = torch.tanh(self.generate_r(, current_context, node_label), 1))) r_child_g = torch.sigmoid(self.generate_rg(, current_context, node_label), 1))) l_child = l_child * l_child_g r_child = r_child * r_child_g return l_child, r_child, node_label_
[docs]class Merge(nn.Module): def __init__(self, hidden_size, embedding_size, dropout=0.5): super(Merge, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.em_dropout = nn.Dropout(dropout) self.merge = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) self.merge_g = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
[docs] def forward(self, node_embedding, sub_tree_1, sub_tree_2): """ Args: node_embedding (torch.Tensor): node embedding, shape [1, embedding_size]. sub_tree_1 (torch.Tensor): representation of sub tree 1, shape [1, hidden_size]. sub_tree_2 (torch.Tensor): representation of sub tree 2, shape [1, hidden_size]. Returns: torch.Tensor: representation of merged tree, shape [1, hidden_size]. """ sub_tree_1 = self.em_dropout(sub_tree_1) sub_tree_2 = self.em_dropout(sub_tree_2) node_embedding = self.em_dropout(node_embedding) sub_tree = torch.tanh(self.merge(, sub_tree_1, sub_tree_2), 1))) sub_tree_g = torch.sigmoid(self.merge_g(, sub_tree_1, sub_tree_2), 1))) sub_tree = sub_tree * sub_tree_g return sub_tree
[docs]class TreeAttention(nn.Module): def __init__(self, input_size, hidden_size): super(TreeAttention, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.attn = nn.Linear(hidden_size + input_size, hidden_size) self.score = nn.Linear(hidden_size, 1)
[docs] def forward(self, hidden, encoder_outputs, seq_mask=None): """ Args: hidden (torch.Tensor): hidden representation, shape [1, batch_size, hidden_size] encoder_outputs (torch.Tensor): output from encoder, shape [sequence_length, batch_size, hidden_size]. seq_mask (torch.Tensor): sequence mask, shape [batch_size, sequence_length]. Returns: attn_energies (torch.Tensor): attention energies, shape [batch_size, 1, sequence_length]. """ max_len = encoder_outputs.size(0) repeat_dims = [1] * hidden.dim() repeat_dims[0] = max_len hidden = hidden.repeat(*repeat_dims) # S x B x H this_batch_size = encoder_outputs.size(1) energy_in =, encoder_outputs), 2).view(-1, self.input_size + self.hidden_size) score_feature = torch.tanh(self.attn(energy_in)) attn_energies = self.score(score_feature) # (S x B) x 1 attn_energies = attn_energies.squeeze(1) attn_energies = attn_energies.view(max_len, this_batch_size).transpose(0, 1) # B x S if seq_mask is not None: attn_energies = attn_energies.masked_fill_(seq_mask.bool(), -1e12) attn_energies = nn.functional.softmax(attn_energies, dim=1) # B x S return attn_energies.unsqueeze(1)
[docs]class SemanticAlignmentModule(nn.Module): def __init__(self, encoder_hidden_size, decoder_hidden_size, hidden_size, batch_first=False): super(SemanticAlignmentModule, self).__init__() self.batch_first = batch_first self.attn = TreeAttention(encoder_hidden_size,decoder_hidden_size) self.encoder_linear1 = nn.Linear(encoder_hidden_size, hidden_size) self.encoder_linear2 = nn.Linear(hidden_size, hidden_size) self.decoder_linear1 = nn.Linear(decoder_hidden_size, hidden_size) self.decoder_linear2 = nn.Linear(hidden_size, hidden_size)
[docs] def forward(self, decoder_hidden, encoder_outputs): if self.batch_first: decoder_hidden = decoder_hidden.unsqueeze(0) encoder_outputs = encoder_outputs.unsqueeze(0) else: decoder_hidden = decoder_hidden.unsqueeze(0) encoder_outputs = encoder_outputs.unsqueeze(1) attn_weights = self.attn(decoder_hidden, encoder_outputs, None) if self.batch_first: align_context = attn_weights.bmm(encoder_outputs) # B x 1 x H else: align_context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x H align_context = align_context.transpose(0,1) encoder_linear1 = torch.tanh(self.encoder_linear1(align_context)) encoder_linear2 = self.encoder_linear2(encoder_linear1) decoder_linear1 = torch.tanh(self.decoder_linear1(decoder_hidden)) decoder_linear2 = self.decoder_linear2(decoder_linear1) return encoder_linear2, decoder_linear2
