Commit 16cdda9f authored by Verena Praher's avatar Verena Praher

cleanup pretrained midlevel code

parent bc052327
......@@ -5,40 +5,13 @@ from test_tube import Experiment, HyperOptArgumentParser
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
from models.midlevel_vgg import ModelMidlevel as Network
import torch
config = {
'epochs': 1,
'patience': 50,
'earlystopping_metric': 'val_loss', # 'val_prauc'
'earlystopping_mode': 'min' # 'max'
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
}
def epochs_500():
global config
config['epochs'] = 500
def epochs_100():
global config
config['mtg_epochs'] = 100
config['mid_epochs'] = 100
config['patience'] = 20
def epochs_20():
global config
config['epochs'] = 20
#def midlevel_configs():
# global config
# config['epochs'] = 2
#def mtg_configs():
# global config
# config['epochs'] = 1
def pretrain_midlevel(hparams):
set_paths('midlevel')
......@@ -47,31 +20,12 @@ def pretrain_midlevel(hparams):
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
monitor='val_loss',
patience=50,
verbose=True,
mode=config['earlystopping_mode']
mode='min'
)
chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
......@@ -86,7 +40,7 @@ def pretrain_midlevel(hparams):
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['mid_epochs'], train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=hparams.max_pretrain_epochs, train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......@@ -97,7 +51,7 @@ def pretrain_midlevel(hparams):
fast_dev_run=False, checkpoint_callback=checkpoint_callback
)
model = Network(num_targets=7)
model = Network(model_config, hparams, num_targets=7)
print(model)
......@@ -107,6 +61,7 @@ def pretrain_midlevel(hparams):
# streamlog.info("Running test")
# trainer.test()
def train_mtgjamendo(hparams, midlevel_chkpt_dir):
set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
......@@ -120,24 +75,24 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir):
logger.info(f"Loaded model successfully")
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
monitor='val_prauc',
patience=50,
verbose=True,
mode=config['earlystopping_mode']
mode='max'
)
checkpoint_callback = ModelCheckpoint(
filepath=chkpt_dir,
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
monitor='val_prauc',
mode='max'
)
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['mtg_epochs'], train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=hparams.max_finetune_epochs, train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......@@ -170,9 +125,10 @@ if __name__=='__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--config', type=str, help='config function to run')
#TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
# implementation of argument parser
parent_parser.add_argument('--max_pretrain_epochs', type=int,
default=10, help='maximum number of epochs for pretraining')
parent_parser.add_argument('--max_finetune_epochs', type=int,
default=10, help='maximum number of epochs for finetuning')
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parent_parser.add_argument('--pretrained_midlevel',
......
import torch.nn as nn
from datasets.midlevel import df_get_midlevel_set
from models.shared_stuff import *
from models.shared_stuff import BasePtlModel
from test_tube import HyperOptArgumentParser
from utils import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from datasets.midlevel import df_get_midlevel_set
from datasets.mtgjamendo import df_get_mtg_set
from sklearn import metrics
def initialize_weights(module):
if isinstance(module, nn.Conv2d):
......@@ -21,9 +20,10 @@ def initialize_weights(module):
elif isinstance(module, nn.Linear):
module.bias.data.zero_()
class ModelMidlevel(pl.LightningModule):
def __init__(self, num_targets, initialize=True, dataset='midlevel', load_from=None, on_gpu=None, map_location=None):
super(ModelMidlevel, self).__init__()
class ModelMidlevel(BasePtlModel):
def __init__(self, config, hparams, num_targets, initialize=True, dataset='midlevel', load_from=None, on_gpu=None, map_location=None):
super(ModelMidlevel, self).__init__(config, hparams)
self.dataset = dataset
if dataset=='midlevel':
......@@ -118,9 +118,8 @@ class ModelMidlevel(pl.LightningModule):
self._load_model(load_from, map_location, on_gpu)
if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(256
, 10)
self.fc_mtg2 = nn.Linear(10, 56)
self.fc_mtg1 = nn.Linear(512, 56)
# self.fc_mtg2 = nn.Linear(10, 56)
for name, param in self.named_parameters():
if 'mtg' in name:
......@@ -140,14 +139,15 @@ class ModelMidlevel(pl.LightningModule):
x = self.conv6(x) # 39 * 18 * 256
x = self.conv7(x) # 39 * 18 * 384
x = self.conv7b(x) # 39 * 18 * 384
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
# ml = self.fc_ml(x)
# x = self.conv11(x) # 2 * 2 * 256
if self.dataset == 'midlevel':
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
x = self.fc_ml(x)
if self.dataset=='mtgjamendo':
x = self.fc_mtg1(x)
logit = nn.Sigmoid()(self.fc_mtg2(x))
x = x.view(x.size(0), -1)
#x = self.fc_mtg1(x)
logit = nn.Sigmoid()(self.fc_mtg1(x))
return logit
return x
......@@ -186,36 +186,22 @@ class ModelMidlevel(pl.LightningModule):
if self.dataset=='midlevel':
return F.mse_loss(y_hat, y)
else:
return my_loss(y_hat, y)
def forward_full_song(self, x, y):
# print(x.shape)
#TODO full song???
return self.forward(x[:, :, :, :512])
# y_hat = torch.zeros((x.shape[0], 56), requires_grad=True).cuda()
# hop_size = 256
# i=0
# count = 0
# while i < x.shape[-1]:
# y_hat += self.forward(x[:,:,:,i:i+512])
# i += hop_size
# count += 1
# return y_hat/count
return super(ModelMidlevel, self).loss(y_hat, y)
def training_step(self, data_batch, batch_nb):
if self.dataset=='midlevel':
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
else:
return training_step(self, data_batch, batch_nb)
return super(ModelMidlevel, self).training_step(data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb):
if self.dataset=='midlevel':
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'val_loss': self.my_loss(y_hat, y),
......@@ -223,7 +209,7 @@ class ModelMidlevel(pl.LightningModule):
'y_hat': y_hat.cpu().numpy(),
}
else:
return validation_step(self, data_batch, batch_nb)
return super(ModelMidlevel, self).validation_step(data_batch, batch_nb)
def validation_end(self, outputs):
......@@ -240,12 +226,12 @@ class ModelMidlevel(pl.LightningModule):
return {'val_loss': avg_loss}
else:
return validation_end(outputs)
return super(ModelMidlevel, self).validation_end(outputs)
def test_step(self, data_batch, batch_nb):
if self.dataset == 'midlevel':
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'test_loss': self.my_loss(y_hat, y),
......@@ -253,7 +239,7 @@ class ModelMidlevel(pl.LightningModule):
'y_hat': y_hat.cpu().numpy(),
}
else:
return test_step(self,data_batch,batch_nb)
return super(ModelMidlevel, self).test_step(data_batch,batch_nb)
def test_end(self, outputs):
if self.dataset == 'midlevel':
......@@ -262,7 +248,7 @@ class ModelMidlevel(pl.LightningModule):
self.experiment.log(test_metrics)
return test_metrics
else:
return test_end(outputs)
return super(ModelMidlevel, self).test_end(outputs)
def configure_optimizers(self):
......@@ -283,5 +269,26 @@ class ModelMidlevel(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
return parent_parser
pass
"""Parameters defined here will be available to your model through self.hparams
"""
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# network params
# parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=True)
#parser.opt_list('--learning_rate', default=0.0001, type=float,
# options=[0.00001, 0.0005, 0.001],
# tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
# training params (opt)
#parser.opt_list('--optimizer_name', default='adam', type=str,
# options=['adam'], tunable=False)
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
#parser.opt_list('--batch_size', default=32, type=int,
# options=[16, 32], tunable=False,
# help='batch size will be divided over all gpus being used across all nodes')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment