experiment_baseline.py 2.13 KB
Newer Older
1
2

from utils import USE_GPU, init_experiment
3
from pytorch_lightning import Trainer
4
from test_tube import Experiment, HyperOptArgumentParser
5
6
from models.baseline import CNN as Network

7
8
9
config = {
    'epochs':1
}
10

11
12
13
def epochs_500():
    global config
    config['epochs'] = 500
14
15
16
17

def run(hparams):
    init_experiment(comment=hparams.experiment_name)
    from utils import CURR_RUN_PATH, logger # import these after init_experiment
18
    logger.info(CURR_RUN_PATH)
19
    logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
20
21
    exp = Experiment(save_dir=CURR_RUN_PATH)

22
23
24
25
26
27
28
29
30
31
    def setup_config():
        conf = hparams.config
        conf_func = globals()[conf]
        try:
            conf_func()
        except:
            logger.error(f"Config {conf} not defined")

    setup_config()
    global config
32
33
34
35
36
37
    # parameters used in the baseline (read from main.py and solver.py)
    # n_epochs = 500
    # lr = 1e-4
    # num_class = 56
    # batch_size = 32

38
39
    if USE_GPU:
        trainer = Trainer(gpus=[0], distributed_backend='ddp',
40
                          experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
41
42
                          fast_dev_run=False)
    else:
43
        trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
44
45
                          fast_dev_run=True)

Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
46
    model = Network(num_class=56)  # TODO num_class
47
48
49
50
51
52
53
54

    print(model)

    trainer.fit(model)
    trainer.test()


if __name__=='__main__':
55
56
57
    parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
    parent_parser.add_argument('--experiment_name', type=str,
                               default='pt_lightning_exp_a', help='test tube exp name')
58
59
60
61
62
    parent_parser.add_argument('--config', type=str, help='config function to run')
    #TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
    # implementation of argument parser
    parent_parser.add_argument('--train_percent', type=float,
                               default=1.0, help='how much train data to use')
63
64
65
    parser = Network.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args()
    run(hyperparams)