Commit cab8a9f2 authored by Verena Praher's avatar Verena Praher

setup experiment for joint midlevel/mtg training

parent 49d7ad1e
from utils import USE_GPU, init_experiment, exit_experiment
from pytorch_lightning import Trainer
from test_tube import Experiment, HyperOptArgumentParser
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
from models.midlevel_mtg_vgg import ModelMidlevel as Network
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
}
initialized = False
trial_counter = 0
def run(hparams):
global initialized, trial_counter
trial_counter += 1
if not initialized:
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
if not initialized:
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
initialized = True
trial_name = f"trial_{trial_counter}"
logger.info(trial_name)
logger.info(hparams)
exp = Experiment(name=trial_name, save_dir=CURR_RUN_PATH)
# exp.tag(hparams)
# callbacks
early_stop = EarlyStopping(
monitor='val_loss',
patience=20,
verbose=True,
mode='min'
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(CURR_RUN_PATH, trial_name, 'best.ckpt'),
save_best_only=True,
verbose=True,
monitor='prauc',
mode='max'
)
trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=hparams.max_epochs,
train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run
model = Network(config=model_config, hparams=hparams, num_targets=56)
print(model)
try:
trainer.fit(model)
except KeyboardInterrupt:
logger.info("Training interrupted")
except:
logger.exception(msg="Error occurred during train!")
exit_experiment('failed', exp)
try:
logger.info("Starting test...")
trainer.test()
except KeyboardInterrupt:
logger.info("Exiting...")
exit_experiment('stopped', exp)
except:
logger.exception(msg="Error occurred during test!")
exit_experiment('failed', exp)
if __name__=='__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs for training')
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment