Commit b3a8c6f7 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury

implement logic for full song rnn, cuda out of mem, working fine on cpu

parent 346ec177
import torch
class PadSequence:
def __call__(self, batch):
# print("PadSequence is called")
# Let's assume that each element in "batch" is a tuple (data, label).
# Sort the batch in the descending order
sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
# Get each sequence and pad it
sequences = [x[0] for x in sorted_batch]
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
# Also need to store the length of each sequence
# This is later needed in order to unpad the sequences
lengths = torch.LongTensor([len(x) for x in sequences])
# # print("PadSequence is called")
# # Let's assume that each element in "batch" is a tuple (data, label).
# # Sort the batch in the descending order
# sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
# # Get each sequence and pad it
# sequences = [x[0] for x in sorted_batch]
# sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
# # Also need to store the length of each sequence
# # This is later needed in order to unpad the sequences
# lengths = torch.LongTensor([len(x) for x in sequences])
#
# # Don't forget to grab the labels of the *sorted* batch
# labels = [x[2] for x in sorted_batch]
# # labels = torch.LongTensor((map(lambda x: x[1], sorted_batch)))
# # print(labels)
# # labels = torch.LongTensor(labels)
# return sequences_padded, lengths, labels
h5data = batch[0][0][0]
idx = [i[0][1] for i in batch]
lengths = [i[0][2] for i in batch]
labels = [i[2] for i in batch]
# Don't forget to grab the labels of the *sorted* batch
labels = [x[2] for x in sorted_batch]
# labels = torch.LongTensor((map(lambda x: x[1], sorted_batch)))
# print(labels)
# labels = torch.LongTensor(labels)
return sequences_padded, lengths, labels
return h5data, idx, lengths, labels
......@@ -160,8 +160,8 @@ class H5FCachedDataset(Dataset):
torch.utils.data.DataLoader(getDataset(), batch_size=1, shuffle=False, num_workers=36)):
# fixing the shapes
x = x[0].numpy().transpose(2, 0, 1)
y = y[0]
z = z[0]
y = y[0] # audio filepath relative
z = z[0] # labels
x = x.reshape(x.shape[0], -1)
if d is None:
d = f.create_dataset('data', (0, x.shape[1]), maxshape=(None, x.shape[1]), dtype='f', chunks=True)
......@@ -181,7 +181,7 @@ class H5FCachedDataset(Dataset):
def __getitem__(self, index):
cpath = os.path.join(self.cache_path, str(index) + "_meta.pt")
# x.transpose(2,0,1).transpose(1,2,0) store columns first
(idx, xlen), y, z = torch.load(cpath)
(idx, xlen), y, z = torch.load(cpath) # y: audio filepath relative, z: labels
x = self.slicing_function(self.h5data, idx, xlen)
if self.augment_options is not None:
......
......@@ -15,18 +15,7 @@ def sample_slicing_function(h5data, idx, xlen):
return torch.from_numpy(x.transpose(1, 0).reshape(1, 256, timeframes))
def full_song_slicing_function(h5data, idx, xlen):
#TODO: not working, make it work if possible.
maxlen = 2048
if xlen > maxlen:
k = torch.randint(xlen - maxlen + 1, (1,))[0].item()
x = h5data[idx + k:idx + k + maxlen]
print(x.shape)
else:
x = h5data[idx:idx+xlen]
x = np.pad(x, ((0, maxlen - xlen), (0, 0)), mode='wrap')
# print(x.shape)
return torch.from_numpy(x.transpose((1, 0)).reshape((1, 256, -1)))
return (h5data, idx, xlen)
t2_parse_labels_cache = {}
......
......@@ -85,6 +85,6 @@ if __name__ == '__main__':
# run(hyperparams)
# gpus = ['cuda:0', 'cuda:1']
# hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams)
for hparam_trial in hyperparams.trials(20):
run(hparam_trial)
run(hyperparams)
# for hparam_trial in hyperparams.trials(20):
# run(hparam_trial)
......@@ -48,15 +48,69 @@ class CRNN(BasePtlModel):
self.dropout = nn.Dropout(self.hparams.drop_prob)
def forward_full_song(self, batch):
def cnn_forward(x):
# init bn
x = self.bn_init(x)
# layer 1
x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x))))
# layer 2
x = nn.ELU()(self.bn_2(self.conv_2(x)))
# x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x))))
# layer 3
x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x))))
# layer 4
x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x))))
# layer 5
x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x))))
# classifier
x = x.view(-1, x.size(0), 32)
return x
def rnn_forward(x):
x = self.gru1(x)[1][1] # TODO: Check if this is correct
x = self.dropout(x)
logit = nn.Sigmoid()(self.dense(x))
return logit
def extract_features(song_idx, song_length):
song_feats = []
for l in range(song_length//self.input_size + 1):
data = h5data[song_idx + l*self.input_size:song_idx + min(song_length, (l + 1) * self.input_size)].transpose()
data = np.pad(data, ((0, 0), (0, self.input_size-data.shape[1])), mode='wrap')
try:
song_feats.append(cnn_forward(torch.tensor([[data]], device=torch.device('cuda'))))
except:
song_feats.append(cnn_forward(torch.tensor([[data]], device=torch.device('cpu'))))
return torch.cat(song_feats)
h5data, idx_list, x_lengths_list, labels_list = batch
sequences = []
for n, ind in enumerate(idx_list):
sequences.append(extract_features(ind, x_lengths_list[n]))
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
result = rnn_forward(sequences_padded)
return result
def forward(self, batch):
#print("batch", batch)
x, x_lengths, _ = batch
#print("x", x)
#print("xlen", x_lengths)
if self.slicing_mode == 'full':
print("before pack", x, x_lengths)
x = pack_padded_sequence(x, x_lengths, batch_first=True)
if self.slicing_mode=='full':
logit = self.forward_full_song(batch)
return logit
x, _, _ = batch # xs, xlens, labels
# init bn
x = self.bn_init(x)
......@@ -89,14 +143,14 @@ class CRNN(BasePtlModel):
return logit
def training_step(self, data_batch, batch_i):
_, _, y = data_batch
y = data_batch[-1]
y_hat = self.forward(data_batch)
y = y.float()
y_hat = y_hat.float()
return {'loss': self.loss(y_hat, y)}
def validation_step(self, data_batch, batch_i):
x, _, y = data_batch
y = data_batch[-1]
y_hat = self.forward(data_batch)
y = y.float()
y_hat = y_hat.float()
......@@ -107,7 +161,7 @@ class CRNN(BasePtlModel):
}
def test_step(self, data_batch, batch_i):
x, _, y = data_batch
y = data_batch[-1]
y_hat = self.forward(data_batch)
y = y.float()
y_hat = y_hat.float()
......@@ -130,8 +184,8 @@ class CRNN(BasePtlModel):
parser.opt_list('--learning_rate', default=0.0001, type=float,
options=[0.00001, 0.0005, 0.001],
tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
parser.opt_list('--slicing_mode', default='full', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=512, options=[512, 1024], type=int, tunable=True)
# training params (opt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment