# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/21 05:00:30
# @File: trnn.py
import copy
import random
from torch.nn.functional import cross_entropy
from typing import Tuple
from mwptoolkit.module.Decoder.rnn_decoder import AttentionalRNNDecoder
from mwptoolkit.loss.nll_loss import NLLLoss
# from mwptoolkit.loss.cross_entropy_loss import CrossEntropyLoss
import torch
from torch import nn
from mwptoolkit.module.Layer.tree_layers import RecursiveNN
from mwptoolkit.module.Encoder.rnn_encoder import SelfAttentionRNNEncoder, BasicRNNEncoder
from mwptoolkit.module.Attention.seq_attention import SeqAttention
from mwptoolkit.module.Embedder.basic_embedder import BasicEmbedder
from mwptoolkit.module.Embedder.roberta_embedder import RobertaEmbedder
from mwptoolkit.module.Embedder.bert_embedder import BertEmbedder
from mwptoolkit.model.Seq2Seq.rnnencdec import RNNEncDec
from mwptoolkit.utils.data_structure import Node, BinaryTree
from mwptoolkit.utils.enum_type import NumMask, SpecialTokens
[docs]class TRNN(nn.Module):
"""
Reference:
Wang et al. "Template-Based Math Word Problem Solvers with Recursive Neural Networks" in AAAI 2019.
"""
def __init__(self, config, dataset):
super(TRNN, self).__init__()
self.device = config['device']
self.seq2seq_embedding_size = config["seq2seq_embedding_size"]
self.seq2seq_encode_hidden_size = config["seq2seq_encode_hidden_size"]
self.seq2seq_decode_hidden_size = config["seq2seq_decode_hidden_size"]
self.num_layers = config["seq2seq_num_layers"]
self.teacher_force_ratio = config["teacher_force_ratio"]
self.seq2seq_dropout_ratio = config['seq2seq_dropout_ratio']
self.ans_embedding_size = config["ans_embedding_size"]
self.ans_hidden_size = config["ans_hidden_size"]
self.ans_dropout_ratio = config["ans_dropout_ratio"]
self.ans_num_layers = config["ans_num_layers"]
self.encoder_rnn_cell_type = config["encoder_rnn_cell_type"]
self.decoder_rnn_cell_type = config["decoder_rnn_cell_type"]
self.max_gen_len = config["max_output_len"]
self.bidirectional = config["bidirectional"]
self.attention = True
self.share_vocab = config["share_vocab"]
self.embedding = config["embedding"]
self.mask_list = NumMask.number
self.in_idx2word = dataset.in_idx2word
self.out_idx2symbol = dataset.out_idx2symbol
self.temp_idx2symbol = dataset.temp_idx2symbol
self.vocab_size = len(dataset.in_idx2word)
self.symbol_size = len(dataset.out_idx2symbol)
self.temp_symbol_size = len(dataset.temp_idx2symbol)
self.operator_nums = len(dataset.operator_list)
self.operator_list = dataset.operator_list
self.generate_list = [SpecialTokens.UNK_TOKEN] + dataset.generate_list
self.generate_idx = [self.in_idx2word.index(num) for num in self.generate_list]
if self.share_vocab:
self.sos_token_idx = dataset.in_word2idx[SpecialTokens.SOS_TOKEN]
else:
self.sos_token_idx = dataset.out_symbol2idx[SpecialTokens.EOS_TOKEN]
try:
self.out_sos_token = dataset.out_symbol2idx[SpecialTokens.SOS_TOKEN]
except:
self.out_sos_token = None
try:
self.out_eos_token = dataset.out_symbol2idx[SpecialTokens.EOS_TOKEN]
except:
self.out_eos_token = None
try:
self.out_pad_token = dataset.out_symbol2idx[SpecialTokens.PAD_TOKEN]
except:
self.out_pad_token = None
try:
self.temp_sos_token = dataset.temp_symbol2idx[SpecialTokens.SOS_TOKEN]
except:
self.temp_sos_token = None
try:
self.temp_eos_token = dataset.temp_symbol2idx[SpecialTokens.EOS_TOKEN]
except:
self.temp_eos_token = None
try:
self.temp_pad_token = dataset.temp_symbol2idx[SpecialTokens.PAD_TOKEN]
except:
self.temp_pad_token = None
# seq2seq module
if config['embedding'] == 'roberta':
self.seq2seq_in_embedder = RobertaEmbedder(self.vocab_size, config['pretrained_model_path'])
self.seq2seq_in_embedder.token_resize(self.vocab_size)
elif config['embedding'] == 'bert':
self.seq2seq_in_embedder = BertEmbedder(self.vocab_size, config['pretrained_model_path'])
self.seq2seq_in_embedder.token_resize(self.vocab_size)
else:
self.seq2seq_in_embedder = BasicEmbedder(self.vocab_size, self.seq2seq_embedding_size,
self.seq2seq_dropout_ratio)
if self.share_vocab:
self.seq2seq_out_embedder = self.seq2seq_in_embedder
else:
self.seq2seq_out_embedder = BasicEmbedder(self.temp_symbol_size, self.seq2seq_embedding_size,
self.seq2seq_dropout_ratio)
self.seq2seq_encoder = BasicRNNEncoder(self.seq2seq_embedding_size, self.seq2seq_encode_hidden_size,
self.num_layers, \
self.encoder_rnn_cell_type, self.seq2seq_dropout_ratio,
self.bidirectional)
self.seq2seq_decoder = AttentionalRNNDecoder(self.seq2seq_embedding_size, self.seq2seq_decode_hidden_size,
self.seq2seq_encode_hidden_size, \
self.num_layers, self.decoder_rnn_cell_type,
self.seq2seq_dropout_ratio)
self.seq2seq_gen_linear = nn.Linear(self.seq2seq_encode_hidden_size, self.temp_symbol_size)
# answer module
if config['embedding'] == 'roberta':
self.answer_in_embedder = RobertaEmbedder(self.vocab_size, config['pretrained_model_path'])
self.answer_in_embedder.token_resize(self.vocab_size)
elif config['embedding'] == 'bert':
self.answer_in_embedder = BertEmbedder(self.vocab_size, config['pretrained_model_path'])
self.answer_in_embedder.token_resize(self.vocab_size)
else:
self.answer_in_embedder = BasicEmbedder(self.vocab_size, self.ans_embedding_size, self.ans_dropout_ratio)
self.answer_encoder = SelfAttentionRNNEncoder(self.ans_embedding_size, self.ans_hidden_size,
self.ans_embedding_size, self.num_layers, \
self.encoder_rnn_cell_type, self.ans_dropout_ratio,
self.bidirectional)
self.answer_rnn = RecursiveNN(self.ans_embedding_size, self.operator_nums, self.operator_list)
weight = torch.ones(self.temp_symbol_size).to(config["device"])
pad = dataset.out_symbol2idx[SpecialTokens.PAD_TOKEN]
self.seq2seq_loss = NLLLoss(weight, pad)
weight2 = torch.ones(self.operator_nums).to(config["device"])
self.ans_module_loss = NLLLoss(weight2, size_average=True)
# self.ans_module_loss=CrossEntropyLoss(weight2,size_average=True)
self.wrong = 0
[docs] def forward(self, seq, seq_length, seq_mask, num_pos, template_target=None, equation_target=None,
output_all_layers=False):
seq2seq_token_logits, seq2seq_outputs, seq2seq_layer_outputs = self.seq2seq_forward(seq, seq_length,
template_target,
output_all_layers)
if equation_target:
template = None
else:
template = self.convert_temp_idx2symbol(seq2seq_outputs)
ans_token_logits, ans_outputs, ans_module_layer_outputs = self.ans_module_forward(seq, seq_length, seq_mask,
template, num_pos,
equation_target,
output_all_layers)
model_all_outputs = {}
if output_all_layers:
model_all_outputs.update(seq2seq_layer_outputs)
model_all_outputs.update(ans_module_layer_outputs)
return (seq2seq_token_logits, ans_token_logits), (seq2seq_outputs, ans_outputs), model_all_outputs
[docs] def calculate_loss(self, batch_data: dict) -> Tuple[float, float]:
"""Finish forward-propagating, calculating loss and back-propagation.
:param batch_data: one batch data.
:return: seq2seq module loss, answer module loss.
"""
# first stage:train seq2seq
seq2seq_loss = self.seq2seq_calculate_loss(batch_data)
# second stage: train answer module
answer_loss = self.ans_module_calculate_loss(batch_data)
return seq2seq_loss, answer_loss
[docs] def model_test(self, batch_data: dict) -> tuple:
"""Model test.
:param batch_data: one batch data.
:return: predicted equation, target equation.
batch_data should include keywords 'question', 'ques len', 'equation', 'ques mask',
'num pos', 'num list', 'template'
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
target = torch.tensor(batch_data["equation"]).to(self.device)
seq_mask = torch.BoolTensor(batch_data["ques mask"]).to(self.device)
num_pos = batch_data['num pos']
num_list = batch_data["num list"]
template_target = self.convert_temp_idx2symbol(torch.tensor(batch_data['template']))
_, output_template, _ = self.seq2seq_forward(seq, seq_length)
template = self.convert_temp_idx2symbol(output_template)
_, _, ans_module_layers = self.ans_module_forward(seq, seq_length, seq_mask, template, num_pos,
output_all_layers=True)
equations = ans_module_layers['ans_model_equation_outputs']
_, _, ans_module_layers = self.ans_module_forward(seq, seq_length, seq_mask, template_target, num_pos,
output_all_layers=True)
ans_module_test = ans_module_layers['ans_model_equation_outputs']
equations = self.mask2num(equations, num_list)
ans_module_test = self.mask2num(ans_module_test, num_list)
targets = self.convert_idx2symbol(target, num_list)
temp_t = template_target
return equations, targets, template, temp_t, ans_module_test, targets
[docs] def predict(self, batch_data: dict, output_all_layers=False):
"""
predict samples without target.
:param dict batch_data: one batch data.
:param bool output_all_layers: return all layer outputs of model.
:return: token_logits, symbol_outputs, all_layer_outputs
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
ques_mask = torch.BoolTensor(batch_data["ques mask"]).to(self.device)
num_pos = batch_data['num pos']
token_logits, symbol_outputs, model_all_outputs = self.forward(seq, seq_length, ques_mask, num_pos,
output_all_layers=output_all_layers)
return token_logits, symbol_outputs, model_all_outputs
[docs] def seq2seq_calculate_loss(self, batch_data: dict) -> float:
"""Finish forward-propagating, calculating loss and back-propagation of seq2seq module.
:param batch_data: one batch data.
:return: loss value of seq2seq module.
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
target = torch.tensor(batch_data["template"]).to(self.device)
# ques_mask = torch.BoolTensor(batch_data["ques mask"]).to(self.device)
token_logits, _, _ = self.seq2seq_forward(seq, seq_length, target)
if self.share_vocab:
target = self.convert_in_idx_2_temp_idx(target)
outputs = torch.nn.functional.log_softmax(token_logits, dim=-1)
self.seq2seq_loss.reset()
self.seq2seq_loss.eval_batch(outputs.view(-1, outputs.size(-1)), target.view(-1))
self.seq2seq_loss.backward()
return self.seq2seq_loss.get_loss()
[docs] def ans_module_calculate_loss(self, batch_data):
"""Finish forward-propagating, calculating loss and back-propagation of answer module.
:param batch_data: one batch data.
:return: loss value of answer module.
"""
seq = torch.tensor(batch_data["question"]).to(self.device)
seq_length = torch.tensor(batch_data["ques len"]).long()
seq_mask = torch.BoolTensor(batch_data["ques mask"]).to(self.device)
num_pos = batch_data["num pos"]
equ_source = copy.deepcopy(batch_data["equ_source"])
for idx, equ in enumerate(equ_source):
equ_source[idx] = equ.split(" ")
template = equ_source
token_logits, _, ans_module_layers = self.ans_module_forward(seq, seq_length, seq_mask, template, num_pos,
equation_target=template, output_all_layers=True)
target = ans_module_layers["ans_module_target"]
self.ans_module_loss.reset()
for b_i in range(len(target)):
if not isinstance(token_logits[b_i],list):
output = torch.nn.functional.log_softmax(token_logits[b_i], dim=1)
self.ans_module_loss.eval_batch(output, target[b_i].view(-1))
self.ans_module_loss.backward()
return self.ans_module_loss.get_loss()
[docs] def seq2seq_generate_t(self, encoder_outputs, encoder_hidden, decoder_inputs):
with_t = random.random()
if with_t < self.teacher_force_ratio:
if self.attention:
decoder_outputs, decoder_states = self.seq2seq_decoder(decoder_inputs, encoder_hidden, encoder_outputs)
else:
decoder_outputs, decoder_states = self.seq2seq_decoder(decoder_inputs, encoder_hidden)
token_logits = self.seq2seq_gen_linear(decoder_outputs)
token_logits = token_logits.view(-1, token_logits.size(-1))
token_logits = torch.nn.functional.log_softmax(token_logits, dim=1)
else:
seq_len = decoder_inputs.size(1)
decoder_hidden = encoder_hidden
decoder_input = decoder_inputs[:, 0, :].unsqueeze(1)
token_logits = []
for idx in range(seq_len):
if self.attention:
decoder_output, decoder_hidden = self.seq2seq_decoder(decoder_input, decoder_hidden,
encoder_outputs)
else:
decoder_output, decoder_hidden = self.seq2seq_decoder(decoder_input, decoder_hidden)
# attn_list.append(attn)
step_output = decoder_output.squeeze(1)
token_logit = self.seq2seq_gen_linear(step_output)
predict = torch.nn.functional.log_softmax(token_logit, dim=1)
# predict=torch.log_softmax(token_logit,dim=1)
output = predict.topk(1, dim=1)[1]
token_logits.append(predict)
if self.share_vocab:
output = self.convert_temp_idx_2_in_idx(output)
decoder_input = self.seq2seq_out_embedder(output)
else:
decoder_input = self.seq2seq_out_embedder(output)
token_logits = torch.stack(token_logits, dim=1)
token_logits = token_logits.view(-1, token_logits.size(-1))
return token_logits
[docs] def seq2seq_generate_without_t(self, encoder_outputs, encoder_hidden, decoder_input):
all_outputs = []
decoder_hidden = encoder_hidden
for idx in range(self.max_gen_len):
if self.attention:
decoder_output, decoder_hidden = self.seq2seq_decoder(decoder_input, decoder_hidden, encoder_outputs)
else:
decoder_output, decoder_hidden = self.seq2seq_decoder(decoder_input, decoder_hidden)
step_output = decoder_output.squeeze(1)
token_logits = self.seq2seq_gen_linear(step_output)
predict = torch.nn.functional.log_softmax(token_logits, dim=1)
output = predict.topk(1, dim=1)[1]
all_outputs.append(output)
if self.share_vocab:
output = self.convert_temp_idx_2_in_idx(output)
decoder_input = self.seq2seq_out_embedder(output)
else:
decoder_input = self.seq2seq_out_embedder(output)
all_outputs = torch.cat(all_outputs, dim=1)
return all_outputs
[docs] def seq2seq_forward(self, seq, seq_length, target=None, output_all_layers=False):
batch_size = seq.size(0)
device = seq.device
seq_emb = self.seq2seq_in_embedder(seq)
encoder_outputs, encoder_hidden, encoder_layer_outputs = self.seq2seq_encoder_forward(seq_emb, seq_length,
output_all_layers)
decoder_inputs = self.init_seq2seq_decoder_inputs(target, device, batch_size)
token_logits, symbol_outputs, decoder_layer_outputs = self.seq2seq_decoder_forward(encoder_outputs,
encoder_hidden,
decoder_inputs, target,
output_all_layers)
seq2seq_all_outputs = {}
if output_all_layers:
seq2seq_all_outputs['seq2seq_inputs_embedding'] = seq_emb
seq2seq_all_outputs.update(encoder_layer_outputs)
seq2seq_all_outputs.update(decoder_layer_outputs)
return token_logits, symbol_outputs, seq2seq_all_outputs
[docs] def ans_module_forward(self, seq, seq_length, seq_mask, template, num_pos, equation_target=None,
output_all_layers=False):
if self.embedding == 'roberta':
seq_emb = self.answer_in_embedder(seq, seq_mask)
else:
seq_emb = self.answer_in_embedder(seq)
encoder_output, encoder_hidden = self.answer_encoder(seq_emb, seq_length)
batch_size = encoder_output.size(0)
generate_num = torch.tensor(self.generate_idx).to(self.device)
if self.embedding == 'roberta':
generate_emb = self.answer_in_embedder(generate_num, None)
else:
generate_emb = self.answer_in_embedder(generate_num)
batch_prob = []
batch_target = []
outputs = []
equations = []
input_template = equation_target if equation_target else template
if equation_target is not None:
for b_i in range(batch_size):
try:
tree_i = self.template2tree(input_template[b_i])
except IndexError:
outputs.append([])
continue
look_up = self.generate_list + NumMask.number[:len(num_pos[b_i])]
num_encoding = seq_emb[b_i, num_pos[b_i]] + encoder_output[b_i, num_pos[b_i]]
num_embedding = torch.cat([generate_emb, num_encoding], dim=0)
assert len(look_up) == len(num_embedding)
prob, target = self.answer_rnn(tree_i.root, num_embedding, look_up, self.out_idx2symbol)
batch_prob.append(prob)
batch_target.append(target)
if not isinstance(prob,list):
output = torch.topk(prob, 1)[1]
outputs.append(output)
else:
outputs.append([])
else:
for b_i in range(batch_size):
try:
tree_i = self.template2tree(input_template[b_i])
except IndexError:
outputs.append([])
continue
look_up = self.generate_list + NumMask.number[:len(num_pos[b_i])]
num_encoding = seq_emb[b_i, num_pos[b_i]] + encoder_output[b_i, num_pos[b_i]]
num_embedding = torch.cat([generate_emb, num_encoding], dim=0)
assert len(look_up) == len(num_embedding)
prob, output, node_pred = self.answer_rnn.test(tree_i.root, num_embedding, look_up, self.out_idx2symbol)
batch_prob.append(prob)
tree_i.root = node_pred
outputs.append(output)
equation = self.tree2equation(tree_i)
equations.append(equation)
all_layer_outputs = {}
if output_all_layers:
all_layer_outputs['ans_module_token_logits'] = batch_prob
all_layer_outputs['ans_module_target'] = batch_target
all_layer_outputs['ans_model_outputs'] = outputs
all_layer_outputs['ans_model_equation_outputs'] = equations
return batch_prob, outputs, all_layer_outputs
[docs] def seq2seq_encoder_forward(self, seq_emb, seq_length, output_all_layers=False):
encoder_outputs, encoder_hidden = self.seq2seq_encoder(seq_emb, seq_length)
if self.bidirectional:
encoder_outputs = encoder_outputs[:, :, self.seq2seq_encode_hidden_size:] + encoder_outputs[:, :,
:self.seq2seq_encode_hidden_size]
if self.encoder_rnn_cell_type == 'lstm':
encoder_hidden = (encoder_hidden[0][::2].contiguous(), encoder_hidden[1][::2].contiguous())
else:
encoder_hidden = encoder_hidden[::2].contiguous()
if self.encoder_rnn_cell_type == self.decoder_rnn_cell_type:
pass
elif (self.encoder_rnn_cell_type == 'gru') and (self.decoder_rnn_cell_type == 'lstm'):
encoder_hidden = (encoder_hidden, encoder_hidden)
elif (self.encoder_rnn_cell_type == 'rnn') and (self.decoder_rnn_cell_type == 'lstm'):
encoder_hidden = (encoder_hidden, encoder_hidden)
elif (self.encoder_rnn_cell_type == 'lstm') and (
self.decoder_rnn_cell_type == 'gru' or self.decoder_rnn_cell_type == 'rnn'):
encoder_hidden = encoder_hidden[0]
else:
pass
all_layer_outputs = {}
if output_all_layers:
all_layer_outputs['seq2seq_encoder_outputs'] = encoder_outputs
all_layer_outputs['seq2seq_encoder_hidden'] = encoder_hidden
return encoder_outputs, encoder_hidden, all_layer_outputs
[docs] def seq2seq_decoder_forward(self, encoder_outputs, encoder_hidden, decoder_inputs, target=None,
output_all_layers=False):
if target is not None and random.random() < self.teacher_force_ratio:
if self.attention:
decoder_outputs, decoder_states = self.seq2seq_decoder(decoder_inputs, encoder_hidden, encoder_outputs)
else:
decoder_outputs, decoder_states = self.seq2seq_decoder(decoder_inputs, encoder_hidden)
token_logits = self.seq2seq_gen_linear(decoder_outputs)
outputs = token_logits.topk(1, dim=-1)[1]
else:
seq_len = decoder_inputs.size(1) if target is not None else self.max_gen_len
decoder_hidden = encoder_hidden
decoder_input = decoder_inputs[:, 0, :].unsqueeze(1)
decoder_outputs = []
token_logits = []
outputs = []
for idx in range(seq_len):
if self.attention:
decoder_output, decoder_hidden = self.seq2seq_decoder(decoder_input, decoder_hidden,
encoder_outputs)
else:
decoder_output, decoder_hidden = self.seq2seq_decoder(decoder_input, decoder_hidden)
step_output = decoder_output.squeeze(1)
token_logit = self.seq2seq_gen_linear(step_output)
output = token_logit.topk(1, dim=-1)[1]
decoder_outputs.append(step_output)
token_logits.append(token_logit)
outputs.append(output)
if self.share_vocab:
output = self.convert_temp_idx_2_in_idx(output)
decoder_input = self.seq2seq_out_embedder(output)
else:
decoder_input = self.seq2seq_out_embedder(output)
decoder_outputs = torch.stack(decoder_outputs, dim=1)
token_logits = torch.stack(token_logits, dim=1)
outputs = torch.stack(outputs, dim=1)
all_layer_outputs = {}
if output_all_layers:
all_layer_outputs['seq2seq_decoder_outputs'] = decoder_outputs
all_layer_outputs['seq2seq_token_logits'] = token_logits
all_layer_outputs['seq2seq_outputs'] = outputs
return token_logits, outputs, all_layer_outputs
[docs] def template2tree(self, template):
tree = BinaryTree()
tree.equ2tree_(template)
return tree
[docs] def tree2equation(self, tree):
equation = tree.tree2equ(tree.root)
return equation
[docs] def convert_temp_idx_2_in_idx(self, output):
device = output.device
batch_size = output.size(0)
seq_len = output.size(1)
decoded_output = []
for b_i in range(batch_size):
output_i = []
for s_i in range(seq_len):
output_i.append(self.in_word2idx[self.temp_idx2symbol[output[b_i, s_i]]])
decoded_output.append(output_i)
decoded_output = torch.tensor(decoded_output).to(device).view(batch_size, -1)
return decoded_output
[docs] def convert_in_idx_2_temp_idx(self, output):
device = output.device
batch_size = output.size(0)
seq_len = output.size(1)
decoded_output = []
for b_i in range(batch_size):
output_i = []
for s_i in range(seq_len):
output_i.append(self.temp_symbol2idx[self.in_idx2word[output[b_i, s_i]]])
decoded_output.append(output_i)
decoded_output = torch.tensor(decoded_output).to(device).view(batch_size, -1)
return decoded_output
[docs] def convert_temp_idx2symbol(self, output):
batch_size = output.size(0)
seq_len = output.size(1)
symbol_list = []
for b_i in range(batch_size):
symbols = []
for s_i in range(seq_len):
idx = output[b_i][s_i]
if idx in [self.temp_sos_token, self.temp_eos_token, self.temp_pad_token]:
break
symbol = self.temp_idx2symbol[idx]
symbols.append(symbol)
symbol_list.append(symbols)
return symbol_list
[docs] def convert_idx2symbol(self, output, num_list):
batch_size = output.size(0)
seq_len = output.size(1)
output_list = []
for b_i in range(batch_size):
res = []
num_len = len(num_list[b_i])
for s_i in range(seq_len):
idx = output[b_i][s_i]
if idx in [self.out_sos_token, self.out_eos_token, self.out_pad_token]:
break
symbol = self.out_idx2symbol[idx]
if "NUM" in symbol:
num_idx = self.mask_list.index(symbol)
if num_idx >= num_len:
res.append(symbol)
else:
res.append(num_list[b_i][num_idx])
else:
res.append(symbol)
output_list.append(res)
return output_list
[docs] def symbol2idx(self, symbols):
r"""symbol to idx
equation symbol to equation idx
"""
outputs = []
for symbol in symbols:
if symbol not in self.out_idx2symbol:
idx = self.out_idx2symbol.index(SpecialTokens.UNK_TOKEN)
else:
idx = self.out_idx2symbol.index(symbol)
outputs.append(idx)
return outputs
[docs] def mask2num(self, output, num_list):
batch_size = len(output)
output_list = []
for b_i in range(batch_size):
res = []
seq_len = len(output[b_i])
num_len = len(num_list[b_i])
for s_i in range(seq_len):
symbol = output[b_i][s_i]
if "NUM" in symbol:
num_idx = self.mask_list.index(symbol)
if num_idx >= num_len:
res.append(symbol)
else:
res.append(num_list[b_i][num_idx])
else:
res.append(symbol)
output_list.append(res)
return output_list