experiment_midlevel.py 5.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

from utils import USE_GPU, init_experiment, set_paths
from pytorch_lightning import Trainer
from test_tube import Experiment, HyperOptArgumentParser
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
from models.midlevel_vgg import ModelMidlevel as Network
import torch

config = {
    'epochs': 1,
    'patience': 50,
    'earlystopping_metric': 'val_loss', # 'val_prauc'
    'earlystopping_mode': 'min' # 'max'
}

def epochs_500():
    global config
    config['epochs'] = 500


def epochs_100():
    global config
    config['epochs'] = 100
    config['patience'] = 20


def epochs_20():
    global config
    config['epochs'] = 20


33
34
35
36
37
38
39
40
41
def midlevel_configs():
    global config
    config['epochs'] = 2

def mtg_configs():
    global config
    config['epochs'] = 1


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def pretrain_midlevel(hparams):
    set_paths('midlevel')
    from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment

    logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
    exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH)

    def setup_config():
        def print_config():
            global config
            st = '---------CONFIG--------\n'
            for k in config.keys():
                st += k+':'+str(config.get(k))+'\n'
            return st

        conf = hparams.config
        if conf is not None:
            conf_func = globals()[conf]
            try:
                conf_func()
            except:
                logger.error(f"Config {conf} not defined")

        logger.info(print_config())

    setup_config()

    early_stop = EarlyStopping(
        monitor=config['earlystopping_metric'],
        patience=config['patience'],
        verbose=True,
        mode=config['earlystopping_mode']
    )

    chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
    checkpoint_callback = ModelCheckpoint(
        filepath=chkpt_dir,
        save_best_only=True,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

    if USE_GPU:
        trainer = Trainer(
            gpus=[0], distributed_backend='ddp',
88
            experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            fast_dev_run=False,
            early_stop_callback=early_stop,
            checkpoint_callback=checkpoint_callback
        )
    else:
        trainer = Trainer(
            experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
            fast_dev_run=False, checkpoint_callback=checkpoint_callback
        )

    model = Network(num_targets=7)

    print(model)

103
    logger.info("Training midlevel")
104
    trainer.fit(model)
105
    logger.info("Training midlevel completed")
106
107
108
    # streamlog.info("Running test")
    # trainer.test()

109
def train_mtgjamendo(hparams, midlevel_chkpt_dir):
110
111
112
113
114
115
116
117
118
    set_paths('midlevel')
    from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
    chkpt_dir = os.path.join(CURR_RUN_PATH, 'mtg.ckpt')

    logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
    exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
    mtg_configs()
    logger.info(f"Loading model from {midlevel_chkpt_dir}")
    model = Network(num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
119
    logger.info(f"Loaded model successfully")
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

    early_stop = EarlyStopping(
        monitor=config['earlystopping_metric'],
        patience=config['patience'],
        verbose=True,
        mode=config['earlystopping_mode']
    )

    checkpoint_callback = ModelCheckpoint(
        filepath=chkpt_dir,
        save_best_only=True,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

    if USE_GPU:
        trainer = Trainer(
            gpus=[0], distributed_backend='ddp',
139
            experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
140
141
142
143
144
145
146
147
148
149
150
151
152
            fast_dev_run=False,
            early_stop_callback=early_stop,
            checkpoint_callback=checkpoint_callback
        )
    else:
        trainer = Trainer(
            experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
            fast_dev_run=False, checkpoint_callback=checkpoint_callback
        )

    logger.info("Training mtgjamendo")
    trainer.fit(model)
    logger.info("Training mtgjamendo completed")
153
154
155
156


def run(hparams):
    init_experiment(comment=hparams.experiment_name)
157
158
159
160
161
162
163
164
165
    if hparams.pretrained_midlevel is None:
        pretrain_midlevel(hparams)
        from utils import CURR_RUN_PATH
        midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
    else:
        from utils import logger
        midlevel_chkpt_dir = hparams.pretrained_midlevel
        logger.info("Using pretrained model", midlevel_chkpt_dir)
    train_mtgjamendo(hparams, midlevel_chkpt_dir)
166
167
168
169
170
171
172
173
174
175
176


if __name__=='__main__':
    parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
    parent_parser.add_argument('--experiment_name', type=str,
                               default='pt_lightning_exp_a', help='test tube exp name')
    parent_parser.add_argument('--config', type=str, help='config function to run')
    #TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
    # implementation of argument parser
    parent_parser.add_argument('--train_percent', type=float,
                               default=1.0, help='how much train data to use')
177
178
    parent_parser.add_argument('--pretrained_midlevel',
                               default="/home/verena/experiments/moodwalk/runs/ca601 - pretrain_midlevel 20 epochs/midlevel.ckpt/", type=str)
179
180
181
    parser = Network.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args()
    run(hyperparams)