Commit c070c7d9 authored by Verena Praher's avatar Verena Praher
Browse files

add option to reload pretrained midlevel model

parent 6ca30b70
......@@ -85,7 +85,7 @@ def pretrain_midlevel(hparams):
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......@@ -106,11 +106,10 @@ def pretrain_midlevel(hparams):
# streamlog.info("Running test")
# trainer.test()
def train_mtgjamendo(hparams):
def train_mtgjamendo(hparams, midlevel_chkpt_dir):
set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
chkpt_dir = os.path.join(CURR_RUN_PATH, 'mtg.ckpt')
midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
......@@ -137,7 +136,7 @@ def train_mtgjamendo(hparams):
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=20, train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......@@ -155,8 +154,15 @@ def train_mtgjamendo(hparams):
def run(hparams):
init_experiment(comment=hparams.experiment_name)
pretrain_midlevel(hparams)
train_mtgjamendo(hparams)
if hparams.pretrained_midlevel is None:
pretrain_midlevel(hparams)
from utils import CURR_RUN_PATH
midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
else:
from utils import logger
midlevel_chkpt_dir = hparams.pretrained_midlevel
logger.info("Using pretrained model", midlevel_chkpt_dir)
train_mtgjamendo(hparams, midlevel_chkpt_dir)
if __name__=='__main__':
......@@ -168,6 +174,8 @@ if __name__=='__main__':
# implementation of argument parser
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parent_parser.add_argument('--pretrained_midlevel',
default="/home/verena/experiments/moodwalk/runs/ca601 - pretrain_midlevel 20 epochs/midlevel.ckpt/", type=str)
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment