Source code for mwptoolkit.hyper_search

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/18 11:36:19
# @File: hyper_search.py


import os
import sys
from functools import partial
from logging import getLogger

import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, AsyncHyperBandScheduler

from mwptoolkit.config.configuration import Config
from mwptoolkit.evaluate.evaluator import AbstractEvaluator, InfixEvaluator, PostfixEvaluator, PrefixEvaluator, MultiWayTreeEvaluator
from mwptoolkit.evaluate.evaluator import MultiEncDecEvaluator
from mwptoolkit.data.utils import create_dataset, create_dataloader
from mwptoolkit.utils.utils import get_model, init_seed, get_trainer, read_json_data, write_json_data
from mwptoolkit.utils.enum_type import SpecialTokens, FixType
from mwptoolkit.utils.logger import init_logger

from mwptoolkit.quick_start import run_toolkit

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), ".")))

[docs]def train_process(search_parameter,configs): for key,value in search_parameter.items(): configs[key]=value dataset = create_dataset(configs) dataset.dataset_load() dataloader = create_dataloader(configs)(configs, dataset) model = get_model(configs["model"])(configs, dataset).to(configs["device"]) if configs["equation_fix"] == FixType.Prefix: evaluator = PrefixEvaluator(configs) elif configs["equation_fix"] == FixType.Nonfix or configs["equation_fix"] == FixType.Infix: evaluator = InfixEvaluator(configs) elif configs["equation_fix"] == FixType.Postfix: evaluator = PostfixEvaluator(configs) elif configs["equation_fix"] == FixType.MultiWayTree: evaluator = MultiWayTreeEvaluator(configs) else: raise NotImplementedError if configs['model'].lower() in ['multiencdec']: evaluator = MultiEncDecEvaluator(configs) trainer = get_trainer(configs)(configs, model, dataloader, evaluator) trainer.param_search()
[docs]def hyper_search_process(model_name, dataset_name, task_type, search_parameter, config_dict={}): configs = Config(model_name, dataset_name, task_type, config_dict) init_seed(configs['random_seed'], True) init_logger(configs) logger = getLogger() logger.info(configs) ray.init(num_gpus=configs['gpu_nums']) scheduler = AsyncHyperBandScheduler( metric="accuracy", mode="max") result=tune.run( partial(train_process,configs=configs), resources_per_trial={"cpu": configs['cpu_per_trial'], "gpu": configs['gpu_per_trial']}, config=search_parameter, scheduler=scheduler, num_samples=configs["samples"], raise_on_failed_trial=False ) best_config=result.get_best_config(metric="accuracy", mode="max") logger.info("best config:{}".format(best_config)) config_dict.update(best_config) model_config_path = configs["model_config_file"] if not os.path.isabs(model_config_path): model_config_path = os.path.join(os.getcwd(),model_config_path) model_config=read_json_data(model_config_path) model_config.update(best_config) best_config_path = configs["best_config_file"] if not os.path.isabs(best_config_path): best_config_path = os.path.join(os.getcwd(),best_config_path) write_json_data(model_config,best_config_path) logger.info("best config saved at {}".format(best_config_path)) run_toolkit(model_name,dataset_name,task_type,config_dict)