Commit 39dc905f authored by Verena Praher's avatar Verena Praher

remove singlerun_crnn

parent 615f22df
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment, HyperOptArgumentParser
from models.crnn import CRNN as Network
import os
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
}
initialized = False # TODO: Find a better way to do this
trial_counter = 0
def run(hparams):
global initialized, trial_counter
trial_counter += 1
if not initialized:
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
if not initialized:
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
initialized = True
trial_name = f"trial_{trial_counter}"
logger.info(trial_name)
logger.info(hparams)
exp = Experiment(name=trial_name, save_dir=CURR_RUN_PATH)
# exp.tag(hparams)
# callbacks
early_stop = EarlyStopping(
monitor='val_loss',
patience=20,
verbose=True,
mode='min'
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(CURR_RUN_PATH, trial_name, 'best.ckpt'),
save_best_only=True,
verbose=True,
monitor='prauc',
mode='max'
)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=hparams.max_epochs,
train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model = Network(num_class=56, config=model_config, hparams=hparams)
print(model)
try:
trainer.fit(model)
except KeyboardInterrupt:
pass
trainer.test()
if __name__=='__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs')
#parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
print(hyperparams)
run(hyperparams)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment