from utils import USE_GPU, init_experiment, set_paths from pytorch_lightning import Trainer from test_tube import Experiment, HyperOptArgumentParser from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint import os from models.midlevel_vgg import ModelMidlevel as Network import torch config = { 'epochs': 1, 'patience': 50, 'earlystopping_metric': 'val_loss', # 'val_prauc' 'earlystopping_mode': 'min' # 'max' } def epochs_500(): global config config['epochs'] = 500 def epochs_100(): global config config['epochs'] = 100 config['patience'] = 20 def epochs_20(): global config config['epochs'] = 20 def pretrain_midlevel(hparams): set_paths('midlevel') from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment streamlog.info("Training midlevel...") logger.info(f"tensorboard --logdir={CURR_RUN_PATH}") exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH) def setup_config(): def print_config(): global config st = '---------CONFIG--------\n' for k in config.keys(): st += k+':'+str(config.get(k))+'\n' return st conf = hparams.config if conf is not None: conf_func = globals()[conf] try: conf_func() except: logger.error(f"Config {conf} not defined") logger.info(print_config()) setup_config() early_stop = EarlyStopping( monitor=config['earlystopping_metric'], patience=config['patience'], verbose=True, mode=config['earlystopping_mode'] ) chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt') checkpoint_callback = ModelCheckpoint( filepath=chkpt_dir, save_best_only=True, verbose=True, monitor='val_loss', mode='min' ) if USE_GPU: trainer = Trainer( gpus=[0], distributed_backend='ddp', experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent, fast_dev_run=False, early_stop_callback=early_stop, checkpoint_callback=checkpoint_callback ) else: trainer = Trainer( experiment=exp, max_nb_epochs=1, train_percent_check=0.01, fast_dev_run=False, checkpoint_callback=checkpoint_callback ) model = Network(num_targets=7) print(model) trainer.fit(model) # streamlog.info("Running test") # trainer.test() logger.info(f"Loading model from {chkpt_dir}") model = Network(num_targets=7, on_gpu=USE_GPU, load_from=chkpt_dir) logger.info(f"Loaded model successfully") pass def run(hparams): init_experiment(comment=hparams.experiment_name) pretrain_midlevel(hparams) if __name__=='__main__': parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) parent_parser.add_argument('--experiment_name', type=str, default='pt_lightning_exp_a', help='test tube exp name') parent_parser.add_argument('--config', type=str, help='config function to run') #TODO : multiple arguments for --config using nargs='+' is not working with the test_tube # implementation of argument parser parent_parser.add_argument('--train_percent', type=float, default=1.0, help='how much train data to use') parser = Network.add_model_specific_args(parent_parser) hyperparams = parser.parse_args() run(hyperparams)