from utils import CURR_RUN_PATH, USE_GPU, logger from pytorch_lightning import Trainer from test_tube import Experiment from models.baseline import CNN as Network def run(): logger.info(CURR_RUN_PATH) exp = Experiment(save_dir=CURR_RUN_PATH) # parameters used in the baseline (read from main.py and solver.py) # n_epochs = 500 # lr = 1e-4 # num_class = 56 # batch_size = 32 if USE_GPU: trainer = Trainer(gpus=[0], distributed_backend='ddp', experiment=exp, max_nb_epochs=1, train_percent_check=0.1, fast_dev_run=False) else: trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1, fast_dev_run=True) model = Network(num_class=56) # TODO num_class print(model) trainer.fit(model) trainer.test() if __name__=='__main__': run()