experiment_baseline.py 794 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from utils import CURR_RUN_PATH, USE_GPU, logger
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.baseline import CNN as Network


def run():
    logger.info(CURR_RUN_PATH)
    exp = Experiment(save_dir=CURR_RUN_PATH)

    # TODO fill their training parameters
    if USE_GPU:
        trainer = Trainer(gpus=[0], distributed_backend='ddp',
                          experiment=exp, max_nb_epochs=10, train_percent_check=1.0,
                          fast_dev_run=False)
    else:
        trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
                          fast_dev_run=True)

Shreyan Chowdhury's avatar
Shreyan Chowdhury committed
20
    model = Network(num_class=56)  # TODO num_class
21
22
23
24
25
26
27
28
29

    print(model)

    trainer.fit(model)
    trainer.test()


if __name__=='__main__':
    run()