Source code for mwptoolkit.model.Seq2Seq.saligned

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/21 04:37:56
# @File: saligned.py

import math
from typing import Tuple, Dict, Any

import torch
import numpy as np
from torch import nn

from mwptoolkit.module.Embedder.basic_embedder import BasicEmbedder
from mwptoolkit.module.Encoder.rnn_encoder import SalignedEncoder
from mwptoolkit.module.Decoder.rnn_decoder import SalignedDecoder
from mwptoolkit.module.Environment.stack_machine import OPERATIONS, StackMachine
from mwptoolkit.utils.enum_type import SpecialTokens, NumMask, Operators


[docs]class Saligned(nn.Module): """ Reference: Chiang et al. "Semantically-Aligned Equation Generation for Solving and Reasoning Math Word Problems". """ def __init__(self, config, dataset): super(Saligned, self).__init__() self.device = config['device'] self.operations = operations = OPERATIONS(dataset.out_symbol2idx) # parameter self._vocab_size = vocab_size = len(dataset.in_idx2word) self._dim_embed = dim_embed = config['embedding_size'] self._dim_hidden = dim_hidden = config['hidden_size'] self._dropout_rate = dropout_rate = config['dropout_ratio'] self.max_gen_len = 40 self.NOOP = operations.NOOP self.GEN_VAR = operations.GEN_VAR self.ADD = operations.ADD self.SUB = operations.SUB self.MUL = operations.MUL self.DIV = operations.DIV self.POWER = operations.POWER self.EQL = operations.EQL self.N_OPS = operations.N_OPS self.PAD = operations.PAD self._device = device = config["device"] self.min_NUM = dataset.out_symbol2idx['NUM_0'] # print(self.dataloader.dataset.out_symbol2idx); exit() # self.do_addeql = False if '<BRG>' in dataset.out_symbol2idx else True # max_NUM = list(dataset.out_symbol2idx.keys())[-2] # self.max_NUM = dataset.out_symbol2idx[max_NUM] # self.ADD = dataset.out_symbol2idx['+'] self.POWER = dataset.out_symbol2idx['^'] self.min_CON = self.N_OPS_out = self.POWER + 1 # self.min_CON = self.N_OPS_out = dataset.out_symbol2idx['^']+1 if '<BRG>' not in dataset.out_symbol2idx else dataset.out_symbol2idx['<BRG>']+1 # self.UNK = dataset.out_symbol2idx['<UNK>'] # self.max_CON = self.min_NUM - 1 self.fix_constants = list(dataset.out_symbol2idx.keys())[self.min_CON:self.min_NUM] self.mask_list = NumMask.number self.out_symbol2idx = dataset.out_symbol2idx self.out_idx2symbol = dataset.out_idx2symbol try: self.out_sos_token = self.out_symbol2idx[SpecialTokens.SOS_TOKEN] except: self.out_sos_token = None try: self.out_eos_token = self.out_symbol2idx[SpecialTokens.EOS_TOKEN] except: self.out_eos_token = None try: self.out_pad_token = self.out_symbol2idx[SpecialTokens.PAD_TOKEN] except: self.out_pad_token = None # module # print('vocab_size', config); #exit() self.embedder = BasicEmbedder(vocab_size, dim_embed, dropout_rate) self.encoder = SalignedEncoder(dim_embed, dim_hidden, dim_hidden, dropout_rate) self.decoder = SalignedDecoder(operations, dim_hidden, dropout_rate, device) self.embedding_one = torch.nn.Parameter(torch.normal(torch.zeros(2 * dim_hidden), 0.01)) self.embedding_pi = torch.nn.Parameter(torch.normal(torch.zeros(2 * dim_hidden), 0.01)) self.encoder.initialize_fix_constant(len(self.fix_constants), self._device) # make loss class_weights = torch.ones(operations.N_OPS + 1) # class_weights[OPERATIONS.NOOP] = 0 self._op_loss = torch.nn.CrossEntropyLoss(class_weights, size_average=False, reduce=False, ignore_index=-1) self._arg_loss = torch.nn.CrossEntropyLoss()
[docs] def forward(self, seq, seq_length, number_list, number_position, number_size, target=None, target_length=None, output_all_layers=False) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Dict[str, Any]]: """ :param torch.Tensor seq: :param torch.Tensor seq_length: :param list number_list: :param list number_position: :param list number_size: :param torch.Tensor | None target: :param torch.Tensor | None target_length: :param bool output_all_layers: :return: token_logits:[batch_size, output_length, output_size], symbol_outputs:[batch_size,output_length], model_all_outputs. :rtype: tuple(torch.Tensor, torch.Tensor, dict) """ constant_indices = number_position constants = number_list num_len = number_size seq_length = seq_length.long() batch_size = seq.size(0) bottom = torch.zeros(self._dim_hidden * 2).to(self._device) bottom.requires_grad = False seq_emb = self.embedder(seq) encoder_outputs, encoder_hidden, operands, number_emb, encoder_layer_outputs = self.encoder_forward(seq_emb, seq_length, constant_indices, output_all_layers) stacks = [StackMachine(self.operations, constants[b] + self.fix_constants, number_emb[b], bottom, dry_run=True) for b in range(batch_size)] if target is not None: operands_len = torch.LongTensor(self.N_OPS + np.array(num_len)).to(self._device) operands_len = operands_len.unsqueeze(1).repeat(1,target.size(1)) target[(target >= operands_len)] = self.N_OPS token_logits, symbol_outputs, decoder_layer_outputs = self.decoder_forward(encoder_outputs, encoder_hidden, seq_length, operands, stacks, number_emb, target, target_length, output_all_layers) model_all_outputs = {} if output_all_layers: model_all_outputs['inputs_embedding'] = seq_emb model_all_outputs.update(encoder_layer_outputs) model_all_outputs.update(decoder_layer_outputs) return token_logits, symbol_outputs, model_all_outputs
[docs] def calculate_loss(self, batch_data: dict) -> float: """Finish forward-propagating, calculating loss and back-propagation. :param batch_data: one batch data. :return: loss value. batch_data should include keywords 'question', 'ques len', 'equation', 'equ len', 'num pos', 'num list', 'num size'. """ text = torch.tensor(batch_data["question"]).to(self.device) ops = torch.tensor(batch_data["equation"]).to(self.device) text_len = torch.tensor(batch_data["ques len"]).long() ops_len = torch.tensor(batch_data["equ len"]).long() constant_indices = batch_data["num pos"] constants = batch_data["num list"] num_len = batch_data["num size"] logits, _, all_layers = self.forward(text, text_len, constants, constant_indices, num_len, ops, ops_len, output_all_layers=True) (op_logits, arg_logits) = logits (op_targets, arg_targets) = all_layers['op_targets'], all_layers['arg_targets'] batch_size = ops.size(0) loss = torch.zeros(batch_size).to(self._device) for t in range(max(ops_len)): loss += self._op_loss(op_logits[:,t,:], op_targets[:,t]) for b in range(batch_size): if self.NOOP <= arg_targets[b, t] < self.N_OPS: continue loss[b] += self._arg_loss(arg_logits[b, t].unsqueeze(0), arg_targets[b, t].unsqueeze(0) - self.N_OPS) loss = (loss / max(ops_len)).mean() loss.backward() return loss.item()
[docs] def model_test(self, batch_data: dict) -> tuple: """Model test. :param batch_data: one batch data. :return: predicted equation, target equation. batch_data should include keywords 'question', 'ques len', 'equation', 'equ len', 'num pos', 'num list', 'num size'. """ text = torch.tensor(batch_data['question']).to(self.device) text_len = torch.tensor(batch_data['ques len']).long() constant_indices = batch_data["num pos"] constants = batch_data["num list"] num_len = batch_data["num size"] target = torch.tensor(batch_data['equation']) _, outputs, _ = self.forward(text,text_len,constants,constant_indices,num_len) predicts = self.convert_idx2symbol(outputs, constants) targets = self.convert_idx2symbol(target, constants) return predicts, targets
[docs] def predict(self, batch_data:dict, output_all_layers=False): """ predict samples without target. :param dict batch_data: one batch data. :param bool output_all_layers: return all layer outputs of model. :return: token_logits, symbol_outputs, all_layer_outputs """ seq = torch.tensor(batch_data["question"]).to(self.device) seq_len = torch.tensor(batch_data["ques len"]).long() num_pos = batch_data["num pos"] num_list = batch_data["num list"] num_size = batch_data["num size"] token_logits, symbol_outputs, model_all_outputs = self.forward(seq, seq_len, num_list, num_pos, num_size, output_all_layers=output_all_layers) return token_logits, symbol_outputs, model_all_outputs
[docs] def encoder_forward(self, seq_emb, seq_length, constant_indices, output_all_layers=False): batch_size = seq_emb.size(0) encoder_outputs, encoder_hidden, operands = \ self.encoder.forward(seq_emb, seq_length, constant_indices) number_emb = [operands[b_i] + self.encoder.get_fix_constant() for b_i in range(batch_size)] all_layer_outputs = {} if output_all_layers: all_layer_outputs['encoder_outputs'] = encoder_outputs all_layer_outputs['encoder_hidden'] = encoder_hidden return encoder_outputs, encoder_hidden, operands, number_emb, all_layer_outputs
[docs] def decoder_forward(self, encoder_outputs, encoder_hidden, inputs_length, operands, stacks, number_emb, target=None, target_length=None, output_all_layers=False): batch_size = encoder_outputs.size(0) prev_op = (torch.zeros(batch_size).to(self._device) - 1).type(torch.LongTensor) prev_output = None prev_state = encoder_hidden decoder_outputs = [] token_logits = [] arg_logits = [] outputs = [] op_targets = [] arg_targets = [] if target is not None: for t in range(max(target_length)): op_logit, arg_logit, prev_output, prev_state = self.decoder(encoder_outputs, inputs_length, operands, stacks, prev_op, prev_output, prev_state, number_emb, self.N_OPS) prev_op = target[:, t] decoder_outputs.append(prev_output) token_logits.append(op_logit) arg_logits.append(arg_logit) # outputs.append(torch.argmax(op_logits, dim=1)) op_target = target[:, t].clone().detach() op_target[(np.array(target_length) <= t)] = self.NOOP op_target[(op_target >= self.N_OPS)] = self.N_OPS op_target.require_grad = False op_targets.append(op_target) _, pred_op = torch.log(torch.nn.functional.softmax(op_logit, -1)).max(-1) _, pred_arg = torch.log(torch.nn.functional.softmax(arg_logit, -1)).max(-1) for b in range(batch_size): if pred_op[b] == self.N_OPS: pred_op[b] += pred_arg[b] outputs.append(pred_op) else: finished = [False] * batch_size for t in range(self.max_gen_len): op_logit, arg_logit, prev_output, prev_state = self.decoder(encoder_outputs, inputs_length, operands, stacks, prev_op, prev_output, prev_state, number_emb, self.N_OPS) n_finished = 0 for b in range(batch_size): if len(stacks[b].stack_log_index) and stacks[b].stack_log_index[-1] == self.EQL: finished[b] = True if finished[b]: op_logit[b, self.PAD] = math.inf n_finished += 1 # if stacks[b].get_height() < 2: # op_logit[b, self.ADD] = -math.inf # op_logit[b, self.SUB] = -math.inf # op_logit[b, self.MUL] = -math.inf # op_logit[b, self.DIV] = -math.inf # op_logit[b, self.POWER] = -math.inf op_loss, prev_op = torch.log(torch.nn.functional.softmax(op_logit, -1)).max(-1) arg_loss, prev_arg = torch.log(torch.nn.functional.softmax(arg_logit, -1)).max(-1) for b in range(batch_size): if prev_op[b] == self.N_OPS: prev_op[b] += prev_arg[b] if n_finished == batch_size: break decoder_outputs.append(prev_output) token_logits.append(op_logit) arg_logits.append(arg_logit) outputs.append(prev_op) if n_finished == batch_size: break decoder_outputs = torch.stack(decoder_outputs, dim=1) token_logits = torch.stack(token_logits, dim=1) arg_logits = torch.stack(arg_logits, dim=1) outputs = torch.stack(outputs, dim=1) if target is not None: op_targets = torch.stack(op_targets,dim=1) arg_targets = target.clone() all_layer_outputs = {} if output_all_layers: all_layer_outputs['decoder_outputs'] = decoder_outputs all_layer_outputs['op_logits'] = token_logits all_layer_outputs['arg_logits'] = arg_logits all_layer_outputs['outputs'] = outputs all_layer_outputs['op_targets'] = op_targets all_layer_outputs['arg_targets'] = arg_targets return (token_logits,arg_logits), outputs, all_layer_outputs
[docs] def convert_mask_num(self, batch_output, num_list): output_list = [] for b_i, output in enumerate(batch_output): res = [] num_len = len(num_list[b_i]) for symbol in output: if "NUM" in symbol: num_idx = self.mask_list.index(symbol) if num_idx >= num_len: res.append(symbol) else: res.append(num_list[b_i][num_idx]) else: res.append(symbol) output_list.append(res) return output_list
[docs] def convert_idx2symbol(self, output, num_list): batch_size = output.size(0) seq_len = output.size(1) output_list = [] for b_i in range(batch_size): res = [] num_len = len(num_list[b_i]) for s_i in range(seq_len): idx = output[b_i][s_i] if idx in [self.out_sos_token, self.out_eos_token, self.out_pad_token]: break symbol = self.out_idx2symbol[idx] if "NUM" in symbol: num_idx = self.mask_list.index(symbol) if num_idx >= num_len: res.append(symbol) else: res.append(num_list[b_i][num_idx]) else: res.append(symbol) output_list.append(res) return output_list