experiment_baseline.py 3.28 KB
Newer Older
1
2

from utils import USE_GPU, init_experiment
3
from pytorch_lightning import Trainer
4
from test_tube import Experiment, HyperOptArgumentParser
5
from models.baseline import CNN as Network
6
7
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
8

9
config = {
10
11
12
13
    'epochs': 1,
    'patience': 50,
    'earlystopping_metric': 'val_loss', # 'val_prauc'
    'earlystopping_mode': 'min' # 'max'
14
}
15

16

17
18
19
def epochs_500():
    global config
    config['epochs'] = 500
20

21
22
23
24
25
26
27
28
29
30
31
32

def epochs_100():
    global config
    config['epochs'] = 100
    config['patience'] = 20


def epochs_20():
    global config
    config['epochs'] = 20


33
34
35
def run(hparams):
    init_experiment(comment=hparams.experiment_name)
    from utils import CURR_RUN_PATH, logger # import these after init_experiment
36
    logger.info(CURR_RUN_PATH)
37
    logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
38
39
    exp = Experiment(save_dir=CURR_RUN_PATH)

40
    def setup_config():
Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
41
42
43
44
45
46
47
        def print_config():
            global config
            st = '---------CONFIG--------\n'
            for k in config.keys():
                st += k+':'+str(config.get(k))+'\n'
            return st

48
        conf = hparams.config
Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
49
50
51
52
53
54
55
56
        if conf is not None:
            conf_func = globals()[conf]
            try:
                conf_func()
            except:
                logger.error(f"Config {conf} not defined")

        logger.info(print_config())
57
58
59

    setup_config()
    global config
60
61
62
63
64
65
    # parameters used in the baseline (read from main.py and solver.py)
    # n_epochs = 500
    # lr = 1e-4
    # num_class = 56
    # batch_size = 32

66
    early_stop = EarlyStopping(
67
68
        monitor=config['earlystopping_metric'],
        patience=config['patience'],
69
        verbose=True,
70
        mode=config['earlystopping_mode']
71
72
73
74
75
76
77
78
79
80
    )

    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(CURR_RUN_PATH, 'best.ckpt'),
        save_best_only=True,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

81
82
    if USE_GPU:
        trainer = Trainer(gpus=[0], distributed_backend='ddp',
83
                          experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
84
85
86
                          fast_dev_run=False,
                          early_stop_callback=early_stop,
                          checkpoint_callback=checkpoint_callback
Verena Praher's avatar
merge    
Verena Praher committed
87
    )
88
    else:
89
        trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
90
91
                          fast_dev_run=True)

Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
92
    model = Network(num_class=56)  # TODO num_class
93
94
95
96
97
98
99
100

    print(model)

    trainer.fit(model)
    trainer.test()


if __name__=='__main__':
101
102
103
    parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
    parent_parser.add_argument('--experiment_name', type=str,
                               default='pt_lightning_exp_a', help='test tube exp name')
104
105
106
107
108
    parent_parser.add_argument('--config', type=str, help='config function to run')
    #TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
    # implementation of argument parser
    parent_parser.add_argument('--train_percent', type=float,
                               default=1.0, help='how much train data to use')
109
110
111
    parser = Network.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args()
    run(hyperparams)