Commit ddf6c7cc authored by Shreyan Chowdhury's avatar Shreyan Chowdhury Committed by Shreyan Chowdhury
Browse files

add paths for shreyan-HP, minor refactoring

parent bc052327
import torch import torch
class PadSequence: class PadSequence:
def __call__(self, batch): def __call__(self, batch):
# print("PadSequence is called") # print("PadSequence is called")
......
...@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen): ...@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache = {} t2_parse_labels_cache = {}
def t2_parse_labels(csvf): def midlevel_parse_labels(csvf):
global t2_parse_labels_cache global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None: if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf) return t2_parse_labels_cache.get(csvf)
...@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name): ...@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
print("loading dataset from '{}'".format(name)) print("loading dataset from '{}'".format(name))
def getdatset(): def getdatset():
files, labels = t2_parse_labels(midlevel_files_csv) files, labels = midlevel_parse_labels(midlevel_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor) return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function, df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
......
...@@ -31,7 +31,7 @@ def full_song_slicing_function(h5data, idx, xlen): ...@@ -31,7 +31,7 @@ def full_song_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache = {} t2_parse_labels_cache = {}
def t2_parse_labels(csvf): def mtgjamendo_parse_labels(csvf):
global t2_parse_labels_cache global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None: if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf) return t2_parse_labels_cache.get(csvf)
...@@ -111,7 +111,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N ...@@ -111,7 +111,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
print("loading dataset from '{}'".format(name)) print("loading dataset from '{}'".format(name))
def getdatset(): def getdatset():
files, labels, label_encoder = t2_parse_labels(mtg_files_csv) files, labels, label_encoder = mtgjamendo_parse_labels(mtg_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor) return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
if slicing_func is None: if slicing_func is None:
......
...@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series': ...@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series':
path_mtgjamendo_annotations_test = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv' path_mtgjamendo_annotations_test = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio' path_mtgjamendo_audio_dir = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series']: if hostname == 'shreyan-HP':
raise Exception(f"Paths not defined for {hostname}") rk = 'rk2'
\ No newline at end of file path_data_cache = '/mnt/2tb/datasets/data_caches'
# midlevel
path_midlevel_annotations_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series', 'shreyan-HP']:
raise Exception(f"Paths not defined for {hostname}")
...@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network ...@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network
import os import os
model_config = { model_config = {
'data_source':'mtgjamendo', 'data_source': 'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'], 'validation_metrics': ['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc'] 'test_metrics': ['rocauc', 'prauc']
} }
initialized = False # TODO: Find a better way to do this initialized = False # TODO: Find a better way to do this
trial_counter = 0 trial_counter = 0
def run(hparams): def run(hparams):
global initialized, trial_counter global initialized, trial_counter
trial_counter += 1 trial_counter += 1
...@@ -56,7 +57,7 @@ def run(hparams): ...@@ -56,7 +57,7 @@ def run(hparams):
train_percent_check=hparams.train_percent, train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop, fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback, checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run nb_sanity_val_steps=0) # don't run sanity validation run
else: else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1, trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True) fast_dev_run=True)
...@@ -68,7 +69,7 @@ def run(hparams): ...@@ -68,7 +69,7 @@ def run(hparams):
trainer.test() trainer.test()
if __name__=='__main__': if __name__ == '__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str, parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name') default='pt_lightning_exp_a', help='test tube exp name')
...@@ -76,14 +77,14 @@ if __name__=='__main__': ...@@ -76,14 +77,14 @@ if __name__=='__main__':
default=1.0, help='how much train data to use') default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int, parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs') default=10, help='maximum number of epochs')
#parent_parser.add_argument('--gpus', type=list, default=[0,1], # parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.' # help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node') # ' value -1 uses all the gpus on the node')
parser = Network.add_model_specific_args(parent_parser) parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args() hyperparams = parser.parse_args()
# run(hyperparams) # run(hyperparams)
#gpus = ['cuda:0', 'cuda:1'] # gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5) # hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams) # run(hyperparams)
for hparam_trial in hyperparams.trials(20): for hparam_trial in hyperparams.trials(20):
run(hparam_trial) run(hparam_trial)
...@@ -15,10 +15,10 @@ from datasets.collate import PadSequence ...@@ -15,10 +15,10 @@ from datasets.collate import PadSequence
# example config dict # example config dict
base_model_config = { base_model_config = {
'data_source':'mtgjamendo', 'data_source': 'mtgjamendo',
'training_metrics':['loss'], 'training_metrics': ['loss'],
'validation_metrics':['loss', 'prauc', 'rocauc'], 'validation_metrics': ['loss', 'prauc', 'rocauc'],
'test_metrics':['loss', 'prauc', 'rocauc'] 'test_metrics': ['loss', 'prauc', 'rocauc']
} }
...@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule): ...@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule):
self.validation_metrics = config.get('validation_metrics') self.validation_metrics = config.get('validation_metrics')
self.test_metrics = config.get('test_metrics') self.test_metrics = config.get('test_metrics')
if self.data_source=='midlevel': if self.data_source == 'midlevel':
dataset, dataset_length = df_get_midlevel_set('midlevel', dataset, dataset_length = df_get_midlevel_set('midlevel',
path_midlevel_annotations, path_midlevel_annotations,
path_midlevel_audio_dir, path_midlevel_audio_dir,
...@@ -177,27 +177,26 @@ class BasePtlModel(pl.LightningModule): ...@@ -177,27 +177,26 @@ class BasePtlModel(pl.LightningModule):
@pl.data_loader @pl.data_loader
def tng_dataloader(self): def tng_dataloader(self):
if self.data_source=='mtgjamendo': if self.data_source == 'mtgjamendo':
dataset = df_get_mtg_set('mtgjamendo', dataset = df_get_mtg_set('mtgjamendo',
path_mtgjamendo_annotations_train, path_mtgjamendo_annotations_train,
path_mtgjamendo_audio_dir, path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer, "_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size) slice_len=self.input_size)
elif self.data_source == 'midlevel':
elif self.data_source=='midlevel':
dataset = self.midlevel_trainset dataset = self.midlevel_trainset
else: else:
raise Exception(f"Data source {self.data_source} not defined") raise Exception(f"Data source {self.data_source} not defined")
if self.slicing_mode == 'full': if self.slicing_mode == 'full':
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size, batch_size=self.hparams.batch_size,
shuffle=True, shuffle=True,
collate_fn=PadSequence()) collate_fn=PadSequence())
else:
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size, batch_size=self.hparams.batch_size,
shuffle=True) shuffle=True)
@pl.data_loader @pl.data_loader
def val_dataloader(self): def val_dataloader(self):
...@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule): ...@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir, path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer, "_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size) slice_len=self.input_size)
elif self.data_source == 'midlevel': elif self.data_source == 'midlevel':
dataset = self.midlevel_valset dataset = self.midlevel_valset
else: else:
raise Exception(f"Data source {self.data_source} not defined") raise Exception(f"Data source {self.data_source} not defined")
...@@ -219,10 +216,10 @@ class BasePtlModel(pl.LightningModule): ...@@ -219,10 +216,10 @@ class BasePtlModel(pl.LightningModule):
batch_size=self.hparams.batch_size, batch_size=self.hparams.batch_size,
shuffle=True, shuffle=True,
collate_fn=PadSequence()) collate_fn=PadSequence())
else:
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size, batch_size=self.hparams.batch_size,
shuffle=True) shuffle=True)
@pl.data_loader @pl.data_loader
def test_dataloader(self): def test_dataloader(self):
...@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule): ...@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir, path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer, "_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size) slice_len=self.input_size)
elif self.data_source == 'midlevel': elif self.data_source == 'midlevel':
dataset = self.midlevel_testset dataset = self.midlevel_testset
else: else:
raise Exception(f"Data source {self.data_source} not defined") raise Exception(f"Data source {self.data_source} not defined")
...@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule): ...@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule):
batch_size=self.hparams.batch_size, batch_size=self.hparams.batch_size,
shuffle=True, shuffle=True,
collate_fn=PadSequence()) collate_fn=PadSequence())
else:
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size, batch_size=self.hparams.batch_size,
shuffle=True) shuffle=True)
...@@ -70,6 +70,8 @@ elif hostname == 'verena-830g5': # Laptop Verena ...@@ -70,6 +70,8 @@ elif hostname == 'verena-830g5': # Laptop Verena
USE_GPU = False USE_GPU = False
elif hostname == 'shreyan-HP': # Laptop Shreyan elif hostname == 'shreyan-HP': # Laptop Shreyan
USE_GPU = False USE_GPU = False
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk2/shared/kofta_cached_datasets'
else: else:
# PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets' # PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
PATH_DATA_CACHE = '/mnt/2tb/datasets/data_caches' PATH_DATA_CACHE = '/mnt/2tb/datasets/data_caches'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment