Commit 333e9359 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

implement grid search in crnn, make input_size tunable

parent 14691251
......@@ -6,14 +6,28 @@ import numpy as np
import pandas as pd
from utils import PATH_DATA_CACHE
slice_length = 512 #TODO: Find a better way
def sample_slicing_function(h5data, idx, xlen):
timeframes = 2048
timeframes = slice_length
k = torch.randint(xlen - timeframes + 1, (1,))[0].item()
x = h5data[idx + k:idx + k + timeframes]
return torch.from_numpy(x.transpose(1, 0).reshape(1, 256, timeframes))
def full_song_slicing_function(h5data, idx, xlen):
#TODO: not working, make it work if possible.
maxlen = 2048
if xlen > maxlen:
k = torch.randint(xlen - maxlen + 1, (1,))[0].item()
x = h5data[idx + k:idx + k + maxlen]
print(x.shape)
else:
x = h5data[idx:idx+xlen]
x = np.pad(x, ((0, maxlen - xlen), (0, 0)), mode='wrap')
print(x.shape)
return torch.from_numpy(x.transpose((1, 0)).reshape((1, 256, -1)))
t2_parse_labels_cache = {}
......@@ -89,7 +103,9 @@ audio_processor = processor_mtgjamendo44k
label_encoder = None
def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, augment_options=None):
def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=None, slice_len=512, augment_options=None):
global slice_length
slice_length = slice_len
audio_path = os.path.expanduser(audio_path)
global label_encoder
print("loading dataset from '{}'".format(name))
......@@ -98,7 +114,10 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, augment_option
files, labels, label_encoder = t2_parse_labels(mtg_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
if slicing_func is None:
slicing_func = sample_slicing_function
df_trset = H5FCachedDataset(getdatset, name, slicing_function=slicing_func,
x_name=cache_x_name,
cache_path=PATH_DATA_CACHE,
augment_options=augment_options
......
......@@ -73,5 +73,6 @@ if __name__=='__main__':
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams)
for hparam_trial in hyperparams.trials(20):
run(hparam_trial)
......@@ -45,7 +45,7 @@ class CRNN(BasePtlModel):
def forward(self, x):
x = x[:, :, :, :512]
# x = x[:, :, :, :512]
# init bn
x = self.bn_init(x)
......@@ -88,6 +88,8 @@ class CRNN(BasePtlModel):
parser.opt_list('--learning_rate', default=0.0001, type=float,
options=[0.00001, 0.0005, 0.001],
tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
# training params (opt)
......@@ -97,6 +99,6 @@ class CRNN(BasePtlModel):
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
parser.opt_list('--batch_size', default=32, type=int,
options=[32, 64, 128, 256], tunable=False,
options=[16, 32], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser
......@@ -7,7 +7,9 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn import metrics
import os
from datasets.mtgjamendo import df_get_mtg_set
from datasets.mtgjamendo import df_get_mtg_set, \
sample_slicing_function, \
full_song_slicing_function
import numpy as np
import pytorch_lightning as pl
from datasets.shared_data_utils import *
......@@ -163,6 +165,15 @@ class BasePtlModel(pl.LightningModule):
self.data_source = config.get('data_source')
self.hparams = hparams
if hparams.slicing_mode == 'full':
self.slicer = full_song_slicing_function
elif hparams.slicing_mode == 'slice':
self.slicer = sample_slicing_function
else:
raise Exception(f"Invalid slicing mode {hparams.slicing_mode}")
self.input_size = hparams.input_size
self.training_metrics = config.get('training_metrics')
self.validation_metrics = config.get('validation_metrics')
self.test_metrics = config.get('test_metrics')
......@@ -306,7 +317,8 @@ class BasePtlModel(pl.LightningModule):
dataset = df_get_mtg_set('mtgjamendo',
path_mtgjamendo_annotations_train,
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k")
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source=='midlevel':
dataset = self.midlevel_trainset
......@@ -324,7 +336,8 @@ class BasePtlModel(pl.LightningModule):
dataset = df_get_mtg_set('mtgjamendo_val',
path_mtgjamendo_annotations_val,
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k")
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source == 'midlevel':
dataset = self.midlevel_valset
......@@ -342,7 +355,8 @@ class BasePtlModel(pl.LightningModule):
dataset = df_get_mtg_set('mtgjamendo_test',
path_mtgjamendo_annotations_test,
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k")
"_ap_mtgjamendo44k", slicing_func=self.slicer,
slice_len=self.input_size)
elif self.data_source == 'midlevel':
dataset = self.midlevel_testset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment