Source code for mwptoolkit.data.dataset.pretrain_dataset

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 11:33:33
# @File: pretrain_dataset.py
import json
import os
import copy
import warnings
from logging import getLogger

import torch
from transformers import RobertaTokenizer, BertTokenizer

from mwptoolkit.config.configuration import Config
from mwptoolkit.data.dataset.abstract_dataset import AbstractDataset
from mwptoolkit.utils.enum_type import DatasetName, MaskSymbol, NumMask, TaskType, FixType, Operators, SpecialTokens
from mwptoolkit.utils.preprocess_tools import id_reedit
from mwptoolkit.utils.preprocess_tool.equation_operator import from_infix_to_multi_way_tree
from mwptoolkit.utils.preprocess_tool.equation_operator import from_infix_to_postfix, from_infix_to_prefix, \
    from_postfix_to_infix, from_postfix_to_prefix, from_prefix_to_infix, from_prefix_to_postfix
from mwptoolkit.utils.preprocess_tool.sentence_operator import deprel_tree_to_file, get_group_nums_, \
    span_level_deprel_tree_to_file, get_span_level_deprel_tree_, get_deprel_tree_
from mwptoolkit.utils.preprocess_tool.number_transfer import number_transfer
from mwptoolkit.utils.utils import read_json_data, write_json_data


[docs]class PretrainDataset(AbstractDataset): """dataset class for pre-train model. """ def __init__(self, config): """ Args: config (mwptoolkit.config.configuration.Config) expected that config includes these parameters below: task_type (str): [single_equation | multi_equation], the type of task. embedding (str|None): embedding module name, use pre-train model as embedding module, if None, not to use pre-train model. rule1 (bool): convert equation according to rule 1. rule2 (bool): convert equation according to rule 2. parse_tree_file_name (str|None): the name of the file to save parse tree information. pretrained_model or transformers_pretrained_model (str|None): road path or name of pretrained model. model (str): model name. dataset (str): dataset name. equation_fix (str): [infix | postfix | prefix], convert equation to specified format. dataset_dir or dataset_path (str): the road path of dataset folder. language (str): a property of dataset, the language of dataset. single (bool): a property of dataset, the equation of dataset is single or not. linear (bool): a property of dataset, the equation of dataset is linear or not. source_equation_fix (str): [infix | postfix | prefix], a property of dataset, the source format of equation of dataset. rebuild (bool): when loading additional dataset information, this can decide to build information anew or load information built before. validset_divide (bool): whether to split validset. if True, the dataset is split to trainset-validset-testset. if False, the dataset is split to trainset-testset. mask_symbol (str): [NUM | number], the symbol to mask numbers in equation. min_word_keep (int): in dataset, words that count greater than the value, will be kept in input vocabulary. min_generate_keep (int): generate number that count greater than the value, will be kept in output symbols. symbol_for_tree (bool): build output symbols for tree or not. share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models. k_fold (int|None): if it's an integer, it indicates to run k-fold cross validation. if it's None, it indicates to run trainset-validset-testset split. read_local_folds (bool): when running k-fold cross validation, if True, then loading split folds from dataset folder. if False, randomly split folds. shuffle (bool): whether to shuffle trainset before training. device (torch.device): resume_training or resume (bool): """ super().__init__(config) self.task_type = config['task_type'] self.rule1 = config['rule1'] self.rule2 = config['rule2'] self.vocab_level = config['vocab_level'] self.embedding = config['embedding'] self.pretrained_model_path = config["transformers_pretrained_model"] if config[ "transformers_pretrained_model"] else config["pretrained_model"] self.parse_tree_path = config['parse_tree_file_name'] self.add_num_symbol = config['add_num_symbol'] if self.parse_tree_path is not None: self.parse_tree_path = os.path.join(self.dataset_path,self.parse_tree_path + '.json') if not os.path.isabs(self.parse_tree_path): self.parse_tree_path = os.path.join(os.getcwd(), self.parse_tree_path) if self.vocab_level is None: self.vocab_level = 'word' def _preprocess(self): if self.dataset in [DatasetName.hmwp]: self.trainset, self.validset, self.testset = id_reedit(self.trainset, self.validset, self.testset) transfer = number_transfer self.trainset, generate_list, train_copy_nums, unk_symbol = transfer(self.trainset, self.dataset, self.task_type, self.mask_symbol, self.min_generate_keep,self.linear, ";", self.vocab_level) self.validset, _g, valid_copy_nums, _ = transfer(self.validset, self.dataset, self.task_type, self.mask_symbol, self.min_generate_keep,self.linear, ";", self.vocab_level) self.testset, _g, test_copy_nums, _ = transfer(self.testset, self.dataset, self.task_type, self.mask_symbol, self.min_generate_keep,self.linear, ";", self.vocab_level) target_equation_fix = self.equation_fix if self.equation_fix else FixType.Infix source_equation_fix = self.source_equation_fix if self.source_equation_fix else FixType.Infix if self.rule1: if source_equation_fix != FixType.Infix: warnings.warn("non-infix-equation datasets may not support EN rule1 process, already ignored it. ") elif self.linear and self.single: self.en_rule1_process(k=max([train_copy_nums, valid_copy_nums, test_copy_nums])) else: warnings.warn( "non-linear or non-single datasets may not support EN rule1 process, already ignored it. ") # raise Warning("non-linear or non-single datasets may not surport en rule1 process, already ignored it. ") if self.rule2: if source_equation_fix != FixType.Infix: warnings.warn("non-infix-equation datasets may not support EN rule2 process, already ignored it. ") elif self.linear and self.single: self.en_rule2_process() else: warnings.warn( "non-linear or non-single datasets may not support EN rule2 process, already ignored it. ") # raise Warning("non-linear or non-single datasets may not surport en rule2 process, already ignored it. ") if source_equation_fix == target_equation_fix: fix = None elif source_equation_fix == FixType.Infix and target_equation_fix == FixType.Prefix: fix = from_infix_to_prefix elif source_equation_fix == FixType.Infix and target_equation_fix == FixType.Postfix: fix = from_infix_to_postfix elif source_equation_fix == FixType.Prefix and target_equation_fix == FixType.Postfix: fix = from_prefix_to_postfix elif source_equation_fix == FixType.Prefix and target_equation_fix == FixType.Infix: fix = from_prefix_to_infix elif source_equation_fix == FixType.Postfix and target_equation_fix == FixType.Infix: fix = from_postfix_to_infix elif source_equation_fix == FixType.Postfix and target_equation_fix == FixType.Prefix: fix = from_postfix_to_prefix elif source_equation_fix == FixType.Infix and target_equation_fix == FixType.MultiWayTree: fix = from_infix_to_multi_way_tree else: raise NotImplementedError("the type of equation fix ({}) is not implemented.".format(self.equation_fix)) self.fix_process(fix) self.operator_mask_process() generate_list = unk_symbol + generate_list if self.symbol_for_tree: copy_nums = max([train_copy_nums, valid_copy_nums, test_copy_nums]) else: copy_nums = train_copy_nums if self.task_type == TaskType.SingleEquation: operator_list = copy.deepcopy(Operators.Single) elif self.task_type == TaskType.MultiEquation: operator_list = copy.deepcopy(Operators.Multi) else: raise NotImplementedError if self.dataset in [DatasetName.mawps]: operator_list.append('=') operator_nums = len(operator_list) # graph preprocess use_gpu = True if self.device == torch.device('cuda') else False if self.model.lower() in ['graph2treeibm']: if os.path.exists(self.parse_tree_path) and not self.rebuild: logger = getLogger() logger.info("read deprel tree infomation from {} ...".format(self.parse_tree_path)) self.trainset, self.validset, self.testset, token_list = \ get_deprel_tree_(self.trainset, self.validset, self.testset, self.parse_tree_path) else: logger = getLogger() logger.info("build deprel tree infomation to {} ...".format(self.parse_tree_path)) deprel_tree_to_file(self.trainset, self.validset, self.testset, \ self.parse_tree_path, self.language, use_gpu) self.trainset, self.validset, self.testset, token_list = \ get_deprel_tree_(self.trainset, self.validset, self.testset, self.parse_tree_path) if self.model.lower() in ['hms']: if os.path.exists(self.parse_tree_path) and not self.rebuild: logger = getLogger() logger.info("read span-level deprel tree infomation from {} ...".format(self.parse_tree_path)) self.trainset, self.validset, self.testset, self.max_span_size = \ get_span_level_deprel_tree_(self.trainset, self.validset, self.testset, self.parse_tree_path) else: logger = getLogger() logger.info("build span-level deprel tree infomation to {} ...".format(self.parse_tree_path)) span_level_deprel_tree_to_file(self.trainset, self.validset, self.testset, \ self.parse_tree_path, self.language, use_gpu) self.trainset, self.validset, self.testset, self.max_span_size = \ get_span_level_deprel_tree_(self.trainset, self.validset, self.testset, self.parse_tree_path) if self.model.lower() in ['graph2tree']: if os.path.exists(self.parse_tree_path) and not self.rebuild: logger = getLogger() logger.info("read deprel tree infomation from {} ...".format(self.parse_tree_path)) self.trainset, self.validset, self.testset = \ get_group_nums_(self.trainset, self.validset, self.testset, self.parse_tree_path) else: logger = getLogger() logger.info("build deprel tree infomation to {} ...".format(self.parse_tree_path)) deprel_tree_to_file(self.trainset, self.validset, self.testset, \ self.parse_tree_path, self.language, use_gpu) self.trainset, self.validset, self.testset = \ get_group_nums_(self.trainset, self.validset, self.testset, self.parse_tree_path) return {'generate_list': generate_list, 'copy_nums': copy_nums, 'operator_list': operator_list, 'operator_nums': operator_nums} def _build_vocab(self): if self.embedding == 'bert': tokenizer = BertTokenizer.from_pretrained(self.pretrained_model_path) elif self.embedding == 'roberta': tokenizer = RobertaTokenizer.from_pretrained(self.pretrained_model_path) else: raise NotImplementedError if self.add_num_symbol: if self.mask_symbol == MaskSymbol.NUM: tokenizer.add_tokens(['NUM'],True) elif self.mask_symbol == MaskSymbol.number: tokenizer.add_tokens(NumMask.number[:self.copy_nums],True) #tokenizer.add_special_tokens(token_dict) if self.model.lower() in ['trnn']: tokenizer.add_tokens(self.generate_list) # tokenizer.add_special_tokens({'additional_special_token':["NUM"]}) # global SpecialTokens SpecialTokens.PAD_TOKEN = tokenizer.pad_token SpecialTokens.SOS_TOKEN = tokenizer.bos_token SpecialTokens.EOS_TOKEN = tokenizer.eos_token SpecialTokens.UNK_TOKEN = tokenizer.unk_token if self.embedding == 'bert': SpecialTokens.SOS_TOKEN = tokenizer.cls_token SpecialTokens.EOS_TOKEN = tokenizer.sep_token in_idx2word = list(tokenizer.get_vocab().keys()) if self.symbol_for_tree: equ_dict = self._build_symbol_for_tree() temp_dict = self._build_template_symbol_for_tree() elif self.equation_fix == FixType.MultiWayTree: equ_dict = self._build_symbol_for_multi_way_tree() temp_dict = self._build_template_symbol_for_multi_way_tree() else: equ_dict = self._build_symbol() temp_dict = self._build_template_symbol() out_idx2symbol = equ_dict['out_idx2symbol'] temp_idx2symbol = temp_dict['temp_idx2symbol'] num_start = equ_dict['num_start'] temp_num_start = temp_dict['temp_num_start'] in_word2idx = {} out_symbol2idx = {} temp_symbol2idx = {} for idx, word in enumerate(in_idx2word): in_word2idx[word] = idx for idx, symbol in enumerate(out_idx2symbol): out_symbol2idx[symbol] = idx for idx, symbol in enumerate(temp_idx2symbol): temp_symbol2idx[symbol] = idx return {'in_idx2word': in_idx2word, 'in_word2idx': in_word2idx, 'out_idx2symbol': out_idx2symbol, 'temp_idx2symbol': temp_idx2symbol, 'out_symbol2idx': out_symbol2idx, 'temp_symbol2idx': temp_symbol2idx, 'num_start': num_start, 'temp_num_start': temp_num_start, 'tokenizer': tokenizer} def _build_symbol_for_tree(self): out_idx2symbol = copy.deepcopy(self.operator_list) num_start = len(out_idx2symbol) out_idx2symbol += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) out_idx2symbol += [SpecialTokens.UNK_TOKEN] return {'out_idx2symbol': out_idx2symbol, 'num_start': num_start} def _build_symbol_for_multi_way_tree(self): out_idx2symbol = [SpecialTokens.PAD_TOKEN, SpecialTokens.SOS_TOKEN, SpecialTokens.EOS_TOKEN, SpecialTokens.NON_TOKEN] out_idx2symbol += Operators.Single num_start = len(out_idx2symbol) out_idx2symbol += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) out_idx2symbol += [SpecialTokens.UNK_TOKEN] return {'out_idx2symbol': out_idx2symbol, 'num_start': num_start} def _build_symbol(self): if self.share_vocab: out_idx2symbol = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.EOS_TOKEN] + self.operator_list else: out_idx2symbol = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.SOS_TOKEN] + [ SpecialTokens.EOS_TOKEN] + self.operator_list if self.model.lower() in ['hms']: out_idx2symbol = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.EOS_TOKEN] + self.operator_list num_start = len(out_idx2symbol) out_idx2symbol += self.generate_list if self.model.lower() in ['hms']: out_idx2symbol += [SpecialTokens.UNK_TOKEN] if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.generate_list)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: out_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.generate_list)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) for data in self.trainset: words_list = data["equation"] for word in words_list: if word in out_idx2symbol: continue elif word[0].isdigit(): continue elif (word[0].isalpha() or word[0].isdigit()) is not True: out_idx2symbol.insert(num_start, word) num_start += 1 continue else: out_idx2symbol.append(word) if self.model.lower() in ['hms']: return {'out_idx2symbol': out_idx2symbol, 'num_start': num_start} out_idx2symbol += [SpecialTokens.UNK_TOKEN] return {'out_idx2symbol': out_idx2symbol, 'num_start': num_start} def _build_template_symbol(self): if self.share_vocab: temp_idx2symbol = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.EOS_TOKEN] + [SpecialTokens.OPT_TOKEN] else: temp_idx2symbol = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.SOS_TOKEN] + [SpecialTokens.EOS_TOKEN] + [ SpecialTokens.OPT_TOKEN] temp_num_start = len(temp_idx2symbol) temp_idx2symbol += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) for data in self.trainset: words_list = data["template"] for word in words_list: if word in temp_idx2symbol: continue elif word[0].isdigit(): continue elif (word[0].isalpha() or word[0].isdigit()) is not True: temp_idx2symbol.insert(temp_num_start, word) temp_num_start += 1 continue else: temp_idx2symbol.append(word) temp_idx2symbol += [SpecialTokens.UNK_TOKEN] return {'temp_idx2symbol': temp_idx2symbol, 'temp_num_start': temp_num_start} def _build_template_symbol_for_multi_way_tree(self): temp_idx2symbol = [SpecialTokens.PAD_TOKEN, SpecialTokens.SOS_TOKEN, SpecialTokens.EOS_TOKEN, SpecialTokens.NON_TOKEN] temp_idx2symbol += [SpecialTokens.OPT_TOKEN] temp_num_start = len(temp_idx2symbol) temp_idx2symbol += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) temp_idx2symbol += [SpecialTokens.UNK_TOKEN] return {'temp_idx2symbol': temp_idx2symbol, 'temp_num_start': temp_num_start} def _build_template_symbol_for_tree(self): temp_idx2symbol = [SpecialTokens.OPT_TOKEN] temp_num_start = len(temp_idx2symbol) temp_idx2symbol += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: temp_idx2symbol += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) temp_idx2symbol += [SpecialTokens.UNK_TOKEN] return {'temp_idx2symbol': temp_idx2symbol, 'temp_num_start': temp_num_start} def _update_vocab(self, vocab_list): index = len(self.in_idx2word) for word in vocab_list: if word not in self.in_idx2word: self.in_idx2word.append(word) self.in_word2idx[word] = index index += 1
[docs] def get_vocab_size(self): """ Returns: (tuple(int, int)): the length of input vocabulary and output symbols """ return len(self.in_idx2word), len(self.out_idx2symbol)
[docs] def save_dataset(self, save_dir: str): """ save dataset parameters to file. :param save_dir: (str) folder which saves the parameter file :return: """ if not os.path.exists(save_dir): os.mkdir(save_dir) input_vocab_file = os.path.join(save_dir, 'input_vocab.json') write_json_data( {'in_idx2word': self.in_idx2word}, input_vocab_file ) output_vocab_file = os.path.join(save_dir, 'output_vocab.json') write_json_data( { 'out_idx2symbol': self.out_idx2symbol, 'temp_idx2symbol': self.temp_idx2symbol }, output_vocab_file ) data_id_file = os.path.join(save_dir, 'data_split.json') write_json_data( { 'trainset_id': self.trainset_id, 'validset_id': self.validset_id, 'testset_id': self.testset_id, 'folds_id': self.folds_id }, data_id_file ) json_encoder = json.encoder.JSONEncoder() parameters_dict = self.parameters_to_dict() not_support_json = [] not_to_save = ['in_idx2word', 'out_idx2symbol', 'temp_idx2symbol', 'in_word2idx', 'out_symbol2idx', 'temp_symbol2idx', 'folds', 'trainset', 'testset', 'validset', 'datas', 'trainset_id', 'validset_id', 'testset_id', 'folds_id'] for key, value in parameters_dict.items(): try: json_encoder.encode({key: value}) except TypeError: not_support_json.append(key) for key in not_support_json: del parameters_dict[key] for key in not_to_save: del parameters_dict[key] parameter_file = os.path.join(save_dir, 'dataset.json') write_json_data(parameters_dict, parameter_file)
[docs] @classmethod def load_from_pretrained(cls, pretrained_dir: str, resume_training=False): """ load dataset parameters from file. :param pretrained_dir: (str) folder which saved the parameter file :param resume_training: (bool) load parameter for resuming training or not. :return: an instantiated object """ config = Config.load_from_pretrained(pretrained_dir) dataset = PretrainDataset(config) input_vocab_file = os.path.join(pretrained_dir, 'input_vocab.json') output_vocab_file = os.path.join(pretrained_dir, 'output_vocab.json') parameter_file = os.path.join(pretrained_dir, 'dataset.json') data_id_file = os.path.join(pretrained_dir, 'data_split.json') input_vocab = read_json_data(input_vocab_file) output_vocab = read_json_data(output_vocab_file) parameter_dict = read_json_data(parameter_file) data_id_dict = read_json_data(data_id_file) in_idx2word = input_vocab['in_idx2word'] out_idx2symbol = output_vocab['out_idx2symbol'] temp_idx2symbol = output_vocab['temp_idx2symbol'] in_word2idx = {} out_symbol2idx = {} temp_symbol2idx = {} for idx, word in enumerate(in_idx2word): in_word2idx[word] = idx for idx, symbol in enumerate(out_idx2symbol): out_symbol2idx[symbol] = idx for idx, symbol in enumerate(temp_idx2symbol): temp_symbol2idx[symbol] = idx setattr(dataset, 'in_idx2word', in_idx2word) setattr(dataset, 'out_idx2symbol', out_idx2symbol) setattr(dataset, 'temp_idx2symbol', temp_idx2symbol) setattr(dataset, 'in_word2idx', in_word2idx) setattr(dataset, 'out_symbol2idx', out_symbol2idx) setattr(dataset, 'temp_symbol2idx', temp_symbol2idx) for key, value in parameter_dict.items(): setattr(dataset, key, value) for key, value in data_id_dict.items(): setattr(dataset, key, value) if resume_training: if config['k_fold']: setattr(dataset, 'fold_t', config['fold_t']) setattr(dataset, 'the_fold_t', config['fold_t'] - 1) setattr(dataset, 'from_pretrained', False) setattr(dataset, 'pretrained_dir', pretrained_dir) setattr(dataset, 'resume_training', resume_training) else: setattr(dataset, 'from_pretrained', False) setattr(dataset, 'pretrained_dir', pretrained_dir) setattr(dataset, 'resume_training', resume_training) else: setattr(dataset, 'from_pretrained', True) setattr(dataset, 'pretrained_dir', pretrained_dir) if dataset.embedding == 'bert': tokenizer = BertTokenizer.from_pretrained(dataset.pretrained_model_path) elif dataset.embedding == 'roberta': tokenizer = RobertaTokenizer.from_pretrained(dataset.pretrained_model_path) else: raise NotImplementedError if dataset.add_num_symbol: if dataset.mask_symbol == MaskSymbol.NUM: tokenizer.add_tokens(['NUM'], True) elif dataset.mask_symbol == MaskSymbol.number: tokenizer.add_tokens(NumMask.number[:dataset.copy_nums], True) if dataset.model.lower() in ['trnn']: tokenizer.add_tokens(dataset.generate_list) setattr(dataset, 'tokenizer', tokenizer) dataset.reset_dataset() return dataset
def __load_pretrained_parameters(self): if self.k_fold: load_dir = os.path.join(self.pretrained_dir, 'fold{}'.format(self.fold_t)) else: load_dir = self.pretrained_dir input_vocab_file = os.path.join(load_dir, 'input_vocab.json') output_vocab_file = os.path.join(load_dir, 'output_vocab.json') parameter_file = os.path.join(load_dir, 'dataset.json') input_vocab = read_json_data(input_vocab_file) output_vocab = read_json_data(output_vocab_file) parameter_dict = read_json_data(parameter_file) in_idx2word = input_vocab['in_idx2word'] out_idx2symbol = output_vocab['out_idx2symbol'] temp_idx2symbol = output_vocab['temp_idx2symbol'] in_word2idx = {} out_symbol2idx = {} temp_symbol2idx = {} for idx, word in enumerate(in_idx2word): in_word2idx[word] = idx for idx, symbol in enumerate(out_idx2symbol): out_symbol2idx[symbol] = idx for idx, symbol in enumerate(temp_idx2symbol): temp_symbol2idx[symbol] = idx setattr(self, 'in_idx2word', in_idx2word) setattr(self, 'out_idx2symbol', out_idx2symbol) setattr(self, 'temp_idx2symbol', temp_idx2symbol) setattr(self, 'in_word2idx', in_word2idx) setattr(self, 'out_symbol2idx', out_symbol2idx) setattr(self, 'temp_symbol2idx', temp_symbol2idx) for key, value in parameter_dict.items(): setattr(self, key, value)