Commit 62e525c6 authored by Verena Praher's avatar Verena Praher
Browse files

setup experiment files for baseline and ResNets with some todos

parent 146aac96
from utils import CURR_RUN_PATH, USE_GPU, logger
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.baseline import CNN as Network
def run():
logger.info(CURR_RUN_PATH)
exp = Experiment(save_dir=CURR_RUN_PATH)
# TODO fill their training parameters
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=10, train_percent_check=1.0,
fast_dev_run=False)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model = Network() # TODO num_class
print(model)
trainer.fit(model)
trainer.test()
if __name__=='__main__':
run()
\ No newline at end of file
from utils import CURR_RUN_PATH, USE_GPU, logger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from test_tube import Experiment
from models.resnet18 import Network
def run():
logger.info(CURR_RUN_PATH)
exp = Experiment(save_dir=CURR_RUN_PATH)
# TODO other training parameters?
# callbacks
early_stop = EarlyStopping(
monitor='val_loss', # TODO: check if this exists
patience=50,
verbose=True,
mode='min' # TODO: check if correct
)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=500, train_percent_check=1.0,
fast_dev_run=False, early_stop_callback=early_stop)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model = Network() # TODO num_tags
print(model)
trainer.fit(model)
# TODO log testing results
trainer.test()
if __name__=='__main__':
run()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment