Commit b04cc45e authored by Verena Praher's avatar Verena Praher
Browse files

add H5FCachedDataset from kofta, adapt/simplify it

parent 4183d2b7
from utils import *
from torch.utils.data import Dataset, DataLoader
from processors.spectrogram_processors import make_framed_spec
import h5py
import numpy as np
from pathlib import Path
import torch
from os.path import expanduser
class MelSpecDataset(Dataset):
def __init__(self, phase='train', ann_root=None, spec_root=None, length=MAX_LENGTH, framed=True):
......@@ -69,11 +73,8 @@ class MelSpecDataset(Dataset):
return tagslist
import h5py
import numpy as np
from pathlib import Path
import torch
from torch.utils import data
class HDF5Dataset(Dataset):
......@@ -209,18 +210,149 @@ class HDF5Dataset(Dataset):
return self.data_cache[fp][cache_idx]
class AudioPreprocessDataset(Dataset):
"""A bases preprocessing dataset representing a Dataset of files that are loaded and preprossessed on the fly.
Access elements via __getitem__ to return: preprocessor(x),sample_id,label
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __init__(self, files, labels, label_encoder, base_dir, preprocessor, return_tensor=True):
self.files = files
self.labels = labels
self.label_encoder = label_encoder
self.base_dir = base_dir
self.preprocessor = preprocessor
self.return_tensor = return_tensor
def __getitem__(self, index):
x = self.preprocessor(self.base_dir + self.files[index])
if self.return_tensor and not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
return x, self.files[index], self.labels[index]
def get_ordered_ids(self):
return self.files
def get_ordered_labels(self):
return self.labels
def __len__(self):
return len(self.files)
class H5FCachedDataset(Dataset):
def __init__(self, get_dataset_func, dataset_name, slicing_function, x_name="", y_name="",
cache_path="~/shared/kofta_cached_datasets/",
):
"""
:param slicing_function: takes h5py dataset, start_idx, and xlen and return the sample
:param get_dataset_func:
:param dataset_name:
:param x_name: useful for large datasets same y differnet x
:param y_name:useful for large datasets same x differnet y
:param cache_path:
"""
self.dataset = None
def getDataset():
if self.dataset == None:
self.dataset = get_dataset_func()
return self.dataset
self.get_dataset_func = getDataset
self.x_name = x_name + y_name
cache_path = expanduser(cache_path)
self.cache_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name)
self.cache_file_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name + ".hdf5")
self.slicing_function = slicing_function
try:
original_umask = os.umask(0)
os.makedirs(self.cache_path, exist_ok=True)
finally:
os.umask(original_umask)
try:
self.f = h5py.File(self.cache_file_path, 'r')
self.h5data = self.f['data']
except OSError:
print("Caches not found at ", self.cache_file_path)
f = h5py.File(self.cache_file_path, 'w')
d = None
starttime = time.time()
for i, (x, y, z) in enumerate(
torch.utils.data.DataLoader(getDataset(), batch_size=1, shuffle=False, num_workers=36)):
# fixing the shapes
x = x[0].numpy().transpose(2, 0, 1)
y = y[0]
z = z[0]
x = x.reshape(x.shape[0], -1)
if d is None:
d = f.create_dataset('data', (0, x.shape[1]), maxshape=(None, x.shape[1]), dtype='f', chunks=True)
idx = d.shape[0]
xlen = x.shape[0]
d.resize((idx + xlen, x.shape[1]))
d[idx:] = x
cpath = os.path.join(self.cache_path, str(i) + "_meta.pt")
torch.save(((idx, xlen), y, z), cpath)
if i % 1000 == 0:
print("cached ", i, " in ", time.time() - starttime)
print("Done Caching ", dataset_name, " cached:", i, " in ", time.time() - starttime)
f.close()
self.f = h5py.File(self.cache_file_path, 'r')
self.h5data = self.f['data']
def __getitem__(self, index):
cpath = os.path.join(self.cache_path, str(index) + "_meta.pt")
# x.transpose(2,0,1).transpose(1,2,0) store columns first
(idx, xlen), y, z = torch.load(cpath)
x = self.slicing_function(self.h5data, idx, xlen)
return x, y, z
def get_ordered_labels(self):
return self.get_dataset_func().get_ordered_labels()
def get_ordered_ids(self):
return self.get_dataset_func().get_ordered_ids()
def get_xcache_path(self):
return os.path.join(self.cache_path, self.x_name + "_x.pt")
def get_ycache_path(self):
return os.path.join(self.cache_path, self.y_name + "_y.pt")
def get_sidcache_path(self):
return os.path.join(self.cache_path, self.y_name + "_sid.pt")
def __len__(self):
return len(self.get_dataset_func())
if __name__=='__main__':
# Tests
torch.manual_seed(6)
# dataset = MelSpecDataset(phase='train', ann_root=PATH_ANNOTATIONS,
# spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
dataset = HDF5Dataset('/mnt/2tb/datasets/MTG-Jamendo/HDF5Cache_spectrograms/', recursive=False, load_data=False)
# dataset = HDF5Dataset('/mnt/2tb/datasets/MTG-Jamendo/HDF5Cache_spectrograms/', recursive=False, load_data=False)
train_files_csv = "~/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv"
audio_path = "~/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio/"
cache_x_name = "_ap_mtgjamendo44k"
from datasets.mtgjamendo import df_get_mtg_set
# name, mtg_files_csv, audio_path, cache_x_name
dataset = df_get_mtg_set('mtgjamendo', train_files_csv, audio_path, cache_x_name)
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
for i, data in enumerate(train_loader, 0):
spec, labels = data
spec, _, labels = data
pass
pass
\ No newline at end of file
import os
from datasets.datasets import H5FCachedDataset, AudioPreprocessDataset
import torch
import librosa
import numpy as np
import pandas as pd
def sample_slicing_function(h5data, idx, xlen):
timeframes = 600
k = torch.randint(xlen - timeframes + 1, (1,))[0].item()
x = h5data[idx + k:idx + k + timeframes]
return torch.from_numpy(x.transpose(1, 0).reshape(1, 256, timeframes))
t2_parse_labels_cache = {}
def t2_parse_labels(csvf):
global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf)
df = pd.read_csv(csvf, sep='\t')
files = df['PATH'].values
labels = []
for l in df['TAGS'].values:
labels.append(set(l.split(",")))
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
bins = mlb.fit_transform(labels)
t2_parse_labels_cache[csvf] = files, bins, mlb
return t2_parse_labels_cache[csvf]
def processor_mtgjamendo44k(file_path):
n_fft = 2048 # 2048
sr = 44100 # 22050 # 44100 # 32000
mono = True # @todo ask mattias
log_spec = False
n_mels = 256
hop_length = 512
fmax = None
dpath, filename = os.path.split(file_path)
#file_path2 = dpath + "/../audio22k/" + filename
if mono:
# this is the slowest part resampling
sig, sr = librosa.load(file_path, sr=sr, mono=True)
sig = sig[np.newaxis]
else:
sig, sr = librosa.load(file_path, sr=sr, mono=False)
# sig, sf_sr = sf.read(file_path)
# sig = np.transpose(sig, (1, 0))
# sig = np.asarray([librosa.resample(s, sf_sr, sr) for s in sig])
spectrograms = []
for y in sig:
# compute stft
stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=None, window='hann', center=True,
pad_mode='reflect')
# keep only amplitures
stft = np.abs(stft)
# spectrogram weighting
if log_spec:
stft = np.log10(stft + 1)
else:
freqs = librosa.core.fft_frequencies(sr=sr, n_fft=n_fft)
stft = librosa.perceptual_weighting(stft ** 2, freqs, ref=1.0, amin=1e-10, top_db=80.0)
# apply mel filterbank
spectrogram = librosa.feature.melspectrogram(S=stft, sr=sr, n_mels=n_mels, fmax=fmax)
# keep spectrogram
spectrograms.append(np.asarray(spectrogram))
spectrograms = np.asarray(spectrograms, dtype=np.float32)
return torch.from_numpy(spectrograms)
audio_processor = processor_mtgjamendo44k
label_encoder = None
def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name):
audio_path = os.path.expanduser(audio_path)
global label_encoder
print("loading dataset from '{}'".format(name))
def getdatset():
files, labels, label_encoder = t2_parse_labels(mtg_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
x_name=cache_x_name,
)
return df_trset
from utils import *
from pytorch_lightning import Trainer
from test_tube import Experiment
from models import vgg_basic
def run():
logger.info(CURR_RUN_PATH)
exp = Experiment(
save_dir=CURR_RUN_PATH
)
exp.save()
# model = cp_resnet.Network(model_config)
model = vgg_basic.Network()
# TODO: deal with this later
#model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
#checkpoint = ModelCheckpoint(
# filepath=model_save_path,
# save_best_only=True,
# verbose=True,
# monitor='rocauc',
# mode='max'
#)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=10, train_percent_check=1.0,
fast_dev_run=False)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
trainer.fit(model)
if __name__=='__main__':
run()
\ No newline at end of file
from utils import *
from datasets.datasets import MelSpecDataset
#from datasets.datasets import MelSpecDataset
from datasets.mtgjamendo import df_get_mtg_set
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -14,6 +15,13 @@ class Network(pl.LightningModule):
super(Network, self).__init__()
self.num_tags = num_tags
self.train_files_csv = "~/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv"
self.test_files_csv = "~/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv"
self.val_files_csv = "~/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv"
self.audio_path = "~/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio/"
self.cache_x_name = "_ap_mtgjamendo44k"
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, 5, 2, 2), # (in_channels, out_channels, kernel_size, stride, padding)
nn.BatchNorm2d(64),
......@@ -114,22 +122,17 @@ class Network(pl.LightningModule):
return yy/frames_to_process
def training_step(self, data_batch, batch_nb):
# TODO: this is not doing a backward pass?? something like this:
# self.optimizer.zero_grad()
# loss.backward()
# self.optimizer.step()
#
# TODO response: they are taken care of in __run_tng_batch() inside Trainer
x, y = data_batch
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
def validation_step(self, data_batch, batch_nb):
x, y = data_batch
print("data_batch", data_batch)
x, _, y = data_batch
print("x", x)
print("y", y)
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
......@@ -148,20 +151,25 @@ class Network(pl.LightningModule):
@pl.data_loader
def tng_dataloader(self):
trainset = MelSpecDataset(phase='train', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
#trainset = MelSpecDataset(phase='train', ann_root=PATH_ANNOTATIONS,
# spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
# def df_get_mtg_set(name, mtg_files_csv, audio_path,
# cache_x_name):
trainset = df_get_mtg_set('mtgjamendo', self.train_files_csv, self.audio_path, self.cache_x_name)
return DataLoader(dataset=trainset, batch_size=32, shuffle=True)
@pl.data_loader
def val_dataloader(self):
validationset = MelSpecDataset(phase='validation', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
return DataLoader(dataset=validationset, batch_size=128, shuffle=True)
#validationset = MelSpecDataset(phase='validation', ann_root=PATH_ANNOTATIONS,
# spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
validationset = df_get_mtg_set('mtgjamendo_val', self.val_files_csv, self.audio_path, self.cache_x_name)
return DataLoader(dataset=validationset, batch_size=32, shuffle=True)
@pl.data_loader
def test_dataloader(self):
testset = MelSpecDataset(phase='test', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
#testset = MelSpecDataset(phase='test', ann_root=PATH_ANNOTATIONS,
# spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED, framed=True)
testset = df_get_mtg_set('mtgjamendo_test', self.test_files_csv, self.audio_path, self.cache_x_name)
return DataLoader(dataset=testset, batch_size=32, shuffle=True)
@staticmethod
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment