Commit 6c53d1cb authored by Shreyan Chowdhury's avatar Shreyan Chowdhury

Merge branch 'full_song'

parents 9c1ab51c 792c6a21
......@@ -2,20 +2,26 @@ import torch
class PadSequence:
def __call__(self, batch):
# print("PadSequence is called")
# Let's assume that each element in "batch" is a tuple (data, label).
# Sort the batch in the descending order
sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
# Get each sequence and pad it
sequences = [x[0] for x in sorted_batch]
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
# Also need to store the length of each sequence
# This is later needed in order to unpad the sequences
lengths = torch.LongTensor([len(x) for x in sequences])
# # print("PadSequence is called")
# # Let's assume that each element in "batch" is a tuple (data, label).
# # Sort the batch in the descending order
# sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
# # Get each sequence and pad it
# sequences = [x[0] for x in sorted_batch]
# sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
# # Also need to store the length of each sequence
# # This is later needed in order to unpad the sequences
# lengths = torch.LongTensor([len(x) for x in sequences])
#
# # Don't forget to grab the labels of the *sorted* batch
# labels = [x[2] for x in sorted_batch]
# # labels = torch.LongTensor((map(lambda x: x[1], sorted_batch)))
# # print(labels)
# # labels = torch.LongTensor(labels)
# return sequences_padded, lengths, labels
h5data = batch[0][0][0]
idx = [i[0][1] for i in batch]
lengths = [i[0][2] for i in batch]
labels = [i[2] for i in batch]
# Don't forget to grab the labels of the *sorted* batch
labels = [x[2] for x in sorted_batch]
# labels = torch.LongTensor((map(lambda x: x[1], sorted_batch)))
# print(labels)
# labels = torch.LongTensor(labels)
return sequences_padded, lengths, labels
return h5data, idx, lengths, labels
......@@ -74,139 +74,6 @@ class MelSpecDataset(Dataset):
return tagslist
class HDF5Dataset(Dataset):
"""Represents an abstract HDF5 dataset.
Input params:
file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
recursive: If True, searches for h5 files in subdirectories.
load_data: If True, loads all the data immediately into RAM. Use this if
the dataset is fits into memory. Otherwise, leave this at false and
the data will load lazily.
data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
transform: PyTorch transform to apply to every data instance (default=None).
"""
def __init__(self, file_path, recursive, load_data, data_cache_size=3, transform=None):
super().__init__()
self.data_info = []
self.data_cache = {}
self.data_cache_size = data_cache_size
self.transform = transform
# Search for all h5 files
p = Path(file_path)
assert (p.is_dir())
if recursive:
files = sorted(p.glob('**/*.h5'))
else:
files = sorted(p.glob('train.h5'))
if len(files) < 1:
raise RuntimeError('No hdf5 datasets found')
for h5dataset_fp in files:
self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
# self._add_data_infos(file_path, load_data)
def __getitem__(self, index):
# get data
x = self.get_data("data", index)
reqd_len = MAX_LENGTH
spec_len = x.shape[1]
x = x[:,:reqd_len] if spec_len > reqd_len else np.pad(x, ((0, 0), (0, reqd_len-spec_len)), mode='wrap')
if self.transform:
x = self.transform(x)
else:
x = torch.from_numpy(x)
# get label
y = self.get_data("label", index)
y = torch.from_numpy(y)
return (x, y)
def __len__(self):
return len(self.get_data_infos('data'))
def _add_data_infos(self, file_path, load_data):
with h5py.File(file_path) as h5_file:
# Walk through all groups, extracting datasets
for gname, group in h5_file.items():
for dname, ds in group.items():
# if data is not loaded its cache index is -1
idx = -1
if load_data:
# add data to the data cache
idx = self._add_to_cache(ds.value, file_path)
# type is derived from the name of the dataset; we expect the dataset
# name to have a name such as 'data' or 'label' to identify its type
# we also store the shape of the data in case we need it
self.data_info.append(
{'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx})
def _load_data(self, file_path):
"""Load data to the cache given the file
path and update the cache index in the
data_info structure.
"""
with h5py.File(file_path) as h5_file:
for gname, group in h5_file.items():
for dname, ds in group.items():
# add data to the data cache and retrieve
# the cache index
idx = self._add_to_cache(ds.value, file_path)
# find the beginning index of the hdf5 file we are looking for
file_idx = next(i for i, v in enumerate(self.data_info) if v['file_path'] == file_path)
# the data info should have the same index since we loaded it in the same way
self.data_info[file_idx + idx]['cache_idx'] = idx
# remove an element from data cache if size was exceeded
if len(self.data_cache) > self.data_cache_size:
# remove one item from the cache at random
removal_keys = list(self.data_cache)
removal_keys.remove(file_path)
self.data_cache.pop(removal_keys[0])
# remove invalid cache_idx
self.data_info = [
{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1}
if di['file_path'] == removal_keys[0] else di
for di in self.data_info
]
def _add_to_cache(self, data, file_path):
"""Adds data to the cache and returns its index. There is one cache
list for every file_path, containing all datasets in that file.
"""
if file_path not in self.data_cache:
self.data_cache[file_path] = [data]
else:
self.data_cache[file_path].append(data)
return len(self.data_cache[file_path]) - 1
def get_data_infos(self, type):
"""Get data infos belonging to a certain type of data.
"""
data_info_type = [di for di in self.data_info if di['type'] == type]
return data_info_type
def get_data(self, type, i):
"""Call this function anytime you want to access a chunk of data from the
dataset. This will make sure that the data is loaded in case it is
not part of the data cache.
"""
fp = self.get_data_infos(type)[i]['file_path']
if fp not in self.data_cache:
self._load_data(fp)
# get new cache_idx assigned by _load_data_info
cache_idx = self.get_data_infos(type)[i]['cache_idx']
return self.data_cache[fp][cache_idx]
class AudioPreprocessDataset(Dataset):
"""A bases preprocessing dataset representing a Dataset of files that are loaded and preprossessed on the fly.
......@@ -293,8 +160,8 @@ class H5FCachedDataset(Dataset):
torch.utils.data.DataLoader(getDataset(), batch_size=1, shuffle=False, num_workers=36)):
# fixing the shapes
x = x[0].numpy().transpose(2, 0, 1)
y = y[0]
z = z[0]
y = y[0] # audio filepath relative
z = z[0] # labels
x = x.reshape(x.shape[0], -1)
if d is None:
d = f.create_dataset('data', (0, x.shape[1]), maxshape=(None, x.shape[1]), dtype='f', chunks=True)
......@@ -314,7 +181,7 @@ class H5FCachedDataset(Dataset):
def __getitem__(self, index):
cpath = os.path.join(self.cache_path, str(index) + "_meta.pt")
# x.transpose(2,0,1).transpose(1,2,0) store columns first
(idx, xlen), y, z = torch.load(cpath)
(idx, xlen), y, z = torch.load(cpath) # y: audio filepath relative, z: labels
x = self.slicing_function(self.h5data, idx, xlen)
if self.augment_options is not None:
......
......@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache = {}
def t2_parse_labels(csvf):
def midlevel_parse_labels(csvf):
global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf)
......@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
print("loading dataset from '{}'".format(name))
def getdatset():
files, labels = t2_parse_labels(midlevel_files_csv)
files, labels = midlevel_parse_labels(midlevel_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
......
......@@ -15,23 +15,12 @@ def sample_slicing_function(h5data, idx, xlen):
return torch.from_numpy(x.transpose(1, 0).reshape(1, 256, timeframes))
def full_song_slicing_function(h5data, idx, xlen):
#TODO: not working, make it work if possible.
maxlen = 2048
if xlen > maxlen:
k = torch.randint(xlen - maxlen + 1, (1,))[0].item()
x = h5data[idx + k:idx + k + maxlen]
print(x.shape)
else:
x = h5data[idx:idx+xlen]
x = np.pad(x, ((0, maxlen - xlen), (0, 0)), mode='wrap')
# print(x.shape)
return torch.from_numpy(x.transpose((1, 0)).reshape((1, 256, -1)))
return (h5data, idx, xlen)
t2_parse_labels_cache = {}
def t2_parse_labels(csvf):
def mtgjamendo_parse_labels(csvf):
global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf)
......@@ -111,7 +100,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
print("loading dataset from '{}'".format(name))
def getdatset():
files, labels, label_encoder = t2_parse_labels(mtg_files_csv)
files, labels, label_encoder = mtgjamendo_parse_labels(mtg_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
if slicing_func is None:
......
......@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series':
path_mtgjamendo_annotations_test = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series']:
raise Exception(f"Paths not defined for {hostname}")
\ No newline at end of file
if hostname == 'shreyan-HP':
rk = 'rk2'
path_data_cache = '/mnt/2tb/datasets/data_caches'
# midlevel
path_midlevel_annotations_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = f'/home/shreyan/mounts/home@{rk}/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series', 'shreyan-HP']:
raise Exception(f"Paths not defined for {hostname}")
from utils import USE_GPU, init_experiment
from utils import USE_GPU, init_experiment, exit_experiment
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment, HyperOptArgumentParser
......@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network
import os
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
'data_source': 'mtgjamendo',
'validation_metrics': ['rocauc', 'prauc'],
'test_metrics': ['rocauc', 'prauc']
}
initialized = False # TODO: Find a better way to do this
initialized = False # TODO: Find a better way to do this
trial_counter = 0
def run(hparams):
global initialized, trial_counter
trial_counter += 1
......@@ -56,19 +57,35 @@ def run(hparams):
train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run
nb_sanity_val_steps=0) # don't run sanity validation run
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
fast_dev_run=True, nb_sanity_val_steps=0)
model = Network(num_class=56, config=model_config, hparams=hparams)
print(model)
trainer.fit(model)
trainer.test()
try:
trainer.fit(model)
except KeyboardInterrupt:
logger.info("Training interrupted")
except:
logger.exception(msg="Error occurred during train!")
exit_experiment('failed', exp)
try:
logger.info("Starting test...")
trainer.test()
except KeyboardInterrupt:
logger.info("Exiting...")
exit_experiment('stopped', exp)
except:
logger.exception(msg="Error occurred during test!")
exit_experiment('failed', exp)
if __name__=='__main__':
if __name__ == '__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
......@@ -76,14 +93,14 @@ if __name__=='__main__':
default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs')
#parent_parser.add_argument('--gpus', type=list, default=[0,1],
# parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams)
for hparam_trial in hyperparams.trials(20):
run(hparam_trial)
# gpus = ['cuda:0', 'cuda:1']
# hyperparams.optimize_parallel_gpu(run, gpus, 5)
run(hyperparams)
# for hparam_trial in hyperparams.trials(20):
# run(hparam_trial)
......@@ -48,15 +48,78 @@ class CRNN(BasePtlModel):
self.dropout = nn.Dropout(self.hparams.drop_prob)
def forward_full_song(self, batch):
def cnn_forward(x):
# init bn
x = self.bn_init(x)
# layer 1
x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x))))
# layer 2
x = nn.ELU()(self.bn_2(self.conv_2(x)))
# x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x))))
# layer 3
x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x))))
# layer 4
x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x))))
# layer 5
x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x))))
# classifier
x = x.view(-1, x.size(0), 32)
return x
def rnn_forward(x):
x = x.squeeze()
x = self.gru1(x)[1][1] # TODO: Check if this is correct
x = self.dropout(x)
logit = nn.Sigmoid()(self.dense(x))
return logit
def extract_features(song_idx, song_length):
# print(song_idx, song_length)
song_length = 2560
song_feats = []
for l in range(song_length//self.input_size + 1):
data = h5data[song_idx + l*self.input_size:song_idx + min(song_length, (l + 1) * self.input_size)].transpose()
if data.shape[1] < self.input_size*0.25:
continue
data = np.pad(data, ((0, 0), (0, self.input_size-data.shape[1])), mode='wrap')
try:
song_feats.append(cnn_forward(torch.tensor([[data]], device=torch.device('cuda'))))
except AssertionError:
# print(song_idx, song_length)
song_feats.append(cnn_forward(torch.tensor([[data]], device=torch.device('cpu'))))
# print("song feats", song_feats.__len__(), song_feats[0].shape)
return torch.cat(song_feats)
h5data, idx_list, x_lengths_list, labels_list = batch
sequences = []
for n, ind in enumerate(idx_list):
sequences.append(extract_features(ind, x_lengths_list[n]))
# print("sequences", sequences.__len__(), sequences[0].shape)
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False)
result = rnn_forward(sequences_padded)
return result
def forward(self, batch):
#print("batch", batch)
x, x_lengths, _ = batch
#print("x", x)
#print("xlen", x_lengths)
if self.slicing_mode == 'full':
print("before pack", x, x_lengths)
x = pack_padded_sequence(x, x_lengths, batch_first=True)
if self.slicing_mode=='full':
logit = self.forward_full_song(batch)
return logit
x, _, _ = batch # xs, xlens, labels
# init bn
x = self.bn_init(x)
......@@ -89,16 +152,16 @@ class CRNN(BasePtlModel):
return logit
def training_step(self, data_batch, batch_i):
_, _, y = data_batch
y = data_batch[-1]
y_hat = self.forward(data_batch)
y = y.float()
y = torch.stack(y).float()
y_hat = y_hat.float()
return {'loss': self.loss(y_hat, y)}
def validation_step(self, data_batch, batch_i):
x, _, y = data_batch
y = data_batch[-1]
y_hat = self.forward(data_batch)
y = y.float()
y = torch.stack(y).float()
y_hat = y_hat.float()
return {
'val_loss': self.loss(y_hat, y),
......@@ -107,9 +170,9 @@ class CRNN(BasePtlModel):
}
def test_step(self, data_batch, batch_i):
x, _, y = data_batch
y = data_batch[-1]
y_hat = self.forward(data_batch)
y = y.float()
y = torch.stack(y).float()
y_hat = y_hat.float()
return {
'test_loss': self.loss(y_hat, y),
......@@ -126,13 +189,11 @@ class CRNN(BasePtlModel):
# network params
parser.add_argument('--gru_hidden_size', default=320, type=int)
parser.add_argument('--gru_num_layers', default=2, type=int)
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=True)
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
parser.opt_list('--learning_rate', default=0.0001, type=float,
options=[0.00001, 0.0005, 0.001],
tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
options=[0.00001, 0.0005, 0.001], tunable=False)
parser.opt_list('--slicing_mode', default='full', options=['full', 'slice'], type=str, tunable=True)
parser.opt_list('--input_size', default=512, options=[512, 1024], type=int, tunable=False)
# training params (opt)
parser.opt_list('--optimizer_name', default='adam', type=str,
......@@ -140,7 +201,7 @@ class CRNN(BasePtlModel):
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
parser.opt_list('--batch_size', default=32, type=int,
options=[16, 32], tunable=False,
parser.opt_list('--batch_size', default=16, type=int,
options=[16, 8], tunable=True,
help='batch size will be divided over all gpus being used across all nodes')
return parser
......@@ -15,10 +15,10 @@ from datasets.collate import PadSequence
# example config dict
base_model_config = {
'data_source':'mtgjamendo',
'training_metrics':['loss'],
'validation_metrics':['loss', 'prauc', 'rocauc'],
'test_metrics':['loss', 'prauc', 'rocauc']
'data_source': 'mtgjamendo',
'training_metrics': ['loss'],
'validation_metrics': ['loss', 'prauc', 'rocauc'],
'test_metrics': ['loss', 'prauc', 'rocauc']
}
......@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule):
self.validation_metrics = config.get('validation_metrics')
self.test_metrics = config.get('test_metrics')
if self.data_source=='midlevel':
if self.data_source == 'midlevel':
dataset, dataset_length = df_get_midlevel_set('midlevel',
path_midlevel_annotations,
path_midlevel_audio_dir,
......@@ -177,27 +177,26 @@ class BasePtlModel(pl.LightningModule):
@pl.data_loader
def tng_dataloader(self):
if self.data_source=='mtgjamendo':
if self.data_source == 'mtgjamendo':
dataset = df_get_mtg_set('mtgjamendo',
path_mtgjamendo_annotations_train,
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source=='midlevel':
elif self.data_source == 'midlevel':
dataset = self.midlevel_trainset
else:
raise Exception(f"Data source {self.data_source} not defined")
if self.slicing_mode == 'full':
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
else:
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
@pl.data_loader
def val_dataloader(self):
......@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source == 'midlevel':
dataset = self.midlevel_valset
else:
raise Exception(f"Data source {self.data_source} not defined")
......@@ -219,10 +216,10 @@ class BasePtlModel(pl.LightningModule):
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
else:
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
@pl.data_loader
def test_dataloader(self):
......@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source == 'midlevel':
dataset = self.midlevel_testset
else:
raise Exception(f"Data source {self.data_source} not defined")
......@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule):
batch_size=self.hparams.batch_size,