Commit c070c7d9 authored by Verena Praher's avatar Verena Praher
Browse files

add option to reload pretrained midlevel model

parent 6ca30b70
...@@ -85,7 +85,7 @@ def pretrain_midlevel(hparams): ...@@ -85,7 +85,7 @@ def pretrain_midlevel(hparams):
if USE_GPU: if USE_GPU:
trainer = Trainer( trainer = Trainer(
gpus=[0], distributed_backend='ddp', gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent, experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
fast_dev_run=False, fast_dev_run=False,
early_stop_callback=early_stop, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback checkpoint_callback=checkpoint_callback
...@@ -106,11 +106,10 @@ def pretrain_midlevel(hparams): ...@@ -106,11 +106,10 @@ def pretrain_midlevel(hparams):
# streamlog.info("Running test") # streamlog.info("Running test")
# trainer.test() # trainer.test()
def train_mtgjamendo(hparams): def train_mtgjamendo(hparams, midlevel_chkpt_dir):
set_paths('midlevel') set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
chkpt_dir = os.path.join(CURR_RUN_PATH, 'mtg.ckpt') chkpt_dir = os.path.join(CURR_RUN_PATH, 'mtg.ckpt')
midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}") logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH) exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
...@@ -137,7 +136,7 @@ def train_mtgjamendo(hparams): ...@@ -137,7 +136,7 @@ def train_mtgjamendo(hparams):
if USE_GPU: if USE_GPU:
trainer = Trainer( trainer = Trainer(
gpus=[0], distributed_backend='ddp', gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent, experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
fast_dev_run=False, fast_dev_run=False,
early_stop_callback=early_stop, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback checkpoint_callback=checkpoint_callback
...@@ -155,8 +154,15 @@ def train_mtgjamendo(hparams): ...@@ -155,8 +154,15 @@ def train_mtgjamendo(hparams):
def run(hparams): def run(hparams):
init_experiment(comment=hparams.experiment_name) init_experiment(comment=hparams.experiment_name)
pretrain_midlevel(hparams) if hparams.pretrained_midlevel is None:
train_mtgjamendo(hparams) pretrain_midlevel(hparams)
from utils import CURR_RUN_PATH
midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
else:
from utils import logger
midlevel_chkpt_dir = hparams.pretrained_midlevel
logger.info("Using pretrained model", midlevel_chkpt_dir)
train_mtgjamendo(hparams, midlevel_chkpt_dir)
if __name__=='__main__': if __name__=='__main__':
...@@ -168,6 +174,8 @@ if __name__=='__main__': ...@@ -168,6 +174,8 @@ if __name__=='__main__':
# implementation of argument parser # implementation of argument parser
parent_parser.add_argument('--train_percent', type=float, parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use') default=1.0, help='how much train data to use')
parent_parser.add_argument('--pretrained_midlevel',
default="/home/verena/experiments/moodwalk/runs/ca601 - pretrain_midlevel 20 epochs/midlevel.ckpt/", type=str)
parser = Network.add_model_specific_args(parent_parser) parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args() hyperparams = parser.parse_args()
run(hyperparams) run(hyperparams)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment