Commit e033f5df authored by Verena Praher's avatar Verena Praher
Browse files

minor formatting/parameter changes

parent ce850b60
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer
from test_tube import Experiment, HyperOptArgumentParser
......
......@@ -21,7 +21,8 @@ def epochs_500():
def epochs_100():
global config
config['epochs'] = 100
config['mtg_epochs'] = 100
config['mid_epochs'] = 100
config['patience'] = 20
......@@ -30,13 +31,13 @@ def epochs_20():
config['epochs'] = 20
def midlevel_configs():
global config
config['epochs'] = 2
#def midlevel_configs():
# global config
# config['epochs'] = 2
def mtg_configs():
global config
config['epochs'] = 1
#def mtg_configs():
# global config
# config['epochs'] = 1
def pretrain_midlevel(hparams):
......@@ -85,7 +86,7 @@ def pretrain_midlevel(hparams):
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=config['mid_epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......@@ -113,7 +114,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir):
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
mtg_configs()
# mtg_configs()
logger.info(f"Loading model from {midlevel_chkpt_dir}")
model = Network(num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
logger.info(f"Loaded model successfully")
......@@ -136,7 +137,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir):
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=config['mtg_epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......
......@@ -16,7 +16,7 @@ def run(hparams):
# callbacks
early_stop = EarlyStopping(
monitor='val_loss',
patience=20,
patience=50,
verbose=True,
mode='min'
)
......@@ -31,7 +31,7 @@ def run(hparams):
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=100, train_percent_check=1.0,
experiment=exp, max_nb_epochs=500, train_percent_check=1.0,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment