Commit 176d4816 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

implement pretrained model loading and fine-tuning

parent eb2252a1
......@@ -30,11 +30,19 @@ def epochs_20():
config['epochs'] = 20
def midlevel_configs():
global config
config['epochs'] = 2
def mtg_configs():
global config
config['epochs'] = 1
def pretrain_midlevel(hparams):
set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
streamlog.info("Training midlevel...")
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH)
......@@ -92,19 +100,63 @@ def pretrain_midlevel(hparams):
print(model)
logger.info("Training midlevel")
trainer.fit(model)
logger.info("Training midlevel completed")
# streamlog.info("Running test")
# trainer.test()
logger.info(f"Loading model from {chkpt_dir}")
model = Network(num_targets=7, on_gpu=USE_GPU, load_from=chkpt_dir)
def train_mtgjamendo(hparams):
set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
chkpt_dir = os.path.join(CURR_RUN_PATH, 'mtg.ckpt')
midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
mtg_configs()
logger.info(f"Loading model from {midlevel_chkpt_dir}")
model = Network(num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
logger.info(f"Loaded model successfully")
pass
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
verbose=True,
mode=config['earlystopping_mode']
)
checkpoint_callback = ModelCheckpoint(
filepath=chkpt_dir,
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
)
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
else:
trainer = Trainer(
experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=False, checkpoint_callback=checkpoint_callback
)
logger.info("Training mtgjamendo")
trainer.fit(model)
logger.info("Training mtgjamendo completed")
def run(hparams):
init_experiment(comment=hparams.experiment_name)
pretrain_midlevel(hparams)
train_mtgjamendo(hparams)
if __name__=='__main__':
......
......@@ -23,14 +23,25 @@ def initialize_weights(module):
module.bias.data.zero_()
class ModelMidlevel(pl.LightningModule):
def __init__(self, num_targets, initialize=True, load_from=None, on_gpu=None, map_location=None):
def __init__(self, num_targets, initialize=True, dataset='midlevel', load_from=None, on_gpu=None, map_location=None):
super(ModelMidlevel, self).__init__()
data_root, audio_path, csvs_path = get_paths()
cache_x_name = '_ap_midlevel44k'
from torch.utils.data import random_split
dataset, dataset_length = df_get_midlevel_set('midlevel', os.path.join(csvs_path, 'annotations.csv'), audio_path, cache_x_name)
self.trainset, self.validationset, self.testset = random_split(dataset, [int(i*dataset_length) for i in [0.7, 0.2, 0.1]])
self.dataset = dataset
if dataset=='midlevel':
data_root, audio_path, csvs_path = get_paths('midlevel')
cache_x_name = '_ap_midlevel44k'
from torch.utils.data import random_split
dataset, dataset_length = df_get_midlevel_set('midlevel', os.path.join(csvs_path, 'annotations.csv'), audio_path, cache_x_name)
self.trainset, self.validationset, self.testset = random_split(dataset, [int(i*dataset_length) for i in [0.7, 0.2, 0.1]])
elif dataset=='mtgjamendo':
data_root, audio_path, csvs_path = get_paths('mtgjamendo')
cache_x_name = "_ap_mtgjamendo44k"
train_csv = os.path.join(csvs_path, 'train_processed.tsv')
validation_csv = os.path.join(csvs_path, 'validation_processed.tsv')
test_csv = os.path.join(csvs_path, 'test_processed.tsv')
self.trainset = df_get_mtg_set('mtgjamendo', train_csv, audio_path, cache_x_name)
self.validationset = df_get_mtg_set('mtgjamendo_val', validation_csv, audio_path, cache_x_name)
self.testset = df_get_mtg_set('mtgjamendo_test', test_csv, audio_path, cache_x_name)
self.num_targets = num_targets
......@@ -100,12 +111,22 @@ class ModelMidlevel(pl.LightningModule):
)
self.fc_ml = nn.Linear(256, 7)
if initialize:
self.apply(initialize_weights)
if load_from:
self._load_model(load_from, map_location, on_gpu)
if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(7, 10)
self.fc_mtg2 = nn.Linear(10, 56)
for name, param in self.named_parameters():
if 'mtg' in name:
param.requires_grad = True
else:
param.requires_grad = False
def forward(self, x):
# 313 * 149 * 1
......@@ -122,6 +143,10 @@ class ModelMidlevel(pl.LightningModule):
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
ml = self.fc_ml(x)
if self.dataset=='mtgjamendo':
x = self.fc_mtg1(ml)
logit = nn.Sigmoid()(self.fc_mtg2(x))
return logit
return ml
def _load_model(self, load_from, map_location=None, on_gpu=True):
......@@ -156,7 +181,10 @@ class ModelMidlevel(pl.LightningModule):
self.load_state_dict(checkpoint['state_dict'])
def my_loss(self, y_hat, y):
return F.mse_loss(y_hat, y)
if self.dataset=='midlevel':
return F.mse_loss(y_hat, y)
else:
return my_loss(y_hat, y)
def forward_full_song(self, x, y):
# print(x.shape)
......@@ -173,51 +201,66 @@ class ModelMidlevel(pl.LightningModule):
# return y_hat/count
def training_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
if self.dataset=='midlevel':
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
else:
return training_step(self, data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'val_loss': self.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
}
if self.dataset=='midlevel':
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'val_loss': self.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
}
else:
return validation_step(self, data_batch, batch_nb)
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
y = []
y_hat = []
for output in outputs:
y.append(output['y'])
y_hat.append(output['y_hat'])
if self.dataset=='midlevel':
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
y = []
y_hat = []
for output in outputs:
y.append(output['y'])
y_hat.append(output['y_hat'])
y = np.concatenate(y)
y_hat = np.concatenate(y_hat)
y = np.concatenate(y)
y_hat = np.concatenate(y_hat)
return {'val_loss': avg_loss}
return {'val_loss': avg_loss}
else:
return validation_end(outputs)
def test_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'test_loss': self.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
}
if self.dataset == 'midlevel':
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'test_loss': self.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
}
else:
return test_step(self,data_batch,batch_nb)
def test_end(self, outputs):
avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
test_metrics = {"test_loss":avg_test_loss}
self.experiment.log(test_metrics)
return test_metrics
if self.dataset == 'midlevel':
avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
test_metrics = {"test_loss":avg_test_loss}
self.experiment.log(test_metrics)
return test_metrics
else:
return test_end(outputs)
def configure_optimizers(self):
......
......@@ -30,7 +30,7 @@ data_roots = {
"shreyan-All-Series": "/mnt/2tb/datasets/MTG-Jamendo"
},
"midlevel":{
"rechenknecht3.cp.jku.at": "/media/rk3/shared/datasets/midlevel",
"rechenknecht3.cp.jku.at": "/media/rk3/shared/midlevel",
"rechenknecht2.cp.jku.at": "/media/rk2/shared/datasets/midlevel",
"rechenknecht1.cp.jku.at": "/media/rk1/shared/datasets/midlevel",
"hermine":"",
......@@ -98,7 +98,17 @@ def set_paths(dataset_name):
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
def get_paths():
def get_paths(dataset_name):
PATH_DATA_ROOT = data_roots[dataset_name][hostname]
if dataset_name == 'midlevel':
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'metadata_annotations')
elif dataset_name == 'mtgjamendo':
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
else:
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'annotations')
return PATH_DATA_ROOT, PATH_AUDIO, PATH_ANNOTATIONS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment