Commit 176d4816 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

implement pretrained model loading and fine-tuning

parent eb2252a1
...@@ -30,11 +30,19 @@ def epochs_20(): ...@@ -30,11 +30,19 @@ def epochs_20():
config['epochs'] = 20 config['epochs'] = 20
def midlevel_configs():
global config
config['epochs'] = 2
def mtg_configs():
global config
config['epochs'] = 1
def pretrain_midlevel(hparams): def pretrain_midlevel(hparams):
set_paths('midlevel') set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
streamlog.info("Training midlevel...")
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}") logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH) exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH)
...@@ -92,19 +100,63 @@ def pretrain_midlevel(hparams): ...@@ -92,19 +100,63 @@ def pretrain_midlevel(hparams):
print(model) print(model)
logger.info("Training midlevel")
trainer.fit(model) trainer.fit(model)
logger.info("Training midlevel completed")
# streamlog.info("Running test") # streamlog.info("Running test")
# trainer.test() # trainer.test()
logger.info(f"Loading model from {chkpt_dir}") def train_mtgjamendo(hparams):
model = Network(num_targets=7, on_gpu=USE_GPU, load_from=chkpt_dir) set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
chkpt_dir = os.path.join(CURR_RUN_PATH, 'mtg.ckpt')
midlevel_chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
mtg_configs()
logger.info(f"Loading model from {midlevel_chkpt_dir}")
model = Network(num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
logger.info(f"Loaded model successfully") logger.info(f"Loaded model successfully")
pass
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
verbose=True,
mode=config['earlystopping_mode']
)
checkpoint_callback = ModelCheckpoint(
filepath=chkpt_dir,
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
)
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
else:
trainer = Trainer(
experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=False, checkpoint_callback=checkpoint_callback
)
logger.info("Training mtgjamendo")
trainer.fit(model)
logger.info("Training mtgjamendo completed")
def run(hparams): def run(hparams):
init_experiment(comment=hparams.experiment_name) init_experiment(comment=hparams.experiment_name)
pretrain_midlevel(hparams) pretrain_midlevel(hparams)
train_mtgjamendo(hparams)
if __name__=='__main__': if __name__=='__main__':
......
...@@ -23,14 +23,25 @@ def initialize_weights(module): ...@@ -23,14 +23,25 @@ def initialize_weights(module):
module.bias.data.zero_() module.bias.data.zero_()
class ModelMidlevel(pl.LightningModule): class ModelMidlevel(pl.LightningModule):
def __init__(self, num_targets, initialize=True, load_from=None, on_gpu=None, map_location=None): def __init__(self, num_targets, initialize=True, dataset='midlevel', load_from=None, on_gpu=None, map_location=None):
super(ModelMidlevel, self).__init__() super(ModelMidlevel, self).__init__()
data_root, audio_path, csvs_path = get_paths() self.dataset = dataset
cache_x_name = '_ap_midlevel44k' if dataset=='midlevel':
from torch.utils.data import random_split data_root, audio_path, csvs_path = get_paths('midlevel')
dataset, dataset_length = df_get_midlevel_set('midlevel', os.path.join(csvs_path, 'annotations.csv'), audio_path, cache_x_name) cache_x_name = '_ap_midlevel44k'
self.trainset, self.validationset, self.testset = random_split(dataset, [int(i*dataset_length) for i in [0.7, 0.2, 0.1]]) from torch.utils.data import random_split
dataset, dataset_length = df_get_midlevel_set('midlevel', os.path.join(csvs_path, 'annotations.csv'), audio_path, cache_x_name)
self.trainset, self.validationset, self.testset = random_split(dataset, [int(i*dataset_length) for i in [0.7, 0.2, 0.1]])
elif dataset=='mtgjamendo':
data_root, audio_path, csvs_path = get_paths('mtgjamendo')
cache_x_name = "_ap_mtgjamendo44k"
train_csv = os.path.join(csvs_path, 'train_processed.tsv')
validation_csv = os.path.join(csvs_path, 'validation_processed.tsv')
test_csv = os.path.join(csvs_path, 'test_processed.tsv')
self.trainset = df_get_mtg_set('mtgjamendo', train_csv, audio_path, cache_x_name)
self.validationset = df_get_mtg_set('mtgjamendo_val', validation_csv, audio_path, cache_x_name)
self.testset = df_get_mtg_set('mtgjamendo_test', test_csv, audio_path, cache_x_name)
self.num_targets = num_targets self.num_targets = num_targets
...@@ -100,12 +111,22 @@ class ModelMidlevel(pl.LightningModule): ...@@ -100,12 +111,22 @@ class ModelMidlevel(pl.LightningModule):
) )
self.fc_ml = nn.Linear(256, 7) self.fc_ml = nn.Linear(256, 7)
if initialize: if initialize:
self.apply(initialize_weights) self.apply(initialize_weights)
if load_from: if load_from:
self._load_model(load_from, map_location, on_gpu) self._load_model(load_from, map_location, on_gpu)
if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(7, 10)
self.fc_mtg2 = nn.Linear(10, 56)
for name, param in self.named_parameters():
if 'mtg' in name:
param.requires_grad = True
else:
param.requires_grad = False
def forward(self, x): def forward(self, x):
# 313 * 149 * 1 # 313 * 149 * 1
...@@ -122,6 +143,10 @@ class ModelMidlevel(pl.LightningModule): ...@@ -122,6 +143,10 @@ class ModelMidlevel(pl.LightningModule):
x = self.conv11(x) # 2 * 2 * 256 x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
ml = self.fc_ml(x) ml = self.fc_ml(x)
if self.dataset=='mtgjamendo':
x = self.fc_mtg1(ml)
logit = nn.Sigmoid()(self.fc_mtg2(x))
return logit
return ml return ml
def _load_model(self, load_from, map_location=None, on_gpu=True): def _load_model(self, load_from, map_location=None, on_gpu=True):
...@@ -156,7 +181,10 @@ class ModelMidlevel(pl.LightningModule): ...@@ -156,7 +181,10 @@ class ModelMidlevel(pl.LightningModule):
self.load_state_dict(checkpoint['state_dict']) self.load_state_dict(checkpoint['state_dict'])
def my_loss(self, y_hat, y): def my_loss(self, y_hat, y):
return F.mse_loss(y_hat, y) if self.dataset=='midlevel':
return F.mse_loss(y_hat, y)
else:
return my_loss(y_hat, y)
def forward_full_song(self, x, y): def forward_full_song(self, x, y):
# print(x.shape) # print(x.shape)
...@@ -173,51 +201,66 @@ class ModelMidlevel(pl.LightningModule): ...@@ -173,51 +201,66 @@ class ModelMidlevel(pl.LightningModule):
# return y_hat/count # return y_hat/count
def training_step(self, data_batch, batch_nb): def training_step(self, data_batch, batch_nb):
x, _, y = data_batch if self.dataset=='midlevel':
y_hat = self.forward_full_song(x, y) x, _, y = data_batch
y = y.float() y_hat = self.forward_full_song(x, y)
y_hat = y_hat.float() y = y.float()
return {'loss':self.my_loss(y_hat, y)} y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
else:
return training_step(self, data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb): def validation_step(self, data_batch, batch_nb):
x, _, y = data_batch if self.dataset=='midlevel':
y_hat = self.forward_full_song(x, y) x, _, y = data_batch
y = y.float() y_hat = self.forward_full_song(x, y)
y_hat = y_hat.float() y = y.float()
return {'val_loss': self.my_loss(y_hat, y), y_hat = y_hat.float()
'y': y.cpu().numpy(), return {'val_loss': self.my_loss(y_hat, y),
'y_hat': y_hat.cpu().numpy(), 'y': y.cpu().numpy(),
} 'y_hat': y_hat.cpu().numpy(),
}
else:
return validation_step(self, data_batch, batch_nb)
def validation_end(self, outputs): def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() if self.dataset=='midlevel':
y = [] avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
y_hat = [] y = []
for output in outputs: y_hat = []
y.append(output['y']) for output in outputs:
y_hat.append(output['y_hat']) y.append(output['y'])
y_hat.append(output['y_hat'])
y = np.concatenate(y) y = np.concatenate(y)
y_hat = np.concatenate(y_hat) y_hat = np.concatenate(y_hat)
return {'val_loss': avg_loss} return {'val_loss': avg_loss}
else:
return validation_end(outputs)
def test_step(self, data_batch, batch_nb): def test_step(self, data_batch, batch_nb):
x, _, y = data_batch if self.dataset == 'midlevel':
y_hat = self.forward_full_song(x, y) x, _, y = data_batch
y = y.float() y_hat = self.forward_full_song(x, y)
y_hat = y_hat.float() y = y.float()
return {'test_loss': self.my_loss(y_hat, y), y_hat = y_hat.float()
'y': y.cpu().numpy(), return {'test_loss': self.my_loss(y_hat, y),
'y_hat': y_hat.cpu().numpy(), 'y': y.cpu().numpy(),
} 'y_hat': y_hat.cpu().numpy(),
}
else:
return test_step(self,data_batch,batch_nb)
def test_end(self, outputs): def test_end(self, outputs):
avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean() if self.dataset == 'midlevel':
test_metrics = {"test_loss":avg_test_loss} avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
self.experiment.log(test_metrics) test_metrics = {"test_loss":avg_test_loss}
return test_metrics self.experiment.log(test_metrics)
return test_metrics
else:
return test_end(outputs)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -30,7 +30,7 @@ data_roots = { ...@@ -30,7 +30,7 @@ data_roots = {
"shreyan-All-Series": "/mnt/2tb/datasets/MTG-Jamendo" "shreyan-All-Series": "/mnt/2tb/datasets/MTG-Jamendo"
}, },
"midlevel":{ "midlevel":{
"rechenknecht3.cp.jku.at": "/media/rk3/shared/datasets/midlevel", "rechenknecht3.cp.jku.at": "/media/rk3/shared/midlevel",
"rechenknecht2.cp.jku.at": "/media/rk2/shared/datasets/midlevel", "rechenknecht2.cp.jku.at": "/media/rk2/shared/datasets/midlevel",
"rechenknecht1.cp.jku.at": "/media/rk1/shared/datasets/midlevel", "rechenknecht1.cp.jku.at": "/media/rk1/shared/datasets/midlevel",
"hermine":"", "hermine":"",
...@@ -98,7 +98,17 @@ def set_paths(dataset_name): ...@@ -98,7 +98,17 @@ def set_paths(dataset_name):
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations') PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
def get_paths(): def get_paths(dataset_name):
PATH_DATA_ROOT = data_roots[dataset_name][hostname]
if dataset_name == 'midlevel':
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'metadata_annotations')
elif dataset_name == 'mtgjamendo':
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
else:
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'annotations')
return PATH_DATA_ROOT, PATH_AUDIO, PATH_ANNOTATIONS return PATH_DATA_ROOT, PATH_AUDIO, PATH_ANNOTATIONS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment