# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/21 04:59:55
# @File: sausolver.py
import random
from typing import Dict, Any, Tuple
import torch
from torch import nn
import copy
from mwptoolkit.module.Encoder.rnn_encoder import BasicRNNEncoder
from mwptoolkit.module.Embedder.basic_embedder import BasicEmbedder
from mwptoolkit.module.Decoder.tree_decoder import SARTreeDecoder
from mwptoolkit.module.Layer.tree_layers import NodeGenerater, SubTreeMerger, TreeNode, TreeEmbedding
from mwptoolkit.module.Layer.tree_layers import Prediction, GenerateNode, Merge, SemanticAlignmentModule
from mwptoolkit.module.Strategy.beam_search import TreeBeam
from mwptoolkit.loss.masked_cross_entropy_loss import MaskedCrossEntropyLoss, masked_cross_entropy
from mwptoolkit.loss.mse_loss import MSELoss
from mwptoolkit.utils.utils import copy_list
from mwptoolkit.utils.enum_type import NumMask, SpecialTokens
[docs]class SAUSolver(nn.Module):
"""
Reference:
Qin et al. "Semantically-Aligned Universal Tree-Structured Solver for Math Word Problems" in EMNLP 2020.
"""
def __init__(self, config, dataset):
super(SAUSolver, self).__init__()
# parameter
self.hidden_size = config["hidden_size"]
self.device = config["device"]
self.USE_CUDA = True if self.device == torch.device('cuda') else False
self.beam_size = config['beam_size']
self.max_out_len = config['max_output_len']
self.embedding_size = config["embedding_size"]
self.dropout_ratio = config["dropout_ratio"]
self.num_layers = config["num_layers"]
self.rnn_cell_type = config["rnn_cell_type"]
self.loss_weight = config['loss_weight']
self.batch_first = False
self.vocab_size = len(dataset.in_idx2word)
self.out_symbol2idx = dataset.out_symbol2idx
self.out_idx2symbol = dataset.out_idx2symbol
generate_list = dataset.generate_list
self.generate_nums = [self.out_symbol2idx[symbol] for symbol in generate_list]
self.mask_list = NumMask.number
self.num_start = dataset.num_start
self.operator_nums = dataset.operator_nums
self.generate_size = len(generate_list)
self.unk_token = self.out_symbol2idx[SpecialTokens.UNK_TOKEN]
try:
self.out_sos_token = self.out_symbol2idx[SpecialTokens.SOS_TOKEN]
except:
self.out_sos_token = None
try:
self.out_eos_token = self.out_symbol2idx[SpecialTokens.EOS_TOKEN]
except:
self.out_eos_token = None
try:
self.out_pad_token = self.out_symbol2idx[SpecialTokens.PAD_TOKEN]
except:
self.out_pad_token = None
# module
self.embedder = BasicEmbedder(self.vocab_size, self.embedding_size, self.dropout_ratio)
# self.t_encoder = BasicRNNEncoder(self.embedding_size, self.hidden_size, self.num_layers, self.rnn_cell_type, self.dropout_ratio)
self.encoder = BasicRNNEncoder(self.embedding_size, self.hidden_size, self.num_layers, self.rnn_cell_type,
self.dropout_ratio, batch_first=False)
#self.decoder = SARTreeDecoder(self.hidden_size, self.operator_nums, self.generate_size, self.dropout_ratio)
self.decoder = Prediction(self.hidden_size,self.operator_nums,self.generate_size,self.dropout_ratio)
self.node_generater = GenerateNode(self.hidden_size, self.operator_nums, self.embedding_size,
self.dropout_ratio)
self.merge = Merge(self.hidden_size, self.embedding_size, self.dropout_ratio)
self.sa = SemanticAlignmentModule(self.hidden_size,self.hidden_size,self.hidden_size)
self.loss1 = MaskedCrossEntropyLoss()
#
[docs] def forward(self, seq, seq_length, nums_stack, num_size, num_pos, target=None, output_all_layers=False) -> Tuple[
torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
:param torch.Tensor seq: input sequence, shape: [batch_size, seq_length].
:param torch.Tensor seq_length: the length of sequence, shape: [batch_size].
:param list nums_stack: different positions of the same number, length:[batch_size]
:param list num_size: number of numbers of input sequence, length:[batch_size].
:param list num_pos: number positions of input sequence, length:[batch_size].
:param torch.Tensor | None target: target, shape: [batch_size, target_length], default None.
:param bool output_all_layers: return output of all layers if output_all_layers is True, default False.
:return : token_logits:[batch_size, output_length, output_size], symbol_outputs:[batch_size,output_length], model_all_outputs.
:rtype: tuple(torch.Tensor, torch.Tensor, dict)
"""
# sequence mask for attention
seq_mask = torch.eq(seq, self.in_pad_token).to(self.device)
num_mask = []
max_num_size = max(num_size) + len(self.generate_nums)
for i in num_size:
d = i + len(self.generate_nums)
num_mask.append([0] * d + [1] * (max_num_size - d))
num_mask = torch.BoolTensor(num_mask).to(self.device)
if not self.batch_first:
target = target.transpose(0,1)
batch_size = len(seq_length)
seq_emb = self.embedder(seq)
problem_output, encoder_outputs, encoder_layer_outputs = self.encoder_forward(seq_emb, seq_length,
output_all_layers)
copy_num_len = [len(_) for _ in num_pos]
max_num_size = max(copy_num_len)
all_nums_encoder_outputs = self.get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size,
max_num_size,
self.hidden_size)
token_logits, symbol_outputs, decoder_layer_outputs = self.decoder_forward(encoder_outputs, problem_output,
all_nums_encoder_outputs, nums_stack,
seq_mask, num_mask, target,
output_all_layers)
model_all_outputs = {}
if output_all_layers:
model_all_outputs['inputs_embedding'] = seq_emb
model_all_outputs.update(encoder_layer_outputs)
model_all_outputs['number_representation'] = all_nums_encoder_outputs
model_all_outputs.update(decoder_layer_outputs)
return token_logits, symbol_outputs, model_all_outputs
[docs] def calculate_loss(self, batch_data:dict) -> float:
"""Finish forward-propagating, calculating loss and back-propagation.
:param batch_data: one batch data.
:return: loss value.
batch_data should include keywords 'question', 'ques len', 'equation', 'equ len',
'num stack', 'num size', 'num pos'
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
target = torch.tensor(batch_data["equation"]).to(self.device)
target_length = torch.LongTensor(batch_data["equ len"]).to(self.device)
nums_stack = copy.deepcopy(batch_data["num stack"])
num_size = batch_data["num size"]
num_pos = batch_data["num pos"]
generate_nums = self.generate_nums
num_start = self.num_start
# sequence mask for attention
unk = self.unk_token
loss = self.train_tree(seq, seq_length, target, target_length, nums_stack, num_size, generate_nums, num_pos, unk, num_start)
return loss
[docs] def model_test(self, batch_data:dict) -> tuple:
"""Model test.
:param batch_data: one batch data.
:return: predicted equation, target equation.
batch_data should include keywords 'question', 'ques len', 'equation',
'num stack', 'num pos', 'num list'
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
target = torch.tensor(batch_data["equation"]).to(self.device)
nums_stack = copy.deepcopy(batch_data["num stack"])
num_pos = batch_data["num pos"]
num_list = batch_data['num list']
generate_nums = self.generate_nums
num_start = self.num_start
# sequence mask for attention
all_node_output = self.evaluate_tree(seq, seq_length, generate_nums, num_pos, num_start, self.beam_size,
self.max_out_len)
all_output = self.convert_idx2symbol(all_node_output, num_list[0], copy_list(nums_stack[0]))
targets = self.convert_idx2symbol(target[0], num_list[0], copy_list(nums_stack[0]))
return all_output, targets
[docs] def predict(self,batch_data:dict,output_all_layers=False):
"""
predict samples without target.
:param dict batch_data: one batch data.
:param bool output_all_layers: return all layer outputs of model.
:return: token_logits, symbol_outputs, all_layer_outputs
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
nums_stack = copy.deepcopy(batch_data["num stack"])
num_size = batch_data["num size"]
num_pos = batch_data["num pos"]
token_logits, symbol_outputs, model_all_outputs = self.forward(seq,seq_length,nums_stack,num_size,num_pos,output_all_layers=output_all_layers)
return token_logits, symbol_outputs, model_all_outputs
[docs] def train_tree(self,input_batch, input_length, target_batch, target_length, nums_stack_batch, num_size_batch, generate_nums, num_pos, unk, num_start,
english=False,var_nums=[], batch_first=False):
# sequence mask for attention
seq_mask = []
max_len = max(input_length)
for i in input_length:
seq_mask.append([0 for _ in range(i)] + [1 for _ in range(i, max_len)])
seq_mask = torch.ByteTensor(seq_mask)
num_mask = []
max_num_size = max(num_size_batch) + len(generate_nums) + len(var_nums) # 最大的位置列表数目+常识数字数目+未知数列表
for i in num_size_batch:
d = i + len(generate_nums) + len(var_nums)
num_mask.append([0] * d + [1] * (max_num_size - d))
num_mask = torch.ByteTensor(num_mask) # 用于屏蔽无关数字,防止生成错误的Nx
#unk = output_lang.word2index["UNK"]
# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
input_var = input_batch.transpose(0, 1)
target = target_batch.transpose(0, 1)
padding_hidden = torch.FloatTensor([0.0 for _ in range(self.decoder.hidden_size)]).unsqueeze(0)
batch_size = len(input_length)
if self.USE_CUDA:
input_var = input_var.cuda()
seq_mask = seq_mask.cuda()
padding_hidden = padding_hidden.cuda()
num_mask = num_mask.cuda()
# Zero gradients of both optimizers
# Run words through encoder
#encoder_outputs, problem_output = self.encoder(input_var, input_length)
seq_emb = self.embedder(input_var)
pade_outputs, _ = self.encoder(seq_emb, input_length)
problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:]
encoder_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:]
# Prepare input and output variables
node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)] # root embedding B x 1
max_target_length = max(target_length)
all_node_outputs = []
all_sa_outputs = []
# all_leafs = []
copy_num_len = [len(_) for _ in num_pos]
num_size = max(copy_num_len)
# 提取与问题相关的数字embedding
all_nums_encoder_outputs = self.get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size,
self.encoder.hidden_size)
embeddings_stacks = [[] for _ in range(batch_size)] # B x 1 当前的tree state/ subtree embedding / output
left_childs = [None for _ in range(batch_size)] # B x 1
for t in range(max_target_length):
num_score, op, current_embeddings, current_context, current_nums_embeddings = self.decoder(
node_stacks, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden, seq_mask, num_mask)
# all_leafs.append(p_leaf)
outputs = torch.cat((op, num_score), 1)
all_node_outputs.append(outputs)
target_t, generate_input = self.generate_tree_input(target[t].tolist(), outputs, nums_stack_batch, num_start,
unk)
target[t] = target_t
if self.USE_CUDA:
generate_input = generate_input.cuda()
left_child, right_child, node_label = self.node_generater(current_embeddings, generate_input, current_context)
left_childs = []
for idx, l, r, node_stack, i, o in zip(range(batch_size), left_child.split(1), right_child.split(1),
node_stacks, target[t].tolist(), embeddings_stacks):
if len(node_stack) != 0:
node = node_stack.pop()
else:
left_childs.append(None)
continue
# 未知数当数字处理,SEP当操作符处理
if i < num_start: # 非数字
node_stack.append(TreeNode(r))
node_stack.append(TreeNode(l, left_flag=True))
o.append(TreeEmbedding(node_label[idx].unsqueeze(0), terminal=False))
# print(o[-1].embedding.size())
# print(encoder_outputs[idx].size())
else: # 数字
current_num = current_nums_embeddings[idx, i - num_start].unsqueeze(0)
while len(o) > 0 and o[-1].terminal:
sub_stree = o.pop()
op = o.pop()
current_num = self.merge(op.embedding, sub_stree.embedding, current_num) # Subtree embedding
if batch_first:
encoder_mapping, decoder_mapping = self.sa(current_num, encoder_outputs[idx])
else:
temp_encoder_outputs = encoder_outputs.transpose(0, 1)
encoder_mapping, decoder_mapping = self.sa(current_num,temp_encoder_outputs[idx])
all_sa_outputs.append((encoder_mapping, decoder_mapping))
o.append(TreeEmbedding(current_num, terminal=True))
if len(o) > 0 and o[-1].terminal:
left_childs.append(o[-1].embedding)
else:
left_childs.append(None)
# all_leafs = torch.stack(all_leafs, dim=1) # B x S x 2
all_node_outputs = torch.stack(all_node_outputs, dim=1) # B x S x N
target = target.transpose(0, 1).contiguous() # B x S
if self.USE_CUDA:
# all_leafs = all_leafs.cuda()
all_node_outputs = all_node_outputs.cuda()
target = target.cuda()
new_all_sa_outputs = []
for sa_pair in all_sa_outputs:
new_all_sa_outputs.append((sa_pair[0].cuda(), sa_pair[1].cuda()))
all_sa_outputs = new_all_sa_outputs
# target_length = torch.LongTensor(target_length).cuda()
else:
pass
# target_length = torch.LongTensor(target_length)
semantic_alignment_loss = nn.MSELoss()
total_semanti_alognment_loss = 0
sa_len = len(all_sa_outputs)
for sa_pair in all_sa_outputs:
total_semanti_alognment_loss += semantic_alignment_loss(sa_pair[0], sa_pair[1])
# print(total_semanti_alognment_loss)
total_semanti_alognment_loss = total_semanti_alognment_loss / sa_len
# print(total_semanti_alognment_loss)
# op_target = target < num_start
# loss_0 = masked_cross_entropy_without_logit(all_leafs, op_target.long(), target_length)
loss = masked_cross_entropy(all_node_outputs, target,target_length) + 0.01 * total_semanti_alognment_loss
# loss = loss_0 + loss_1
loss.backward()
# clip the grad
# torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
# torch.nn.utils.clip_grad_norm_(predict.parameters(), 5)
# torch.nn.utils.clip_grad_norm_(generate.parameters(), 5)
# Update parameters with optimizers
return loss.item() # , loss_0.item(), loss_1.item()
[docs] def evaluate_tree(self, input_batch, input_length, generate_nums, num_pos, num_start, beam_size=5, max_length=30):
seq_mask = torch.BoolTensor(1, input_length).fill_(0)
# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
input_var = input_batch.transpose(0, 1)
num_mask = torch.BoolTensor(1, len(num_pos[0]) + len(generate_nums)).fill_(0)
padding_hidden = torch.FloatTensor([0.0 for _ in range(self.hidden_size)]).unsqueeze(0)
batch_size = 1
if self.USE_CUDA:
input_var = input_var.cuda()
seq_mask = seq_mask.cuda()
padding_hidden = padding_hidden.cuda()
num_mask = num_mask.cuda()
# Run words through encoder
seq_emb = self.embedder(input_var)
pade_outputs, _ = self.encoder(seq_emb, input_length)
problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:]
encoder_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:]
# Prepare input and output variables
node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]
num_size = len(num_pos[0])
all_nums_encoder_outputs = self.get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size,
self.hidden_size)
# B x P x N
embeddings_stacks = [[] for _ in range(batch_size)]
left_childs = [None for _ in range(batch_size)]
beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_childs, [])]
for t in range(max_length):
current_beams = []
while len(beams) > 0:
b = beams.pop()
if len(b.node_stack[0]) == 0:
current_beams.append(b)
continue
# left_childs = torch.stack(b.left_childs)
left_childs = b.left_childs
num_score, op, current_embeddings, current_context, current_nums_embeddings = self.decoder(b.node_stack,
left_childs,
encoder_outputs,
all_nums_encoder_outputs,
padding_hidden,
seq_mask,
num_mask)
out_score = nn.functional.log_softmax(torch.cat((op, num_score), dim=1), dim=1)
# out_score = p_leaf * out_score
topv, topi = out_score.topk(beam_size)
for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
current_node_stack = copy_list(b.node_stack)
current_left_childs = []
current_embeddings_stacks = copy_list(b.embedding_stack)
current_out = copy.deepcopy(b.out)
out_token = int(ti)
current_out.append(out_token)
node = current_node_stack[0].pop()
if out_token < num_start:
generate_input = torch.LongTensor([out_token])
if self.USE_CUDA:
generate_input = generate_input.cuda()
left_child, right_child, node_label = self.node_generater(current_embeddings, generate_input,
current_context)
current_node_stack[0].append(TreeNode(right_child))
current_node_stack[0].append(TreeNode(left_child, left_flag=True))
current_embeddings_stacks[0].append(TreeEmbedding(node_label[0].unsqueeze(0), False))
else:
current_num = current_nums_embeddings[0, out_token - num_start].unsqueeze(0)
while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
sub_stree = current_embeddings_stacks[0].pop()
op = current_embeddings_stacks[0].pop()
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
current_left_childs.append(current_embeddings_stacks[0][-1].embedding)
else:
current_left_childs.append(None)
current_beams.append(TreeBeam(b.score + float(tv), current_node_stack, current_embeddings_stacks,
current_left_childs, current_out))
beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
beams = beams[:beam_size]
flag = True
for b in beams:
if len(b.node_stack[0]) != 0:
flag = False
if flag:
break
return beams[0].out
[docs] def encoder_forward(self, seq_emb, seq_length, output_all_layers=False):
if not self.batch_first:
encoder_inputs = seq_emb.transpose(0, 1)
else:
encoder_inputs = seq_emb
pade_outputs, hidden_states = self.encoder(encoder_inputs, seq_length)
problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:]
encoder_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:]
all_layer_outputs = {}
if output_all_layers:
all_layer_outputs['encoder_outputs']=encoder_outputs
all_layer_outputs['encoder_hidden'] = hidden_states
all_layer_outputs['inputs_representation'] = problem_output
return problem_output,encoder_outputs,all_layer_outputs
[docs] def decoder_forward(self,encoder_outputs,problem_output,all_nums_encoder_outputs,nums_stack,seq_mask,num_mask,target=None,output_all_layers=False):
batch_size = problem_output.size(0)
node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]
padding_hidden = torch.FloatTensor([0.0 for _ in range(self.hidden_size)]).unsqueeze(0).to(self.device)
embeddings_stacks = [[] for _ in range(batch_size)]
left_childs = [None for _ in range(batch_size)]
token_logits=[]
outputs = []
all_sa_outputs = []
if target is not None:
max_target_length = max(target.size(0))
for t in range(max_target_length):
num_score, op_score, current_embeddings, current_context, current_nums_embeddings = self.decoder(
node_stacks,
left_childs,
encoder_outputs,
all_nums_encoder_outputs,
padding_hidden,
seq_mask,
num_mask)
# all_leafs.append(p_leaf)
token_logit = torch.cat((op_score, num_score), 1)
output = torch.topk(token_logit,1,dim=-1)[1]
token_logits.append(token_logit)
outputs.append(output)
target_t, generate_input = self.generate_tree_input(target[t].tolist(), token_logit, nums_stack,
self.num_start, self.unk_token)
target[t] = target_t
if self.USE_CUDA:
generate_input = generate_input.cuda()
left_child, right_child, node_label = self.node_generater(current_embeddings, generate_input,
current_context)
left_childs = []
for idx, l, r, node_stack, i, o in zip(range(batch_size), left_child.split(1), right_child.split(1),
node_stacks, target[t].tolist(), embeddings_stacks):
if len(node_stack) != 0:
node = node_stack.pop()
else:
left_childs.append(None)
continue
if i < self.num_start:
node_stack.append(TreeNode(r))
node_stack.append(TreeNode(l, left_flag=True))
o.append(TreeEmbedding(node_label[idx].unsqueeze(0), False))
else:
current_num = current_nums_embeddings[idx, i - self.num_start].unsqueeze(0)
while len(o) > 0 and o[-1].terminal:
sub_stree = o.pop()
op = o.pop()
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
if self.batch_first:
encoder_mapping, decoder_mapping = self.sa(current_num, encoder_outputs[idx])
else:
temp_encoder_outputs = encoder_outputs.transpose(0, 1)
encoder_mapping, decoder_mapping = self.sa(current_num, temp_encoder_outputs[idx])
all_sa_outputs.append((encoder_mapping, decoder_mapping))
o.append(TreeEmbedding(current_num, True))
if len(o) > 0 and o[-1].terminal:
left_childs.append(o[-1].embedding)
else:
left_childs.append(None)
if not self.batch_first:
target = target.transpose(0, 1).contiguous()
else:
beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_childs, [], [])]
max_gen_len = self.max_out_len
for t in range(max_gen_len):
current_beams = []
while len(beams) > 0:
b = beams.pop()
if len(b.node_stack[0]) == 0:
current_beams.append(b)
continue
left_childs = b.left_childs
num_score, op_score, current_embeddings, current_context, current_nums_embeddings = self.decoder(
b.node_stack,
left_childs,
encoder_outputs,
all_nums_encoder_outputs,
padding_hidden,
seq_mask,
num_mask)
token_logit = torch.cat((op_score, num_score), 1)
out_score = nn.functional.log_softmax(token_logit, dim=1)
# out_score = p_leaf * out_score
topv, topi = out_score.topk(self.beam_size)
for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
current_node_stack = copy_list(b.node_stack)
current_left_childs = []
current_embeddings_stacks = copy_list(b.embedding_stack)
current_out = [tl for tl in b.out]
current_token_logit = [tl for tl in b.token_logit]
current_token_logit.append(token_logit)
out_token = int(ti)
current_out.append(torch.squeeze(ti,dim=1))
node = current_node_stack[0].pop()
if out_token < self.num_start:
generate_input = torch.LongTensor([out_token])
if self.USE_CUDA:
generate_input = generate_input.cuda()
left_child, right_child, node_label = self.node_generater(current_embeddings,
generate_input,
current_context)
current_node_stack[0].append(TreeNode(right_child))
current_node_stack[0].append(TreeNode(left_child, left_flag=True))
current_embeddings_stacks[0].append(TreeEmbedding(node_label[0].unsqueeze(0), False))
else:
current_num = current_nums_embeddings[0, out_token - self.num_start].unsqueeze(0)
while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
sub_stree = current_embeddings_stacks[0].pop()
op = current_embeddings_stacks[0].pop()
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
current_left_childs.append(current_embeddings_stacks[0][-1].embedding)
else:
current_left_childs.append(None)
current_beams.append(
TreeBeam(b.score + float(tv), current_node_stack, current_embeddings_stacks,
current_left_childs, current_out,current_token_logit))
beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
beams = beams[:self.beam_size]
flag = True
for b in beams:
if len(b.node_stack[0]) != 0:
flag = False
if flag:
break
token_logits = beams[0].token_logit
outputs = beams[0].out
token_logits = torch.stack(token_logits, dim=1) # B x S x N
outputs = torch.stack(outputs, dim=1) # B x S
all_layer_outputs = {}
if output_all_layers:
all_layer_outputs['token_logits'] = token_logits
all_layer_outputs['outputs'] = outputs
all_layer_outputs['target'] = target
all_layer_outputs['semantic_alignment_pair']=all_sa_outputs
return token_logits,outputs,all_layer_outputs
[docs] def get_all_number_encoder_outputs(self, encoder_outputs, num_pos, batch_size, num_size, hidden_size):
indices = list()
sen_len = encoder_outputs.size(0)
masked_index = []
temp_1 = [1 for _ in range(hidden_size)]
temp_0 = [0 for _ in range(hidden_size)]
for b in range(batch_size):
for i in num_pos[b]:
indices.append(i + b * sen_len)
masked_index.append(temp_0)
indices += [0 for _ in range(len(num_pos[b]), num_size)]
masked_index += [temp_1 for _ in range(len(num_pos[b]), num_size)]
indices = torch.LongTensor(indices)
masked_index = torch.BoolTensor(masked_index)
masked_index = masked_index.view(batch_size, num_size, hidden_size)
if self.USE_CUDA:
indices = indices.cuda()
masked_index = masked_index.cuda()
all_outputs = encoder_outputs.transpose(0, 1).contiguous()
all_embedding = all_outputs.view(-1, encoder_outputs.size(2)) # S x B x H -> (B x S) x H
all_num = all_embedding.index_select(0, indices)
all_num = all_num.view(batch_size, num_size, hidden_size)
return all_num.masked_fill_(masked_index, 0.0)
[docs] def mse_loss(self, outputs, targets, mask=None):
# outputs : [batch_size,output_len,hidden_size]
# targets : [batch_size,output_len,hidden_size]
# mask : [batch_size,output_len]
mask = mask.to(self.device)
x = torch.sqrt(torch.sum(torch.square((outputs - targets)), dim=-1)) # [batch_size,output_len]
y = torch.sum(x * mask, dim=-1) / torch.sum(mask, dim=-1) # [batch_size]
return torch.sum(y)
[docs] def convert_idx2symbol(self, output, num_list, num_stack):
# batch_size=output.size(0)
'''batch_size=1'''
seq_len = len(output)
num_len = len(num_list)
output_list = []
res = []
for s_i in range(seq_len):
idx = output[s_i]
if idx in [self.out_sos_token, self.out_eos_token, self.out_pad_token]:
break
symbol = self.out_idx2symbol[idx]
if "NUM" in symbol:
num_idx = self.mask_list.index(symbol)
if num_idx >= num_len:
res = []
break
res.append(num_list[num_idx])
elif symbol == SpecialTokens.UNK_TOKEN:
try:
pos_list = num_stack.pop()
c = num_list[pos_list[0]]
res.append(c)
except:
return None
else:
res.append(symbol)
output_list.append(res)
return output_list
# def evaluate_tree(self, input_batch, input_length, generate_nums, num_pos, num_start, beam_size=5, max_length=30,var_nums=[]):
# # sequence mask for attention
# seq_mask = torch.ByteTensor(1, input_length).fill_(0)
# # Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
# input_var = torch.LongTensor(input_batch).unsqueeze(1)
#
# num_mask = torch.ByteTensor(1, len(num_pos) + len(generate_nums) + len(var_nums)).fill_(0)
#
# # Set to not-training mode to disable dropout
#
# padding_hidden = torch.FloatTensor([0.0 for _ in range(self.decoder.hidden_size)]).unsqueeze(0)
#
# batch_size = 1
#
# if self.USE_CUDA:
# input_var = input_var.cuda()
# seq_mask = seq_mask.cuda()
# padding_hidden = padding_hidden.cuda()
# num_mask = num_mask.cuda()
#
# # Run words through encoder
# encoder_outputs, problem_output = self.encoder(input_var, input_length)
#
# # Prepare input and output variables # # root embedding B x 1
# node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]
#
# num_size = len(num_pos)
# # 提取与问题相关的数字embedding
# all_nums_encoder_outputs = self.get_all_number_encoder_outputs(encoder_outputs, [num_pos], batch_size, num_size,
# self.encoder.hidden_size)
# # B x P x N
# embeddings_stacks = [[] for _ in range(batch_size)]
# left_childs = [None for _ in range(batch_size)]
# beam_search=True
# if beam_search:
# beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_childs, [])]
#
# for t in range(max_length):
# current_beams = []
# while len(beams) > 0:
# b = beams.pop()
# if len(b.node_stack[0]) == 0:
# current_beams.append(b)
# continue
# # left_childs = torch.stack(b.left_childs)
# left_childs = b.left_childs
#
# num_score, op, current_embeddings, current_context, current_nums_embeddings = self.decoder(
# b.node_stack, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden,
# seq_mask, num_mask)
#
# # leaf = p_leaf[:, 0].unsqueeze(1)
# # repeat_dims = [1] * leaf.dim()
# # repeat_dims[1] = op.size(1)
# # leaf = leaf.repeat(*repeat_dims)
# #
# # non_leaf = p_leaf[:, 1].unsqueeze(1)
# # repeat_dims = [1] * non_leaf.dim()
# # repeat_dims[1] = num_score.size(1)
# # non_leaf = non_leaf.repeat(*repeat_dims)
# #
# # p_leaf = torch.cat((leaf, non_leaf), dim=1)
# out_score = nn.functional.log_softmax(torch.cat((op, num_score), dim=1), dim=1)
#
# # out_score = p_leaf * out_score
#
# topv, topi = out_score.topk(beam_size)
#
# # is_leaf = int(topi[0])
# # if is_leaf:
# # topv, topi = op.topk(1)
# # out_token = int(topi[0])
# # else:
# # topv, topi = num_score.topk(1)
# # out_token = int(topi[0]) + num_start
# for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
# current_node_stack = copy_list(b.node_stack)
# current_left_childs = []
# current_embeddings_stacks = copy_list(b.embedding_stack)
# current_out = copy.deepcopy(b.out)
# out_token = int(ti)
# current_out.append(out_token)
#
# node = current_node_stack[0].pop()
#
# # var_num当时数字处理,SEP/;当操作符处理
# if out_token < num_start: # 非数字
# generate_input = torch.LongTensor([out_token])
# if self.USE_CUDA:
# generate_input = generate_input.cuda()
# left_child, right_child, node_label = self.node_generater(current_embeddings, generate_input,
# current_context)
#
# current_node_stack[0].append(TreeNode(right_child))
# current_node_stack[0].append(TreeNode(left_child, left_flag=True))
#
# current_embeddings_stacks[0].append(TreeEmbedding(node_label[0].unsqueeze(0), False))
# else: # 数字
# current_num = current_nums_embeddings[0, out_token - num_start].unsqueeze(0)
#
# while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
# sub_stree = current_embeddings_stacks[0].pop()
# op = current_embeddings_stacks[0].pop()
# current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
# current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
# if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
# current_left_childs.append(current_embeddings_stacks[0][-1].embedding)
# else:
# current_left_childs.append(None)
# current_beams.append(
# TreeBeam(b.score + float(tv), current_node_stack, current_embeddings_stacks,
# current_left_childs, current_out))
# beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
# beams = beams[:beam_size]
# flag = True
# for b in beams:
# if len(b.node_stack[0]) != 0:
# flag = False
# if flag:
# break
#
# return beams[0].out
# else:
# all_node_outputs = []
# for t in range(max_length):
# num_score, op, current_embeddings, current_context, current_nums_embeddings = self.decoder(
# node_stacks, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden,
# seq_mask, num_mask)
#
# out_scores = nn.functional.log_softmax(torch.cat((op, num_score), dim=1), dim=1)
# out_tokens = torch.argmax(out_scores, dim=1) # B
# all_node_outputs.append(out_tokens)
# left_childs = []
# for idx, node_stack, out_token, embeddings_stack in zip(range(batch_size), node_stacks, out_tokens,
# embeddings_stacks):
# # node = node_stack.pop()
# if len(node_stack) != 0:
# node = node_stack.pop()
# else:
# left_childs.append(None)
# continue
# # var_num当时数字处理,SEP/;当操作符处理
# if out_token < num_start: # 非数字
# generate_input = torch.LongTensor([out_token])
# if self.USE_CUDA:
# generate_input = generate_input.cuda()
# left_child, right_child, node_label = self.node_generater(current_embeddings, generate_input,
# current_context)
# node_stack.append(TreeNode(right_child))
# node_stack.append(TreeNode(left_child, left_flag=True))
# embeddings_stack.append(TreeEmbedding(node_label.unsqueeze(0), False))
# else: # 数字
# current_num = current_nums_embeddings[idx, out_token - num_start].unsqueeze(0)
# while len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
# sub_stree = embeddings_stack.pop()
# op = embeddings_stack.pop()
# current_num = self.merge(op.embedding.squeeze(0), sub_stree.embedding, current_num)
# embeddings_stack.append(TreeEmbedding(current_num, terminal=True))
#
# if len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
# left_childs.append(embeddings_stack[-1].embedding)
# else:
# left_childs.append(None)
#
# # all_leafs = torch.stack(all_leafs, dim=1) # B x S x 2
# all_node_outputs = torch.stack(all_node_outputs, dim=1) # B x S x N
# all_node_outputs = all_node_outputs.cpu().numpy()
# return all_node_outputs[0]