Commit ddf6c7cc authored by Shreyan Chowdhury's avatar Shreyan Chowdhury Committed by Shreyan Chowdhury

add paths for shreyan-HP, minor refactoring

parent bc052327
import torch
class PadSequence:
def __call__(self, batch):
# print("PadSequence is called")
......
......@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache = {}
def t2_parse_labels(csvf):
def midlevel_parse_labels(csvf):
global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf)
......@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
print("loading dataset from '{}'".format(name))
def getdatset():
files, labels = t2_parse_labels(midlevel_files_csv)
files, labels = midlevel_parse_labels(midlevel_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
......
......@@ -31,7 +31,7 @@ def full_song_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache = {}
def t2_parse_labels(csvf):
def mtgjamendo_parse_labels(csvf):
global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf)
......@@ -111,7 +111,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
print("loading dataset from '{}'".format(name))
def getdatset():
files, labels, label_encoder = t2_parse_labels(mtg_files_csv)
files, labels, label_encoder = mtgjamendo_parse_labels(mtg_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
if slicing_func is None:
......
......@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series':
path_mtgjamendo_annotations_test = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series']:
raise Exception(f"Paths not defined for {hostname}")
\ No newline at end of file
if hostname == 'shreyan-HP':
rk = 'rk2'
path_data_cache = '/mnt/2tb/datasets/data_caches'
# midlevel
path_midlevel_annotations_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series', 'shreyan-HP']:
raise Exception(f"Paths not defined for {hostname}")
......@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network
import os
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
'data_source': 'mtgjamendo',
'validation_metrics': ['rocauc', 'prauc'],
'test_metrics': ['rocauc', 'prauc']
}
initialized = False # TODO: Find a better way to do this
initialized = False # TODO: Find a better way to do this
trial_counter = 0
def run(hparams):
global initialized, trial_counter
trial_counter += 1
......@@ -56,7 +57,7 @@ def run(hparams):
train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run
nb_sanity_val_steps=0) # don't run sanity validation run
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
......@@ -68,7 +69,7 @@ def run(hparams):
trainer.test()
if __name__=='__main__':
if __name__ == '__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
......@@ -76,14 +77,14 @@ if __name__=='__main__':
default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs')
#parent_parser.add_argument('--gpus', type=list, default=[0,1],
# parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
# gpus = ['cuda:0', 'cuda:1']
# hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams)
for hparam_trial in hyperparams.trials(20):
run(hparam_trial)
......@@ -15,10 +15,10 @@ from datasets.collate import PadSequence
# example config dict
base_model_config = {
'data_source':'mtgjamendo',
'training_metrics':['loss'],
'validation_metrics':['loss', 'prauc', 'rocauc'],
'test_metrics':['loss', 'prauc', 'rocauc']
'data_source': 'mtgjamendo',
'training_metrics': ['loss'],
'validation_metrics': ['loss', 'prauc', 'rocauc'],
'test_metrics': ['loss', 'prauc', 'rocauc']
}
......@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule):
self.validation_metrics = config.get('validation_metrics')
self.test_metrics = config.get('test_metrics')
if self.data_source=='midlevel':
if self.data_source == 'midlevel':
dataset, dataset_length = df_get_midlevel_set('midlevel',
path_midlevel_annotations,
path_midlevel_audio_dir,
......@@ -177,27 +177,26 @@ class BasePtlModel(pl.LightningModule):
@pl.data_loader
def tng_dataloader(self):
if self.data_source=='mtgjamendo':
if self.data_source == 'mtgjamendo':
dataset = df_get_mtg_set('mtgjamendo',
path_mtgjamendo_annotations_train,
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source=='midlevel':
elif self.data_source == 'midlevel':
dataset = self.midlevel_trainset
else:
raise Exception(f"Data source {self.data_source} not defined")
if self.slicing_mode == 'full':
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
else:
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
@pl.data_loader
def val_dataloader(self):
......@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source == 'midlevel':
dataset = self.midlevel_valset
else:
raise Exception(f"Data source {self.data_source} not defined")
......@@ -219,10 +216,10 @@ class BasePtlModel(pl.LightningModule):
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
else:
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
@pl.data_loader
def test_dataloader(self):
......@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source == 'midlevel':
dataset = self.midlevel_testset
else:
raise Exception(f"Data source {self.data_source} not defined")
......@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule):
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
else:
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
......@@ -70,6 +70,8 @@ elif hostname == 'verena-830g5': # Laptop Verena
USE_GPU = False
elif hostname == 'shreyan-HP': # Laptop Shreyan
USE_GPU = False
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk2/shared/kofta_cached_datasets'
else:
# PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
PATH_DATA_CACHE = '/mnt/2tb/datasets/data_caches'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment