Source code for mwptoolkit.data.dataset.dataset_multiencdec

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 11:33:11
# @File: dataset_multiencdec.py
import json
import os
import copy
from logging import getLogger
import stanza

from mwptoolkit.config.configuration import Config
from mwptoolkit.data.dataset.template_dataset import TemplateDataset
from mwptoolkit.utils.enum_type import NumMask, SpecialTokens, FixType, Operators, MaskSymbol, SPECIAL_TOKENS, \
    DatasetName, TaskType
from mwptoolkit.utils.preprocess_tool.equation_operator import from_infix_to_postfix, from_infix_to_prefix, \
    from_postfix_to_infix, from_postfix_to_prefix, from_prefix_to_infix, from_prefix_to_postfix
from mwptoolkit.utils.preprocess_tools import id_reedit, dataset_drop_duplication
from mwptoolkit.utils.preprocess_tool.number_transfer import number_transfer
from mwptoolkit.utils.utils import read_json_data, write_json_data


[docs]class DatasetMultiEncDec(TemplateDataset): """dataset class for deep-learning model MultiE&D """ def __init__(self, config): """ Args: config (mwptoolkit.config.configuration.Config) expected that config includes these parameters below: task_type (str): [single_equation | multi_equation], the type of task. parse_tree_file_name (str|None): the name of the file to save parse tree information. ltp_model_dir or ltp_model_path (str|None): the road path of ltp model. model (str): model name. dataset (str): dataset name. equation_fix (str): [infix | postfix | prefix], convert equation to specified format. dataset_dir or dataset_path (str): the road path of dataset folder. language (str): a property of dataset, the language of dataset. single (bool): a property of dataset, the equation of dataset is single or not. linear (bool): a property of dataset, the equation of dataset is linear or not. source_equation_fix (str): [infix | postfix | prefix], a property of dataset, the source format of equation of dataset. rebuild (bool): when loading additional dataset information, this can decide to build information anew or load information built before. validset_divide (bool): whether to split validset. if True, the dataset is split to trainset-validset-testset. if False, the dataset is split to trainset-testset. mask_symbol (str): [NUM | number], the symbol to mask numbers in equation. min_word_keep (int): in dataset, words that count greater than the value, will be kept in input vocabulary. min_generate_keep (int): generate number that count greater than the value, will be kept in output symbols. symbol_for_tree (bool): build output symbols for tree or not. share_vocab (bool): encoder and decoder of the model share the same vocabulary, often seen in Seq2Seq models. k_fold (int|None): if it's an integer, it indicates to run k-fold cross validation. if it's None, it indicates to run trainset-validset-testset split. read_local_folds (bool): when running k-fold cross validation, if True, then loading split folds from dataset folder. if False, randomly split folds. shuffle (bool): whether to shuffle trainset before training. device (torch.device): resume_training or resume (bool): """ super().__init__(config) self.task_type = config['task_type'] self.parse_tree_path = config['parse_tree_file_name'] if self.parse_tree_path is not None: self.parse_tree_path = os.path.join(self.dataset_path, self.parse_tree_path + '.json') if not os.path.isabs(self.parse_tree_path): self.parse_tree_path = os.path.join(os.getcwd(), self.parse_tree_path) self.ltp_model_path = config['ltp_model_dir'] if config['ltp_model_dir'] else config['ltp_model_path'] if self.ltp_model_path and not os.path.isabs(self.ltp_model_path): self.ltp_model_path = os.path.join(os.getcwd(), self.ltp_model_path) def _preprocess(self): if self.dataset in [DatasetName.hmwp]: self.trainset, self.validset, self.testset = id_reedit(self.trainset, self.validset, self.testset) if self.dataset in [DatasetName.draw]: self.trainset, self.validset, self.testset = dataset_drop_duplication(self.trainset, self.validset, self.testset) transfer = number_transfer self.trainset, generate_list, train_copy_nums, unk_symbol = transfer(self.trainset, self.dataset, self.task_type, self.mask_symbol, self.min_generate_keep,self.linear, ";") self.validset, _g, valid_copy_nums, _ = transfer(self.validset, self.dataset, self.task_type, self.mask_symbol, self.min_generate_keep,self.linear, ";") self.testset, _g, test_copy_nums, _ = transfer(self.testset, self.dataset, self.task_type, self.mask_symbol, self.min_generate_keep,self.linear, ";") source_equation_fix = self.source_equation_fix if self.source_equation_fix else FixType.Infix if source_equation_fix == FixType.Infix: to_infix = None to_prefix = from_infix_to_prefix to_postfix = from_infix_to_postfix elif source_equation_fix == FixType.Prefix: to_infix = from_prefix_to_infix to_prefix = None to_postfix = from_prefix_to_postfix elif source_equation_fix == FixType.Postfix: to_infix = from_postfix_to_infix to_prefix = from_postfix_to_prefix to_postfix = None else: raise NotImplementedError() for idx, data in enumerate(self.trainset): if to_infix: self.trainset[idx]["infix equation"] = to_infix(data["equation"]) else: self.trainset[idx]["infix equation"] = data["equation"] if to_postfix: self.trainset[idx]["postfix equation"] = to_postfix(data["equation"]) else: self.trainset[idx]["postfix equation"] = data["equation"] if to_prefix: self.trainset[idx]["prefix equation"] = to_prefix(data["equation"]) else: self.trainset[idx]["prefix equation"] = data["equation"] for idx, data in enumerate(self.validset): if to_infix: self.validset[idx]["infix equation"] = to_infix(data["equation"]) else: self.validset[idx]["infix equation"] = data["equation"] if to_postfix: self.validset[idx]["postfix equation"] = to_postfix(data["equation"]) else: self.validset[idx]["postfix equation"] = data["equation"] if to_prefix: self.validset[idx]["prefix equation"] = to_prefix(data["equation"]) else: self.validset[idx]["prefix equation"] = data["equation"] for idx, data in enumerate(self.testset): if to_infix: self.testset[idx]["infix equation"] = to_infix(data["equation"]) else: self.testset[idx]["infix equation"] = data["equation"] if to_postfix: self.testset[idx]["postfix equation"] = to_postfix(data["equation"]) else: self.testset[idx]["postfix equation"] = data["equation"] if to_prefix: self.testset[idx]["prefix equation"] = to_prefix(data["equation"]) else: self.testset[idx]["prefix equation"] = data["equation"] generate_list = unk_symbol + generate_list if self.symbol_for_tree: copy_nums = max([train_copy_nums, valid_copy_nums, test_copy_nums]) else: copy_nums = train_copy_nums if self.task_type == TaskType.SingleEquation: operator_list = copy.deepcopy(Operators.Single) if self.dataset in [DatasetName.mawps]: operator_list.append('=') operator_nums = len(operator_list) elif self.task_type == TaskType.MultiEquation: operator_nums = len(Operators.Multi) operator_list = copy.deepcopy(Operators.Multi) else: raise NotImplementedError if os.path.exists(self.parse_tree_path) and not self.rebuild: logger = getLogger() logger.info('read pos information from {} ...'.format(self.parse_tree_path)) self.read_pos_from_file(self.parse_tree_path) else: logger = getLogger() logger.info('build pos information to {} ...'.format(self.parse_tree_path)) if self.language == 'zh': try: import pyltp self.build_pos_to_file_with_pyltp(self.parse_tree_path) except: self.build_pos_to_file_with_stanza(self.parse_tree_path) else: self.build_pos_to_file_with_stanza(self.parse_tree_path) self.read_pos_from_file(self.parse_tree_path) return {'generate_list': generate_list, 'copy_nums': copy_nums, 'operator_list': operator_list, 'operator_nums': operator_nums} def _build_vocab(self): words_count_1 = {} for data in self.trainset: words_list = data["question"] for word in words_list: try: words_count_1[word] += 1 except: words_count_1[word] = 1 in_idx2word_1 = [SpecialTokens.PAD_TOKEN, SpecialTokens.UNK_TOKEN] for key, value in words_count_1.items(): if value > self.min_word_keep or "NUM" in key: in_idx2word_1.append(key) words_count_2 = {} for data in self.trainset: words_list = data["pos"] for word in words_list: try: words_count_2[word] += 1 except: words_count_2[word] = 1 in_idx2word_2 = [SpecialTokens.PAD_TOKEN, SpecialTokens.UNK_TOKEN] for key, value in words_count_2.items(): if value > self.min_word_keep: in_idx2word_2.append(key) equ_dict_2 = self._build_symbol() equ_dict_1 = self._build_symbol_for_tree() out_idx2symbol_2 = equ_dict_2['out_idx2symbol_2'] out_idx2symbol_1 = equ_dict_1['out_idx2symbol_1'] num_start1 = equ_dict_1['num_start1'] num_start2 = equ_dict_2['num_start2'] in_word2idx_1 = {} in_word2idx_2 = {} out_symbol2idx_1 = {} out_symbol2idx_2 = {} for idx, word in enumerate(in_idx2word_1): in_word2idx_1[word] = idx for idx, word in enumerate(in_idx2word_2): in_word2idx_2[word] = idx for idx, symbol in enumerate(out_idx2symbol_1): out_symbol2idx_1[symbol] = idx for idx, symbol in enumerate(out_idx2symbol_2): out_symbol2idx_2[symbol] = idx return {'in_idx2word_1': in_idx2word_1, 'in_idx2word_2': in_idx2word_2, 'in_word2idx_1': in_word2idx_1, 'in_word2idx_2': in_word2idx_2, 'out_idx2symbol_1': out_idx2symbol_1, 'out_symbol2idx_1': out_symbol2idx_1, 'out_idx2symbol_2': out_idx2symbol_2, 'out_symbol2idx_2': out_symbol2idx_2, 'num_start1': num_start1, 'num_start2': num_start2, } def _build_symbol(self): if self.share_vocab: out_idx2symbol_2 = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.EOS_TOKEN] + self.operator_list else: out_idx2symbol_2 = [SpecialTokens.PAD_TOKEN] + [SpecialTokens.EOS_TOKEN] + self.operator_list num_start2 = len(out_idx2symbol_2) out_idx2symbol_2 += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: out_idx2symbol_2 += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.generate_list)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: out_idx2symbol_2 += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: out_idx2symbol_2 += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.generate_list)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) out_idx2symbol_2 += [SpecialTokens.SOS_TOKEN] out_idx2symbol_2 += [SpecialTokens.UNK_TOKEN] return {'out_idx2symbol_2': out_idx2symbol_2, 'num_start2': num_start2} def _build_symbol_for_tree(self): out_idx2symbol_1 = copy.deepcopy(self.operator_list) num_start1 = len(out_idx2symbol_1) out_idx2symbol_1 += self.generate_list if self.mask_symbol == MaskSymbol.NUM: mask_list = NumMask.number try: out_idx2symbol_1 += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) elif self.mask_symbol == MaskSymbol.alphabet: mask_list = NumMask.alphabet try: out_idx2symbol_1 += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError( "alphabet may not enough to mask {} numbers, changing the mask_symbol from alphabet to number may solve the problem.".format( self.copy_nums)) elif self.mask_symbol == MaskSymbol.number: mask_list = NumMask.number try: out_idx2symbol_1 += [mask_list[i] for i in range(self.copy_nums)] except IndexError: raise IndexError("{} numbers is not enough to mask {} numbers ".format(len(mask_list), self.copy_nums)) else: raise NotImplementedError("the type of masking number ({}) is not implemented".format(self.mask_symbol)) out_idx2symbol_1 += [SpecialTokens.UNK_TOKEN] return {'out_idx2symbol_1': out_idx2symbol_1, 'num_start1': num_start1}
[docs] def build_pos_to_file_with_stanza(self, path): nlp = stanza.Pipeline(self.language, processors='depparse,tokenize,pos,lemma', tokenize_pretokenized=True, logging_level='error') new_datas = [] for data in self.trainset: doc = nlp(data["ques source 1"]) token_list = doc.to_dict()[0] pos = [] parse_tree = [] for token in token_list: # pos.append(token['xpos']) pos.append(token['upos']) parse_tree.append(token['head'] - 1) new_datas.append({'id': data['id'], 'pos': pos, 'parse tree': parse_tree}) for data in self.validset: doc = nlp(data["ques source 1"]) token_list = doc.to_dict()[0] pos = [] parse_tree = [] for token in token_list: pos.append(token['upos']) parse_tree.append(token['head'] - 1) new_datas.append({'id': data['id'], 'pos': pos, 'parse tree': parse_tree}) for data in self.testset: doc = nlp(data["ques source 1"]) token_list = doc.to_dict()[0] pos = [] parse_tree = [] for token in token_list: pos.append(token['upos']) parse_tree.append(token['head'] - 1) new_datas.append({'id': data['id'], 'pos': pos, 'parse tree': parse_tree}) write_json_data(new_datas, path)
[docs] def build_pos_to_file_with_pyltp(self, path): from pyltp import Postagger, Parser pos_model_path = os.path.join(self.ltp_model_path, "pos.model") par_model_path = os.path.join(self.ltp_model_path, 'parser.model') postagger = Postagger() postagger.load(pos_model_path) parser = Parser() parser.load(par_model_path) new_datas = [] for data in self.trainset: postags = postagger.postag(data["ques source 1"].split(' ')) postags = ' '.join(postags).split(' ') arcs = parser.parse(data["ques source 1"].split(' '), postags) parse_tree = [arc.head - 1 for arc in arcs] new_datas.append({'id': data['id'], 'pos': postags, 'parse tree': parse_tree}) for data in self.validset: postags = postagger.postag(data["ques source 1"].split(' ')) postags = ' '.join(postags).split(' ') arcs = parser.parse(data["ques source 1"].split(' '), postags) parse_tree = [arc.head - 1 for arc in arcs] new_datas.append({'id': data['id'], 'pos': postags, 'parse tree': parse_tree}) for data in self.testset: postags = postagger.postag(data["ques source 1"].split(' ')) postags = ' '.join(postags).split(' ') arcs = parser.parse(data["ques source 1"].split(' '), postags) parse_tree = [arc.head - 1 for arc in arcs] new_datas.append({'id': data['id'], 'pos': postags, 'parse tree': parse_tree}) write_json_data(new_datas, path)
[docs] def read_pos_from_file(self, path): pos_datas = read_json_data(path) for data in self.trainset: for pos_data in pos_datas: if pos_data['id'] != data['id']: continue else: data['pos'] = pos_data['pos'] data['parse tree'] = pos_data['parse tree'] pos_datas.remove(pos_data) break for data in self.validset: for pos_data in pos_datas: if pos_data['id'] != data['id']: continue else: data['pos'] = pos_data['pos'] data['parse tree'] = pos_data['parse tree'] pos_datas.remove(pos_data) break for data in self.testset: for pos_data in pos_datas: if pos_data['id'] != data['id']: continue else: data['pos'] = pos_data['pos'] data['parse tree'] = pos_data['parse tree'] pos_datas.remove(pos_data) break
[docs] def save_dataset(self, save_dir: str): """ save dataset parameters to file. :param save_dir: (str) folder which saves the parameter file :return: """ if not os.path.exists(save_dir): os.mkdir(save_dir) input_vocab_file = os.path.join(save_dir, 'input_vocab.json') write_json_data( { 'in_idx2word_1': self.in_idx2word_1, 'in_idx2word_2': self.in_idx2word_2 }, input_vocab_file ) output_vocab_file = os.path.join(save_dir, 'output_vocab.json') write_json_data( { 'out_idx2symbol_1': self.out_idx2symbol_1, 'out_idx2symbol_2': self.out_idx2symbol_2 }, output_vocab_file ) data_id_file = os.path.join(save_dir, 'data_split.json') write_json_data( { 'trainset_id': self.trainset_id, 'validset_id': self.validset_id, 'testset_id': self.testset_id, 'folds_id': self.folds_id }, data_id_file ) json_encoder = json.encoder.JSONEncoder() parameters_dict = self.parameters_to_dict() not_support_json = [] not_to_save = ['in_idx2word_1', 'in_idx2word_2', 'out_idx2symbol_1', 'out_idx2symbol_2', 'in_word2idx_1', 'in_word2idx_2', 'out_symbol2idx_1', 'out_symbol2idx_2', 'folds', 'trainset', 'testset', 'validset', 'datas', 'trainset_id', 'validset_id', 'testset_id', 'folds_id'] for key, value in parameters_dict.items(): try: json_encoder.encode({key: value}) except TypeError: not_support_json.append(key) for key in not_support_json: del parameters_dict[key] for key in not_to_save: del parameters_dict[key] parameter_file = os.path.join(save_dir, 'dataset.json') write_json_data(parameters_dict, parameter_file)
[docs] @classmethod def load_from_pretrained(cls, pretrained_dir: str, resume_training=False): """ load dataset parameters from file. :param pretrained_dir: (str) folder which saved the parameter file :param resume_training: (bool) load parameter for resuming training or not. :return: an instantiated object """ config = Config.load_from_pretrained(pretrained_dir) dataset = DatasetMultiEncDec(config) input_vocab_file = os.path.join(pretrained_dir, 'input_vocab.json') output_vocab_file = os.path.join(pretrained_dir, 'output_vocab.json') parameter_file = os.path.join(pretrained_dir, 'dataset.json') data_id_file = os.path.join(pretrained_dir, 'data_split.json') input_vocab = read_json_data(input_vocab_file) output_vocab = read_json_data(output_vocab_file) parameter_dict = read_json_data(parameter_file) data_id_dict = read_json_data(data_id_file) in_idx2word_1 = input_vocab['in_idx2word_1'] in_idx2word_2 = input_vocab['in_idx2word_2'] out_idx2symbol_1 = output_vocab['out_idx2symbol_1'] out_idx2symbol_2 = output_vocab['out_idx2symbol_2'] in_word2idx_1 = {} in_word2idx_2 = {} out_symbol2idx_1 = {} out_symbol2idx_2 = {} for idx, word in enumerate(in_idx2word_1): in_idx2word_1[word] = idx for idx, word in enumerate(in_idx2word_2): in_idx2word_2[word] = idx for idx, symbol in enumerate(out_idx2symbol_1): out_idx2symbol_1[symbol] = idx for idx, symbol in enumerate(out_idx2symbol_2): out_idx2symbol_2[symbol] = idx setattr(dataset, 'in_idx2word_1', in_idx2word_1) setattr(dataset, 'in_idx2word_2', in_idx2word_2) setattr(dataset, 'out_idx2symbol_1', out_idx2symbol_1) setattr(dataset, 'out_idx2symbol_2', out_idx2symbol_2) setattr(dataset, 'in_word2idx_1', in_word2idx_1) setattr(dataset, 'in_word2idx_2', in_word2idx_2) setattr(dataset, 'out_symbol2idx_1', out_symbol2idx_1) setattr(dataset, 'out_symbol2idx_2', out_symbol2idx_2) for key, value in parameter_dict.items(): setattr(dataset, key, value) for key,value in data_id_dict.items(): setattr(dataset, key, value) if resume_training: if config['k_fold']: setattr(dataset, 'fold_t', config['fold_t']) setattr(dataset, 'the_fold_t', config['fold_t'] - 1) setattr(dataset, 'from_pretrained', False) setattr(dataset, 'pretrained_dir', pretrained_dir) setattr(dataset, 'resume_training', resume_training) else: setattr(dataset, 'from_pretrained', False) setattr(dataset, 'pretrained_dir', pretrained_dir) setattr(dataset, 'resume_training', resume_training) else: setattr(dataset, 'from_pretrained', True) setattr(dataset, 'pretrained_dir', pretrained_dir) dataset.reset_dataset() return dataset
def __load_pretrained_parameters(self): if self.k_fold: load_dir = os.path.join(self.pretrained_dir, 'fold{}'.format(self.fold_t)) else: load_dir = self.pretrained_dir input_vocab_file = os.path.join(load_dir, 'input_vocab.json') output_vocab_file = os.path.join(load_dir, 'output_vocab.json') parameter_file = os.path.join(load_dir, 'dataset.json') input_vocab = read_json_data(input_vocab_file) output_vocab = read_json_data(output_vocab_file) parameter_dict = read_json_data(parameter_file) in_idx2word_1 = input_vocab['in_idx2word_1'] in_idx2word_2 = input_vocab['in_idx2word_2'] out_idx2symbol_1 = output_vocab['out_idx2symbol_1'] out_idx2symbol_2 = output_vocab['out_idx2symbol_2'] in_word2idx_1 = {} in_word2idx_2 = {} out_symbol2idx_1 = {} out_symbol2idx_2 = {} for idx, word in enumerate(in_idx2word_1): in_idx2word_1[word] = idx for idx, word in enumerate(in_idx2word_2): in_idx2word_2[word] = idx for idx, symbol in enumerate(out_idx2symbol_1): out_idx2symbol_1[symbol] = idx for idx, symbol in enumerate(out_idx2symbol_2): out_idx2symbol_2[symbol] = idx setattr(self, 'in_idx2word_1', in_idx2word_1) setattr(self, 'in_idx2word_2', in_idx2word_2) setattr(self, 'out_idx2symbol_1', out_idx2symbol_1) setattr(self, 'out_idx2symbol_2', out_idx2symbol_2) setattr(self, 'in_word2idx_1', in_word2idx_1) setattr(self, 'in_word2idx_2', in_word2idx_2) setattr(self, 'out_symbol2idx_1', out_symbol2idx_1) setattr(self, 'out_symbol2idx_2', out_symbol2idx_2) for key, value in parameter_dict.items(): setattr(self, key, value)