Commit 2f72c41d authored by Verena Praher's avatar Verena Praher
Browse files

add early stopping and checkpoint to baseline

parent 48deb5bb
......@@ -2,7 +2,8 @@ from utils import CURR_RUN_PATH, USE_GPU, logger
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.baseline import CNN as Network
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
def run():
logger.info(CURR_RUN_PATH)
......@@ -14,10 +15,28 @@ def run():
# num_class = 56
# batch_size = 32
early_stop = EarlyStopping(
monitor='val_loss',
patience=50,
verbose=True,
mode='min'
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(CURR_RUN_PATH, 'best.ckpt'),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=500, train_percent_check=1.0,
fast_dev_run=False)
experiment=exp, max_nb_epochs=100, train_percent_check=1.0,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment