Commit 4183d2b7 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

added h5py stuff

parent ca903fab
......@@ -4,7 +4,7 @@ from processors.spectrogram_processors import make_framed_spec
class MelSpecDataset(Dataset):
def __init__(self, phase='train', ann_root=None, spec_root=None, length=MAX_FRAMES, framed=True):
def __init__(self, phase='train', ann_root=None, spec_root=None, length=MAX_LENGTH, framed=True):
assert ann_root is not None, logger.error("ann_root (root directory containing annotation files) required")
assert spec_root is not None, logger.error("spec_root (root directory of spectrograms) required")
assert phase in ['train', 'validation', 'test'], \
......@@ -69,11 +69,153 @@ class MelSpecDataset(Dataset):
return tagslist
import h5py
import numpy as np
from pathlib import Path
import torch
from torch.utils import data
class HDF5Dataset(Dataset):
"""Represents an abstract HDF5 dataset.
Input params:
file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
recursive: If True, searches for h5 files in subdirectories.
load_data: If True, loads all the data immediately into RAM. Use this if
the dataset is fits into memory. Otherwise, leave this at false and
the data will load lazily.
data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
transform: PyTorch transform to apply to every data instance (default=None).
"""
def __init__(self, file_path, recursive, load_data, data_cache_size=3, transform=None):
super().__init__()
self.data_info = []
self.data_cache = {}
self.data_cache_size = data_cache_size
self.transform = transform
# Search for all h5 files
p = Path(file_path)
assert (p.is_dir())
if recursive:
files = sorted(p.glob('**/*.h5'))
else:
files = sorted(p.glob('train.h5'))
if len(files) < 1:
raise RuntimeError('No hdf5 datasets found')
for h5dataset_fp in files:
self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
# self._add_data_infos(file_path, load_data)
def __getitem__(self, index):
# get data
x = self.get_data("data", index)
reqd_len = MAX_LENGTH
spec_len = x.shape[1]
x = x[:,:reqd_len] if spec_len > reqd_len else np.pad(x, ((0, 0), (0, reqd_len-spec_len)), mode='wrap')
if self.transform:
x = self.transform(x)
else:
x = torch.from_numpy(x)
# get label
y = self.get_data("label", index)
y = torch.from_numpy(y)
return (x, y)
def __len__(self):
return len(self.get_data_infos('data'))
def _add_data_infos(self, file_path, load_data):
with h5py.File(file_path) as h5_file:
# Walk through all groups, extracting datasets
for gname, group in h5_file.items():
for dname, ds in group.items():
# if data is not loaded its cache index is -1
idx = -1
if load_data:
# add data to the data cache
idx = self._add_to_cache(ds.value, file_path)
# type is derived from the name of the dataset; we expect the dataset
# name to have a name such as 'data' or 'label' to identify its type
# we also store the shape of the data in case we need it
self.data_info.append(
{'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx})
def _load_data(self, file_path):
"""Load data to the cache given the file
path and update the cache index in the
data_info structure.
"""
with h5py.File(file_path) as h5_file:
for gname, group in h5_file.items():
for dname, ds in group.items():
# add data to the data cache and retrieve
# the cache index
idx = self._add_to_cache(ds.value, file_path)
# find the beginning index of the hdf5 file we are looking for
file_idx = next(i for i, v in enumerate(self.data_info) if v['file_path'] == file_path)
# the data info should have the same index since we loaded it in the same way
self.data_info[file_idx + idx]['cache_idx'] = idx
# remove an element from data cache if size was exceeded
if len(self.data_cache) > self.data_cache_size:
# remove one item from the cache at random
removal_keys = list(self.data_cache)
removal_keys.remove(file_path)
self.data_cache.pop(removal_keys[0])
# remove invalid cache_idx
self.data_info = [
{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1}
if di['file_path'] == removal_keys[0] else di
for di in self.data_info
]
def _add_to_cache(self, data, file_path):
"""Adds data to the cache and returns its index. There is one cache
list for every file_path, containing all datasets in that file.
"""
if file_path not in self.data_cache:
self.data_cache[file_path] = [data]
else:
self.data_cache[file_path].append(data)
return len(self.data_cache[file_path]) - 1
def get_data_infos(self, type):
"""Get data infos belonging to a certain type of data.
"""
data_info_type = [di for di in self.data_info if di['type'] == type]
return data_info_type
def get_data(self, type, i):
"""Call this function anytime you want to access a chunk of data from the
dataset. This will make sure that the data is loaded in case it is
not part of the data cache.
"""
fp = self.get_data_infos(type)[i]['file_path']
if fp not in self.data_cache:
self._load_data(fp)
# get new cache_idx assigned by _load_data_info
cache_idx = self.get_data_infos(type)[i]['cache_idx']
return self.data_cache[fp][cache_idx]
if __name__=='__main__':
# Tests
torch.manual_seed(6)
dataset = MelSpecDataset(phase='train', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
# dataset = MelSpecDataset(phase='train', ann_root=PATH_ANNOTATIONS,
# spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
dataset = HDF5Dataset('/mnt/2tb/datasets/MTG-Jamendo/HDF5Cache_spectrograms/', recursive=False, load_data=False)
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
......
......@@ -89,7 +89,7 @@ def correlations():
if __name__=='__main__':
# stats = compute_duration_stats(os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv'))
# stats = compute_melspec_length_stats(os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv'))
# print(stats)
stats = compute_melspec_length_stats(os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv'))
print(stats)
# compute_tag_stats()
correlations()
\ No newline at end of file
# correlations()
\ No newline at end of file
import torch.nn as nn
from utils import *
from datasets.datasets import MelSpecDataset
from datasets.datasets import HDF5Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -81,4 +81,24 @@ class CNN(pl.LightningModule):
return F.binary_cross_entropy(y_hat, y)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
\ No newline at end of file
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader
def tng_dataloader(self):
trainset = HDF5Dataset(os.path.join(PATH_MELSPEC_DOWNLOADED_HDF5, 'train.h5'), recursive=False, load_data=False)
return DataLoader(dataset=trainset, batch_size=32, shuffle=True)
@pl.data_loader
def val_dataloader(self):
validationset = HDF5Dataset(os.path.join(PATH_MELSPEC_DOWNLOADED_HDF5, 'val.h5'), recursive=False, load_data=False)
return DataLoader(dataset=validationset, batch_size=128, shuffle=True)
@pl.data_loader
def test_dataloader(self):
testset = HDF5Dataset(os.path.join(PATH_MELSPEC_DOWNLOADED_HDF5, 'test.h5'), recursive=False, load_data=False)
return DataLoader(dataset=testset, batch_size=32, shuffle=True)
@staticmethod
def add_model_specific_args(parent_parser, root_dir):
return parent_parser
pass
\ No newline at end of file
......@@ -119,6 +119,8 @@ class Network(pl.LightningModule):
# self.optimizer.zero_grad()
# loss.backward()
# self.optimizer.step()
#
# TODO response: they are taken care of in __run_tng_batch() inside Trainer
x, y = data_batch
y_hat = self.forward_full_song(x, y)
......
from utils import *
import h5py
def trim_silence(spec, thresh=0.1):
"""
......@@ -90,3 +90,49 @@ def preprocess_specs(source_root, destination_root, frame_length=256, hop=1.0):
spec = trim_silence(spec)
framed_spec = make_framed_spec(spec, frame_length=frame_length, hop=hop)
np.save(destination, framed_spec)
def hdf_cache_specs(source_root, destination_root, annotations_file='all', output_filename=None):
"""
Makes HDF5 cache of spectrograms
"""
if annotations_file == 'all':
ann_tr = pd.read_csv(os.path.join(PATH_ANNOTATIONS, f'train_processed.tsv'), sep='\t')
ann_val = pd.read_csv(os.path.join(PATH_ANNOTATIONS, f'validation_processed.tsv'), sep='\t')
ann_test = pd.read_csv(os.path.join(PATH_ANNOTATIONS, f'test_processed.tsv'), sep='\t')
annotations = pd.concat([ann_tr, ann_val, ann_test], axis=0, ignore_index=True)
else:
annotations = pd.read_csv(annotations_file, sep='\t')
if not os.path.exists(destination_root):
os.mkdir(destination_root)
if output_filename is None:
output_filename = 'dataset.h5'
hdf_filepath = os.path.join(destination_root, output_filename)
tagslist = np.load(os.path.join(PATH_PROJECT_ROOT, 'tagslist.npy'))
with h5py.File(hdf_filepath, 'w') as hdf:
for idx in tqdm(annotations.index):
filename = annotations.PATH.iloc[idx].split('.')[0] # discard '.mp3' extension
labels_str = annotations.TAGS.iloc[idx] # get labels in string format
labels_onehot = np.array([int(i in labels_str) for i in tagslist]) # convert to onehot encoding
filepath = os.path.join(source_root, filename+'.npy') # path of the melspectrogram stored in .npy format
spec = np.load(filepath) # load the spec from .npy
spec = trim_silence(spec) # trim silence
song = hdf.create_group(filename.split('/')[1])
song.create_dataset('data', data=spec)
song.create_dataset('label', data=labels_onehot)
if __name__=='__main__':
# pass
hdf_cache_specs(PATH_MELSPEC_DOWNLOADED, os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms'),
annotations_file=os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv'), output_filename='train.h5')
hdf_cache_specs(PATH_MELSPEC_DOWNLOADED, os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms'),
annotations_file=os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv'), output_filename='val.h5')
hdf_cache_specs(PATH_MELSPEC_DOWNLOADED, os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms'),
annotations_file=os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv'), output_filename='test.h5')
# pass
\ No newline at end of file
......@@ -32,10 +32,15 @@ def oversampling_test():
def reconstruct():
pass
def check_datasets():
for filename in os.listdir('/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_melspec_downloaded/00'):
f = np.load(os.path.join('/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_melspec_downloaded/00', filename))
if len(f.shape) > 2:
print(filename)
if __name__=='__main__':
check_datasets()
pass
......
......@@ -17,7 +17,7 @@ plt.rcParams["figure.dpi"] = 288 # increase dpi for clearer plots
# PARAMS =======================
INPUT_SIZE = (96, 256)
MAX_FRAMES = 40
MAX_LENGTH = 10000
# CONFIG =======================
......@@ -44,6 +44,7 @@ PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
PATH_MELSPEC_DOWNLOADED = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_melspec_downloaded')
PATH_MELSPEC_DOWNLOADED_FRAMED = os.path.join(PATH_MELSPEC_DOWNLOADED, 'framed')
PATH_MELSPEC_DOWNLOADED_HDF5 = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
PATH_RESULTS = os.path.join(PATH_PROJECT_ROOT, 'results')
TRAINED_MODELS_PATH = ''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment