from utils import USE_GPU, init_experiment from pytorch_lightning import Trainer from test_tube import Experiment, HyperOptArgumentParser from models.baseline import CNN as Network from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint import os config = { 'epochs': 1, 'patience': 50, 'earlystopping_metric': 'val_loss', # 'val_prauc' 'earlystopping_mode': 'min' # 'max' } def epochs_500(): global config config['epochs'] = 500 def epochs_100(): global config config['epochs'] = 100 config['patience'] = 20 def epochs_20(): global config config['epochs'] = 20 def run(hparams): init_experiment(comment=hparams.experiment_name) from utils import CURR_RUN_PATH, logger # import these after init_experiment logger.info(CURR_RUN_PATH) logger.info(f"tensorboard --logdir={CURR_RUN_PATH}") exp = Experiment(save_dir=CURR_RUN_PATH) def setup_config(): def print_config(): global config st = '---------CONFIG--------\n' for k in config.keys(): st += k+':'+str(config.get(k))+'\n' return st conf = hparams.config if conf is not None: conf_func = globals()[conf] try: conf_func() except: logger.error(f"Config {conf} not defined") logger.info(print_config()) setup_config() global config # parameters used in the baseline (read from main.py and solver.py) # n_epochs = 500 # lr = 1e-4 # num_class = 56 # batch_size = 32 early_stop = EarlyStopping( monitor=config['earlystopping_metric'], patience=config['patience'], verbose=True, mode=config['earlystopping_mode'] ) checkpoint_callback = ModelCheckpoint( filepath=os.path.join(CURR_RUN_PATH, 'best.ckpt'), save_best_only=True, verbose=True, monitor='val_loss', mode='min' ) if USE_GPU: trainer = Trainer(gpus=[0], distributed_backend='ddp', experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent, fast_dev_run=False, early_stop_callback=early_stop, checkpoint_callback=checkpoint_callback, nb_sanity_val_steps=0) # don't run sanity validation run else: trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01, fast_dev_run=True) model = Network(num_class=56) # TODO num_class print(model) trainer.fit(model) trainer.test() if __name__=='__main__': parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) parent_parser.add_argument('--experiment_name', type=str, default='pt_lightning_exp_a', help='test tube exp name') parent_parser.add_argument('--config', type=str, help='config function to run') #TODO : multiple arguments for --config using nargs='+' is not working with the test_tube # implementation of argument parser parent_parser.add_argument('--train_percent', type=float, default=1.0, help='how much train data to use') parser = Network.add_model_specific_args(parent_parser) hyperparams = parser.parse_args() run(hyperparams)