experiment_baseline.py 904 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
from utils import CURR_RUN_PATH, USE_GPU, logger
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.baseline import CNN as Network


def run():
    logger.info(CURR_RUN_PATH)
    exp = Experiment(save_dir=CURR_RUN_PATH)

11
12
13
14
15
16
    # parameters used in the baseline (read from main.py and solver.py)
    # n_epochs = 500
    # lr = 1e-4
    # num_class = 56
    # batch_size = 32

17
18
    if USE_GPU:
        trainer = Trainer(gpus=[0], distributed_backend='ddp',
Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
19
                          experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
20
21
22
23
24
                          fast_dev_run=False)
    else:
        trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
                          fast_dev_run=True)

Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
25
    model = Network(num_class=56)  # TODO num_class
26
27
28
29
30
31
32
33
34

    print(model)

    trainer.fit(model)
    trainer.test()


if __name__=='__main__':
    run()