# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2022/2/16 22:02
# @File: configuration.py
# @Update Time: 2022/2/16 22:02
import copy
import sys
import os
import re
import json
import warnings
from logging import getLogger
from enum import Enum
import torch
from mwptoolkit.utils.utils import read_json_data, get_model, write_json_data
[docs]class Config(object):
"""The class for loading pre-defined parameters.
Config will load the parameters from internal config file, dataset config file, model config file, config dictionary and cmd line.
The default road path of internal config file is 'mwptoolkit/config/config.json', and it's not supported to change.
The dataset config, model config and config dictionary are called the external config.
According to specific dataset and model, this class will load the dataset config from default road path 'mwptoolkit/properties/dataset/dataset_name.json'
and model config from default road path 'mwptoolkit/properties/model/model_name.json'.
You can set the parameters 'model_config_path' and 'dataset_config_path' to load your own model and dataset config, but note that only json file can be loaded correctly.
Config dictionary is a dict-like object. When you initialize the Config object, you can pass config dictionary through the code 'config = Config(config_dict=config_dict)'
Cmd line requires you keep the template --param_name=param_value to set any parameter you want.
If there are multiple values of the same parameter, the priority order is as following:
cmd line > external config > internal config
in external config, config dictionary > model config > dataset config.
"""
def __init__(self, model_name=None, dataset_name=None, task_type=None, config_dict={}):
"""
Args:
model_name (str): the model name, default is None, if it is None, config will search the parameter 'model'
from the external input as the model name.
dataset_name (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
from the external input as the dataset name.
task_type (str): the task type, default is None, if it is None, config will search the parameter 'task_type'
from the external input as the task type.
config_dict (dict): the external parameter dictionaries, default is None.
"""
super().__init__()
# internal config
self.internal_config_dict = {}
self.path_config_dict = {}
# external config
self.external_config_dict = {}
self.model_config_dict = {}
self.dataset_config_dict = {}
# cmd config
self.cmd_config_dict = {}
# final config
self.final_config_dict = {}
# load internal config from file
self._load_internal_config()
# initialize external config
self._init_external_config(model_name, dataset_name, task_type, config_dict)
# load cmd line
self._load_cmd_line()
self._build_path_config()
# load model config
self._load_model_config()
# load dataset config
self._load_dataset_config()
# merge model and dataset config to external config
self._merge_external_config_dict()
# merge internal, external and cmd line config to final config
self._build_final_config_dict()
# self._init_model_path()
self._init_device()
def _load_internal_config(self):
dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(dir, 'config.json')
self.internal_config_dict = read_json_data(config_path)
def _init_external_config(self,model_name, dataset_name, task_type, config_dict):
self.external_config_dict['model'] = model_name
self.external_config_dict['dataset'] = dataset_name
self.external_config_dict['task_type'] = task_type
self.external_config_dict.update(config_dict)
[docs] def _convert_config_dict(self, config_dict):
r"""This function convert the str parameters to their original type.
"""
for key in config_dict:
param = config_dict[key]
if not isinstance(param, str):
continue
try:
value = eval(param)
if not isinstance(value, (str, int, float, list, tuple, dict, bool, Enum, None)):
value = param
except (NameError, SyntaxError, TypeError):
if isinstance(param, str):
if param.lower() == "true":
value = True
elif param.lower() == "false":
value = False
elif param.lower() == "none":
value = None
else:
value = param
else:
value = param
config_dict[key] = value
return config_dict
[docs] def _load_cmd_line(self):
r""" Read parameters from command line and convert it to str.
"""
cmd_config_dict = dict()
unrecognized_args = []
if "ipykernel_launcher" not in sys.argv[0]:
for arg in sys.argv[1:]:
if not arg.startswith("--") or len(arg[2:].split("=")) != 2:
if arg.startswith("--search_parameter"):
continue
unrecognized_args.append(arg)
continue
cmd_arg_name, cmd_arg_value = arg[2:].split("=")
if cmd_arg_name in cmd_config_dict and cmd_arg_value != cmd_config_dict[cmd_arg_name]:
raise SyntaxError("There are duplicate commend arg '%s' with different value." % arg)
else:
cmd_config_dict[cmd_arg_name] = cmd_arg_value
if len(unrecognized_args) > 0:
logger = getLogger()
logger.warning('command line args [{}] will not be used in Mwptoolkit'.format(' '.join(unrecognized_args)))
cmd_config_dict = self._convert_config_dict(cmd_config_dict)
# if 'task_type' not in cmd_config_dict:
# task_type = self.external_config_dict['task_type']
# else:
# task_type = cmd_config_dict['task_type']
# if task_type not in ['single_equation', 'multi_equation']:
# raise NotImplementedError("task_type {} can't be found".format(task_type))
self.cmd_config_dict.update(cmd_config_dict)
for key, value in self.external_config_dict.items():
try:
self.external_config_dict[key] = self.cmd_config_dict[key]
except KeyError:
pass
for key, value in self.internal_config_dict.items():
try:
self.internal_config_dict[key] = self.cmd_config_dict[key]
except KeyError:
pass
return cmd_config_dict
def _load_model_config(self):
if self.internal_config_dict["load_best_config"]:
model_config_path = self.path_config_dict["best_config_file"]
else:
model_config_path = self.path_config_dict["model_config_file"]
if not os.path.isabs(model_config_path):
model_config_path = os.path.join(os.getcwd(),model_config_path)
try:
self.model_config_dict = read_json_data(model_config_path)
except FileNotFoundError:
warnings.warn('model config file is not exist, file path : {}'.format(model_config_path))
self.model_config_dict = {}
for key, value in self.model_config_dict.items():
try:
self.model_config_dict[key] = self.external_config_dict[key]
except KeyError:
pass
try:
self.model_config_dict[key] = self.cmd_config_dict[key]
except KeyError:
pass
def _load_dataset_config(self):
dataset_config_file = self.path_config_dict["dataset_config_file"]
if not os.path.isabs(dataset_config_file):
dataset_config_file = os.path.join(os.getcwd(),dataset_config_file)
try:
self.dataset_config_dict = read_json_data(dataset_config_file)
except FileNotFoundError:
warnings.warn('dataset config file is not exist, file path : {}'.format(dataset_config_file))
self.dataset_config_dict = {}
for key, value in self.dataset_config_dict.items():
try:
self.dataset_config_dict[key] = self.external_config_dict[key]
except KeyError:
pass
try:
self.dataset_config_dict[key] = self.cmd_config_dict[key]
except KeyError:
pass
def _build_path_config(self):
path_config_dict = {}
dir = os.path.dirname(os.path.realpath(__file__))
model_name = self.external_config_dict['model']
dataset_name = self.external_config_dict['dataset']
if model_name is None:
model_name = self.cmd_config_dict["model"]
if dataset_name is None:
dataset_name = self.cmd_config_dict["dataset"]
model_config_file = os.path.join(dir, "../properties/model/{}.json".format(model_name))
best_config_file = os.path.join(dir, "../properties/best_config/{}_{}.json".format(model_name, dataset_name))
dataset_config_file = os.path.join(dir, "../properties/dataset/{}.json".format(dataset_name))
path_config_dict["model_config_file"] = os.path.relpath(model_config_file,os.getcwd())
path_config_dict["best_config_file"] = os.path.relpath(best_config_file,os.getcwd())
path_config_dict["dataset_config_file"] = os.path.relpath(dataset_config_file,os.getcwd())
path_config_dict["dataset_dir"] = "dataset/{}".format(dataset_name)
path_config_dict["checkpoint_file"] = 'checkpoint/' + '{}-{}.pth'.format(model_name, dataset_name)
path_config_dict["trained_model_dir"] = 'trained_model/' + '{}-{}'.format(model_name, dataset_name)
path_config_dict["log_file"] = 'log/' + '{}-{}.log'.format(model_name, dataset_name)
path_config_dict["output_dir"] = 'result/{}-{}'.format(model_name,dataset_name)
path_config_dict["checkpoint_dir"] = 'checkpoint/' + '{}-{}'.format(model_name, dataset_name)
self.path_config_dict = path_config_dict
for key, value in path_config_dict.items():
try:
self.path_config_dict[key] = self.external_config_dict[key]
except KeyError:
pass
try:
self.path_config_dict[key] = self.cmd_config_dict[key]
except KeyError:
pass
# merge path config into internal config
self.internal_config_dict.update(self.path_config_dict)
def _init_model_path(self):
path_config_dict = {}
model_name = self.final_config_dict["model"]
dataset_name = self.final_config_dict["dataset"]
fix = self.final_config_dict["equation_fix"]
path_config_dict["checkpoint_file"] = 'checkpoint/' + '{}-{}-{}.pth'.format(model_name, dataset_name, fix)
path_config_dict["trained_model_dir"] = 'trained_model/' + '{}-{}-{}'.format(model_name, dataset_name, fix)
path_config_dict["log_file"] = 'log/' + '{}-{}-{}.log'.format(model_name, dataset_name, fix)
for key, value in path_config_dict.items():
try:
path_config_dict[key] = self.external_config_dict[key]
except KeyError:
pass
try:
path_config_dict[key] = self.cmd_config_dict[key]
except KeyError:
pass
self.path_config_dict.update(path_config_dict)
self.final_config_dict.update(path_config_dict)
def _merge_external_config_dict(self):
external_config_dict = dict()
external_config_dict.update(self.dataset_config_dict)
external_config_dict.update(self.model_config_dict)
external_config_dict.update(self.external_config_dict)
# external_config_dict.update(self.cmd_config_dict)
self.external_config_dict = external_config_dict
def _build_final_config_dict(self):
self.final_config_dict.update(self.internal_config_dict)
self.final_config_dict.update(self.external_config_dict)
self.final_config_dict.update(self.cmd_config_dict)
def _init_device(self):
if self.final_config_dict["gpu_id"] == None:
if torch.cuda.is_available() and self.final_config_dict["use_gpu"]:
self.final_config_dict["gpu_id"] = "0"
else:
self.final_config_dict["gpu_id"] = ""
else:
if self.final_config_dict["use_gpu"] != True:
self.final_config_dict["gpu_id"] = ""
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.final_config_dict["gpu_id"])
self.final_config_dict['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.final_config_dict["map_location"] = "cuda" if torch.cuda.is_available() else "cpu"
self.final_config_dict['gpu_nums'] = torch.cuda.device_count()
def _update_internal_config(self,key,value):
if key in self.internal_config_dict:
self.internal_config_dict[key] = value
if key in self.path_config_dict:
self.path_config_dict[key] = value
def _update_external_config(self,key,value):
if key in self.external_config_dict:
self.external_config_dict[key]=value
if key in self.model_config_dict:
self.model_config_dict[key]=value
if key in self.dataset_config_dict:
self.dataset_config_dict[key] = value
[docs] @classmethod
def load_from_pretrained(cls,pretrained_dir):
config_file = os.path.join(pretrained_dir,'config.json')
config_dict = read_json_data(config_file)
model_name = config_dict['final_config_dict']['model']
dataset_name = config_dict['final_config_dict']['dataset']
task_type = config_dict['final_config_dict']['task_type']
config = Config(model_name,dataset_name,task_type)
for key,value in config_dict.items():
setattr(config,key,value)
config._load_cmd_line()
config._build_path_config()
config._build_final_config_dict()
config._init_device()
return config
[docs] def save_config(self,trained_dir):
json_encoder = json.encoder.JSONEncoder()
config_file = os.path.join(trained_dir, 'config.json')
config_dict = self.to_dict()
not_support_json=[]
for key1,value1 in config_dict.items():
for key2,value2 in value1.items():
try:
json_encoder.encode({key2:value2})
except TypeError:
# del config_dict[key1][key2]
not_support_json.append([key1,key2])
for keys in not_support_json:
del config_dict[keys[0]][keys[1]]
write_json_data(config_dict,config_file)
[docs] def to_dict(self):
config_dict={}
for name, value in vars(self).items():
if hasattr(eval('self.{}'.format(name)), '__call__') or re.match('__.*?__', name):
continue
else:
config_dict[name]=copy.deepcopy(value)
return config_dict
def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str.")
value = self._convert_config_dict({key:value})[key]
self.final_config_dict[key] = value
self._update_internal_config(key, value)
self._update_external_config(key, value)
def __getitem__(self, item):
if item in self.final_config_dict:
return self.final_config_dict[item]
else:
return None
def __delitem__(self, key):
del self.final_config_dict[key]
del self.external_config_dict[key]
del self.model_config_dict[key]
del self.dataset_config_dict[key]
del self.internal_config_dict[key]
del self.path_config_dict[key]
def __str__(self):
args_info = ''
args_info += '\n'.join(["{}={}".format(arg, value) for arg, value in self.final_config_dict.items()])
args_info += '\n\n'
return args_info