Source code for mwptoolkit.model.PreTrain.gpt2

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/21 04:36:11
# @File: gpt2.py
from typing import Tuple, Dict, Any

import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Config, BertTokenizer, GPT2Tokenizer

from mwptoolkit.loss.nll_loss import NLLLoss
from mwptoolkit.utils.enum_type import SpecialTokens, NumMask, DatasetName


[docs]class GPT2(nn.Module): """ Reference: Radford et al. "Language Models are Unsupervised Multitask Learners". """ def __init__(self, config, dataset): super(GPT2, self).__init__() self.device = config["device"] self.max_out_len = config['max_output_len'] self.max_input_len = config["max_len"] self.pretrained_model_path = config['pretrained_model'] if config['pretrained_model'] else config[ 'transformers_pretrained_model'] self.tokenizer = dataset.tokenizer if config['dataset'] in [DatasetName.math23k, DatasetName.hmwp, DatasetName.ape200k]: # print ("tokenizer: ") self.eos_token_id = self.tokenizer.sep_token_id self.eos_token = self.tokenizer.sep_token self.start_token = self.tokenizer.cls_token else: self.eos_token_id = self.tokenizer.eos_token_id self.eos_token = self.tokenizer.eos_token self.start_token = '' self.configuration = GPT2Config.from_pretrained(self.pretrained_model_path) self.decoder = GPT2LMHeadModel.from_pretrained(self.pretrained_model_path, config=self.configuration) self._pretrained_model_resize() self.loss = NLLLoss() def _pretrained_model_resize(self): self.decoder.resize_token_embeddings(len(self.tokenizer))
[docs] def forward(self, seq, target=None,output_all_layers=False) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ :param torch.Tensor seq: input sequence, shape: [batch_size, seq_length]. :param torch.Tensor | None target: target, shape: [batch_size,target_length]. :param bool output_all_layers: return output of all layers if output_all_layers is True, default False. :return: token_logits: [batch_size, output_length, output_size], symbol_outputs: [batch_size,output_length], model_all_outputs. :rtype: tuple(torch.Tensor, torch.Tensor, dict) """ token_logits, symbol_outputs, decoder_layer_outputs = self.decoder_forward(seq, target, output_all_layers) model_all_outputs = {} if output_all_layers: model_all_outputs.update(decoder_layer_outputs) return token_logits, symbol_outputs, model_all_outputs
[docs] def calculate_loss(self, batch_data:dict)->float: """Finish forward-propagating, calculating loss and back-propagation. Args: batch_data (dict): one batch data. Returns: float: loss value. """ seq, target = batch_data["question"], batch_data["equation"] seq = torch.LongTensor(seq).to(self.device) target = torch.LongTensor(target).to(self.device) token_logits,_,_ = self.forward(seq, target) token_logits = token_logits.view(-1,token_logits.size(-1)) outputs = torch.nn.functional.log_softmax(token_logits, dim=1) self.loss.reset() self.loss.eval_batch(outputs, target.view(-1)) self.loss.backward() return self.loss.get_loss()
[docs] def model_test(self, batch_data:dict)->tuple: """Model test. Args: batch_data (dict): one batch data. Returns: tuple(list,list): predicted equation, target equation. """ seq = batch_data["question"] num_list = batch_data['num list'] target = batch_data['equation'] seq = torch.LongTensor(seq).to(self.device) target = torch.LongTensor(target).to(self.device) _, outputs, _ = self.forward(seq) outputs = self.decode_(outputs) target = self.decode_(target) outputs = self.convert_idx2symbol(outputs, num_list) targets = self.convert_idx2symbol(target, num_list) return outputs, targets
[docs] def predict(self, batch_data:dict, output_all_layers=False): """ predict samples without target. :param dict batch_data: one batch data. :param bool output_all_layers: return all layer outputs of model. :return: token_logits, symbol_outputs, all_layer_outputs """ seq = torch.tensor(batch_data['question']).to(self.device) token_logits, symbol_outputs, model_all_outputs = self.forward(seq, output_all_layers=output_all_layers) return token_logits, symbol_outputs, model_all_outputs
[docs] def list2str(self, x): y = ''.join(x) return y
[docs] def decoder_forward(self,seq,target=None,output_all_layers=False): if target is not None: tgts_inputs_tensor = target[:, :-1] # '[CLS] / * num_1 num_2 num_0 [SEP] tgts_outputs_tensor = target # '[CLS] / * num_1 num_2 num_0 [SEP] [SEP]' seq_mask = (tgts_inputs_tensor != self.eos_token_id).float() seq_mask = torch.cat([torch.FloatTensor(seq_mask.shape[0], 1).fill_(1.), seq_mask], 1) inputs = torch.cat([seq, tgts_inputs_tensor], 1) logits = self.decoder(inputs)[0] logits = logits[:, -tgts_outputs_tensor.shape[1]:, :].contiguous() outputs = torch.topk(logits,1,dim=-1)[1] else: outputs = [] logits = [] inputs = seq for idx in range(self.max_out_len): decoder_outputs = self.decoder(inputs) token_logit = decoder_outputs[0][:, -1, :] tokens = token_logit.topk(1, dim=1)[1] # mask=tokens==self.tokenizer.pad_token_id logits.append(token_logit) outputs.append(tokens) inputs = torch.cat((inputs, tokens), dim=1) logits = torch.stack(logits,dim=1) outputs = torch.cat(outputs, dim=1) # all_output = self.decode_(all_output) # print (all_output) # print ("all_output:", all_output.size()) all_layer_outputs = {} if output_all_layers: all_layer_outputs['token_logits']=logits all_layer_outputs['outputs']=outputs return logits,outputs,all_layer_outputs
[docs] def decode_(self, outputs): batch_size = outputs.size(0) all_outputs = [] for b in range(batch_size): symbols = self.tokenizer.decode(outputs[b]) symbols = self.tokenizer.tokenize(symbols) symbols_ = [] for token in symbols: if token == self.start_token: continue if 'Ġ' in token: symbols_.append(token[1:]) # if '/' == token[0] and len(token) == 2 and ('+' == token[1] or '-' == token[1] or '*' == token[1] or '/' == token[1]): # symbols_.append(token[0]) # symbols_.append(token[1:]) elif token == self.eos_token: break else: symbols_.append(token) symbols = symbols_[:] # print ("symbols",symbols) all_outputs.append(symbols) # print (all_outputs) return all_outputs
[docs] def encode_(self, inputs): outputs = [] for idx, s in enumerate(inputs): out = self.tokenizer.encode(inputs[idx]) outputs.append(out) output_length = max([len(_) for _ in outputs]) + 1 for i in range(len(outputs)): outputs[i] += (output_length - len(outputs[i])) * [self.eos_token_id] outputs_tensor = torch.LongTensor(outputs) return outputs_tensor
[docs] def convert_idx2symbol(self, outputs, num_lists): batch_size = len(outputs) output_list = [] for b_i in range(batch_size): num_len = len(num_lists[b_i]) res = [] if isinstance(outputs[b_i], str): output = outputs[b_i].split() else: output = outputs[b_i] for s_i in range(len(output)): symbol = output[s_i] if "NUM" in symbol: num_idx = NumMask.number.index(symbol) if num_idx >= num_len: res.append(symbol) else: res.append(num_lists[b_i][num_idx]) else: res.append(symbol) output_list.append(res) return output_list