experiment_baseline.py 1.38 KB
Newer Older
1
2

from utils import USE_GPU, init_experiment
3
from pytorch_lightning import Trainer
4
from test_tube import Experiment, HyperOptArgumentParser
5
6
7
from models.baseline import CNN as Network


8
9
10
11

def run(hparams):
    init_experiment(comment=hparams.experiment_name)
    from utils import CURR_RUN_PATH, logger # import these after init_experiment
12
13
14
    logger.info(CURR_RUN_PATH)
    exp = Experiment(save_dir=CURR_RUN_PATH)

15
16
17
18
19
20
    # parameters used in the baseline (read from main.py and solver.py)
    # n_epochs = 500
    # lr = 1e-4
    # num_class = 56
    # batch_size = 32

21
22
    if USE_GPU:
        trainer = Trainer(gpus=[0], distributed_backend='ddp',
23
                          experiment=exp, max_nb_epochs=1, train_percent_check=1.0,
24
25
                          fast_dev_run=False)
    else:
26
        trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
27
28
                          fast_dev_run=True)

Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
29
    model = Network(num_class=56)  # TODO num_class
30
31
32
33
34
35
36
37

    print(model)

    trainer.fit(model)
    trainer.test()


if __name__=='__main__':
38
39
40
41
42
43
    parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
    parent_parser.add_argument('--experiment_name', type=str,
                               default='pt_lightning_exp_a', help='test tube exp name')
    parser = Network.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args()
    run(hyperparams)