Commit e033f5df authored by Verena Praher's avatar Verena Praher
Browse files

minor formatting/parameter changes

parent ce850b60
from utils import USE_GPU, init_experiment from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from test_tube import Experiment, HyperOptArgumentParser from test_tube import Experiment, HyperOptArgumentParser
......
...@@ -21,7 +21,8 @@ def epochs_500(): ...@@ -21,7 +21,8 @@ def epochs_500():
def epochs_100(): def epochs_100():
global config global config
config['epochs'] = 100 config['mtg_epochs'] = 100
config['mid_epochs'] = 100
config['patience'] = 20 config['patience'] = 20
...@@ -30,13 +31,13 @@ def epochs_20(): ...@@ -30,13 +31,13 @@ def epochs_20():
config['epochs'] = 20 config['epochs'] = 20
def midlevel_configs(): #def midlevel_configs():
global config # global config
config['epochs'] = 2 # config['epochs'] = 2
def mtg_configs(): #def mtg_configs():
global config # global config
config['epochs'] = 1 # config['epochs'] = 1
def pretrain_midlevel(hparams): def pretrain_midlevel(hparams):
...@@ -85,7 +86,7 @@ def pretrain_midlevel(hparams): ...@@ -85,7 +86,7 @@ def pretrain_midlevel(hparams):
if USE_GPU: if USE_GPU:
trainer = Trainer( trainer = Trainer(
gpus=[0], distributed_backend='ddp', gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent, experiment=exp, max_nb_epochs=config['mid_epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False, fast_dev_run=False,
early_stop_callback=early_stop, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback checkpoint_callback=checkpoint_callback
...@@ -113,7 +114,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir): ...@@ -113,7 +114,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir):
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}") logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH) exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
mtg_configs() # mtg_configs()
logger.info(f"Loading model from {midlevel_chkpt_dir}") logger.info(f"Loading model from {midlevel_chkpt_dir}")
model = Network(num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir) model = Network(num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
logger.info(f"Loaded model successfully") logger.info(f"Loaded model successfully")
...@@ -136,7 +137,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir): ...@@ -136,7 +137,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir):
if USE_GPU: if USE_GPU:
trainer = Trainer( trainer = Trainer(
gpus=[0], distributed_backend='ddp', gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent, experiment=exp, max_nb_epochs=config['mtg_epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False, fast_dev_run=False,
early_stop_callback=early_stop, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback checkpoint_callback=checkpoint_callback
......
...@@ -16,7 +16,7 @@ def run(hparams): ...@@ -16,7 +16,7 @@ def run(hparams):
# callbacks # callbacks
early_stop = EarlyStopping( early_stop = EarlyStopping(
monitor='val_loss', monitor='val_loss',
patience=20, patience=50,
verbose=True, verbose=True,
mode='min' mode='min'
) )
...@@ -31,7 +31,7 @@ def run(hparams): ...@@ -31,7 +31,7 @@ def run(hparams):
if USE_GPU: if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend=None, trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=100, train_percent_check=1.0, experiment=exp, max_nb_epochs=500, train_percent_check=1.0,
fast_dev_run=False, early_stop_callback=early_stop, fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback) checkpoint_callback=checkpoint_callback)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment