Source code for mwptoolkit.model.Seq2Seq.ept

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/21 04:35:42
# @File: ept.py

import torch
from torch import nn
from transformers import AutoModel

from mwptoolkit.module.Encoder.transformer_encoder import TransformerEncoder
from mwptoolkit.module.Decoder.transformer_decoder import TransformerDecoder
from mwptoolkit.module.Decoder.ept_decoder import VanillaOpTransformer, ExpressionTransformer, ExpressionPointerTransformer
from mwptoolkit.module.Embedder.position_embedder import PositionEmbedder
from mwptoolkit.module.Embedder.basic_embedder import BasicEmbedder
from mwptoolkit.module.Attention.self_attention import SelfAttentionMask
from mwptoolkit.module.Strategy.beam_search import Beam_Search_Hypothesis
from mwptoolkit.module.Strategy.sampling import topk_sampling
from mwptoolkit.module.Strategy.greedy import greedy_search
from mwptoolkit.loss.smoothed_cross_entropy_loss import SmoothCrossEntropyLoss
from mwptoolkit.utils.enum_type import EPT as EPT_CON


[docs]def Submodule_types(decoder_type): if "vall" in decoder_type: return VanillaOpTransformer elif 'gen' in decoder_type: return ExpressionTransformer elif 'ptr' in decoder_type: return ExpressionPointerTransformer
[docs]class EPT(nn.Module): """ Reference: Kim et al. "Point to the Expression: Solving Algebraic Word Problems using the Expression-Pointer Transformer Model" in EMNLP 2020. """ def __init__(self, config, dataset): super(EPT, self).__init__() self.device = config["device"] self.max_output_len = config["max_output_len"] self.share_vocab = config["share_vocab"] self.decoding_strategy = config["decoding_strategy"] self.teacher_force_ratio = config["teacher_force_ratio"] self.task_type = config['task_type'] try: self.in_pad_idx = dataset.in_word2idx["<pad>"] except: self.in_pad_idx = None self.in_word2idx = dataset.in_word2idx self.in_idx2word = dataset.in_idx2word self.mode = config["decoder"] if 'vall' in config["decoder"]: self.out_symbol2idx = dataset.out_symbol2idx self.out_idx2symbol = dataset.out_idx2symbol #self.out_pad_idx = self.in_pad_idx #self.out_sos_idx = config["in_word2idx"]["<SOS>"] self.decoder = VanillaOpTransformer(config, self.out_symbol2idx, self.out_idx2symbol) else: self.out_opsym2idx = dataset.out_opsym2idx self.out_idx2opsym = dataset.out_idx2opsymbol self.out_consym2idx = dataset.out_consym2idx self.out_idx2consym = dataset.out_idx2consymbol if 'gen' in config["decoder"]: self.decoder = ExpressionTransformer(config, self.out_opsym2idx, self.out_idx2opsym, self.out_consym2idx, self.out_idx2consym) elif 'ptr' in config["decoder"]: self.decoder = ExpressionPointerTransformer(config, self.out_opsym2idx, self.out_idx2opsym, self.out_consym2idx, self.out_idx2consym) #self.out_pad_idx = self.in_pad_idx #self.out_sos_idx = config["in_word2idx"]["<SOS>"] pretrained_model_path = config['pretrained_model'] if config['pretrained_model'] else config[ 'transformers_pretrained_model'] self.encoder = AutoModel.from_pretrained(pretrained_model_path) #self.encoder = TransformerEncoder(config["embedding_size"], config["ffn_size"], config["num_encoder_layers"], \ # config["num_heads"], config["attn_dropout_ratio"], \ # config["attn_weight_dropout_ratio"], config["ffn_dropout_ratio"]) #self.decoder = TransformerDecoder(config["embedding_size"], config["ffn_size"], config["num_decoder_layers"], \ # config["num_heads"], config["attn_dropout_ratio"], \ # config["attn_weight_dropout_ratio"], config["ffn_dropout_ratio"]) #self.decoder = Submodule_types(config["decoder"])(config) #self.out = nn.Linear(config["embedding_size"], config["symbol_size"]) self.loss = SmoothCrossEntropyLoss()
[docs] def forward(self, src, src_mask, num_pos, num_size, target=None, output_all_layers=False): """ :param torch.Tensor src: input sequence. :param list src_mask: mask of input sequence. :param list num_pos: number position of input sequence. :param list num_size: number of numbers of input sequence. :param torch.Tensor target: target, default None. :param bool output_all_layers: return output of all layers if output_all_layers is True, default False. :return: token_logits:[batch_size, output_length, output_size], symbol_outputs:[batch_size,output_length], model_all_outputs. """ encoder_output, encoder_layer_outputs = self.encoder_forward(src, src_mask, output_all_layers) max_numbers = max(num_size) if num_pos is not None: text_num, text_numpad = self.gather_vectors(encoder_output, num_pos, max_len=max_numbers) else: text_num = text_numpad = None token_logits, outputs, decoder_layer_outputs = self.decoder_forward(encoder_output, text_num, text_numpad, src_mask, target, output_all_layers) model_all_outputs = {} if output_all_layers: model_all_outputs.update(encoder_layer_outputs) model_all_outputs.update(decoder_layer_outputs) return token_logits, outputs, model_all_outputs
[docs] def calculate_loss(self, batch_data:dict) -> float: """Finish forward-propagating, calculating loss and back-propagation. :param batch_data: one batch data. :return: loss value. batch_data should include keywords 'question', 'ques len', 'equation','ques mask', 'num pos', 'num size' and 'max numbers'. """ src = torch.tensor(batch_data["question"]).to(self.device) src_mask = torch.BoolTensor(batch_data["ques mask"]).to(self.device) num_pos = batch_data["num pos"] target = torch.tensor(batch_data["equation"]).to(self.device) num_size = batch_data["num size"] token_logits, _, all_layers = self.forward(src,src_mask,num_pos,num_size,target,output_all_layers=True) targets = all_layers['targets'] self.loss.reset() for key, result in targets.items(): predicted = token_logits[key].flatten(0, -2) result = self.shift_target(result) target = result.flatten() self.loss.eval_batch(predicted, target) self.loss.backward() batch_loss = self.loss.get_loss() return batch_loss
[docs] def model_test(self, batch_data:dict) -> tuple: """Model test. :param batch_data: one batch data. :return: predicted equation, target equation. batch_data should include keywords 'question', 'equation','ques mask', 'num pos', 'num size'. """ src = torch.tensor(batch_data["question"]).to(self.device) src_mask = torch.BoolTensor(batch_data["ques mask"]).to(self.device) num_pos = batch_data["num pos"] num_size = batch_data["num size"] _, symbol_outputs, _ = self.forward(src, src_mask, num_pos, num_size) all_outputs = self.convert_idx2symbol(symbol_outputs, batch_data["num list"]) targets = self.convert_idx2symbol(batch_data["equation"], batch_data["num list"]) return all_outputs, targets
[docs] def predict(self,batch_data:dict,output_all_layers=False): """ predict samples without target. :param dict batch_data: one batch data. :param bool output_all_layers: return all layer outputs of model. :return: token_logits, symbol_outputs, all_layer_outputs """ raise NotImplementedError
[docs] def encoder_forward(self, src, src_mask, output_all_layers=False): encoder_outputs = self.encoder(input_ids=src, attention_mask=(~src_mask).float()) encoder_output = encoder_outputs[0] all_layer_outputs = {} if output_all_layers: all_layer_outputs['encoder_outputs'] = encoder_output all_layer_outputs['inputs_representation'] = encoder_output return encoder_output, all_layer_outputs
[docs] def decoder_forward(self, encoder_output, text_num, text_numpad, src_mask, target=None, output_all_layers=False): if target is not None: token_logits, targets = self.decoder(text=encoder_output, text_num=text_num, text_numpad=text_numpad, text_pad=src_mask, equation=target) outputs = None else: max_len = self.max_output_len outputs, _ = self.decoder(text=encoder_output, text_num=text_num, text_numpad=text_numpad, text_pad=src_mask, beam=1, max_len=max_len) token_logits = None shape = list(outputs.shape) seq_len = shape[2] if seq_len < max_len: shape[2] = max_len tensor = torch.full(shape, fill_value=-1, dtype=torch.long) tensor[:, :, :seq_len] = outputs.cpu() outputs = tensor outputs.squeeze(1) targets = None all_layer_outputs = {} if output_all_layers: all_layer_outputs['targets'] = targets return token_logits, outputs, all_layer_outputs
[docs] def decode(self, output): device = output.device batch_size = output.size(0) decoded_output = [] for idx in range(batch_size): decoded_output.append(self.in_word2idx[self.out_idx2symbol[output[idx]]]) decoded_output = torch.tensor(decoded_output).to(device).view(batch_size, -1) return output
[docs] def gather_vectors(self, hidden: torch.Tensor, mask: torch.Tensor, max_len: int = 1): """ Gather hidden states of indicated positions. :param torch.Tensor hidden: Float Tensor of hidden states. Shape [B, S, H], where B = batch size, S = length of sequence, and H = hidden dimension :param torch.Tensor mask: Long Tensor which indicates number indices that we're interested in. Shape [B, S]. :param int max_len: Expected maximum length of vectors per batch. 1 by default. :rtype: Tuple[torch.Tensor, torch.Tensor] :return: Tuple of Tensors: - [0]: Float Tensor of indicated hidden states. Shape [B, N, H], where N = max(number of interested positions, max_len) - [1]: Bool Tensor of padded positions. Shape [B, N]. """ # Compute the maximum number of indicated positions in the text max_len = max(mask.max().item(), max_len) batch_size, seq_len, hidden_size = hidden.shape # Storage for gathering hidden states gathered = torch.zeros(batch_size, max_len, hidden_size, dtype=hidden.dtype, device=hidden.device) pad_mask = torch.ones(batch_size, max_len, dtype=torch.bool, device=hidden.device) # Average hidden states for tokens representing a number for row in range(batch_size): for i in range(max_len): indices = (mask[row] == i).nonzero().view(-1).tolist() if len(indices) > 0: begin = min(indices) end = max(indices) + 1 # Copy masked positions. Take mean of number vectors. gathered[row, i] = hidden[row, begin:end].mean(dim=0) pad_mask[row, i] = False return gathered, pad_mask
[docs] def shift_target(self, target: torch.Tensor, fill_value=-1) -> torch.Tensor: """ Shift matrix to build generation targets. :param torch.Tensor target: Target tensor to build generation targets. Shape [B, T] :param fill_value: Value to be filled at the padded positions. :rtype: torch.Tensor :return: Tensor with shape [B, T], where (i, j)-entries are (i, j+1) entry of target tensor. """ # Target does not require gradients. with torch.no_grad(): pad_at_end = torch.full((target.shape[0], 1), fill_value=fill_value, dtype=target.dtype, device=target.device) return torch.cat([target[:, 1:], pad_at_end], dim=-1).contiguous()
[docs] def convert_idx2symbol(self, output, num_list): #batch_size=output.size(0) '''batch_size=1''' output_list = [] if "vall" in self.mode: for id, single in enumerate(output): output_list.append(self.out_expression_op(single, num_list[id])) else: for id, single in enumerate(output): output_list.append(self.out_expression_expr(single, num_list[id])) return output_list
[docs] def out_expression_op(self, item, num_list): equation = [] # Tokens after PAD_ID will be ignored. for i, token in enumerate(item.tolist()): if token != EPT_CON.PAD_ID: token = self.out_idx2sym[token] if token == EPT_CON.SEQ_NEW_EQN: equation.clear() continue elif token == EPT_CON.SEQ_END_EQN: break else: break equation.append(token) return equation
[docs] def out_expression_expr(self, item, num_list): expressions = [] for token in item: # For each token in the item. # First index should be the operator. operator = self.out_idx2opsym[token[0]] if operator == EPT_CON.FUN_NEW_EQN: # If the operator is __NEW_EQN, we ignore the previously generated outputs. expressions.clear() continue if operator == EPT_CON.FUN_END_EQN: # If the operator is __END_EQN, we ignore the next outputs. break # Now, retrieve the operands operands = [] for i in range(1, len(token), 2): # For each argument, we build two values: source and value. src = token[i] if src != EPT_CON.PAD_ID: # If source is not a padding, compute the value. src = EPT_CON.ARG_TOKENS[src] operand = token[i + 1] if src == EPT_CON.ARG_CON or "gen" in self.mode: operand = self.out_idx2consym[operand] if type(operand) is str and operand.startswith(EPT_CON.MEM_PREFIX): operands.append((EPT_CON.ARG_MEM, int(operand[2:]))) else: operands.append((src, operand)) # Append an expression expressions.append((operator, operands)) computation_history = [] expression_used = [] #print("expressions", expressions) for operator, operands in expressions: # For each expression. computation = [] if operator == EPT_CON.FUN_NEW_VAR: # Generate new variable whenever __NEW_VAR() appears. computation.append(EPT_CON.FORMAT_VAR % len(computation_history)) else: # Otherwise, form an expression tree for src, operand in operands: # Find each operands from specified sources. if src == EPT_CON.ARG_NUM and "ptr" in self.mode: # If this is a number pointer, then replace it into number indices computation.append(EPT_CON.FORMAT_NUM % operand) elif src == EPT_CON.ARG_MEM: # If this indicates the result of prior expression, then replace it with prior results if operand < len(computation_history): computation += computation_history[operand] # Mark the prior expression as used. expression_used[operand] = True else: # Expression is not found, then use UNK. computation.append(EPT_CON.ARG_UNK) else: # Otherwise, this is a constant: append the operand itself. computation.append(operand) # To make it as a postfix representation, append operator at the last. computation.append(operator) # Save current expression into the history. computation_history.append(computation) expression_used.append(False) # Find unused computation history. These are the top-level formula. computation_history = [equation for used, equation in zip(expression_used, computation_history) if not used] result = sum(computation_history, []) replace_result = [] for word in result: if 'N_' in word: replace_result.append(str(num_list[int(word[2:])]['value'])) elif 'C_' in word: replace_result.append(str(word[2:].replace('_', '.'))) else: replace_result.append(word) if '=' in replace_result[:-1]: replace_result.append("<BRG>") return replace_result
def __str__(self) -> str: info = super().__str__() total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) parameters = "\ntotal parameters : {} \ntrainable parameters : {}".format(total, trainable) return info + parameters