Commit 320bf3d1 authored by Verena Praher's avatar Verena Praher
Browse files

start implementing collate_fn for slicing_mode=full

parent 8db5855f
......@@ -24,7 +24,7 @@ def full_song_slicing_function(h5data, idx, xlen):
else:
x = h5data[idx:idx+xlen]
x = np.pad(x, ((0, maxlen - xlen), (0, 0)), mode='wrap')
print(x.shape)
# print(x.shape)
return torch.from_numpy(x.transpose((1, 0)).reshape((1, 256, -1)))
......@@ -126,3 +126,28 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
return df_trset
if __name__ == '__main__':
from datasets.shared_data_utils import path_mtgjamendo_annotations_train, path_mtgjamendo_audio_dir
from torch.utils.data import DataLoader
import math
dataset = df_get_mtg_set('mtgjamendo',
path_mtgjamendo_annotations_train,
path_mtgjamendo_audio_dir,
"_ap_mtgjamendo44k",
slice_len=512)
from datasets.collate import PadSequence
dataloader = DataLoader(dataset, 32, shuffle=True, collate_fn=PadSequence())
min_x = math.inf
max_x = -math.inf
batch = next(iter(dataloader))
# print(batch)
for _, (x, _, _) in enumerate(dataloader):
# print(torch.min(x))
if torch.min(x) < min_x:
min_x = torch.min(x)
if torch.max(x) > max_x:
max_x = torch.max(x)
print("min x", min_x)
print("max x", max_x)
\ No newline at end of file
from test_tube import HyperOptArgumentParser
from utils import *
from models.shared_stuff import BasePtlModel
from torch.nn.utils.rnn import pack_padded_sequence
class CRNN(BasePtlModel):
def __init__(self, config, num_class, hparams):
......@@ -37,15 +38,25 @@ class CRNN(BasePtlModel):
# recurrent layer
self.gru1 = nn.GRU(input_size=32,
hidden_size=hparams.gru_hidden_size,
num_layers=hparams.gru_num_layers)
num_layers=hparams.gru_num_layers
#,
#batch_first=True # TODO: check if this is needed
)
# classifier
self.dense = nn.Linear(hparams.gru_hidden_size, num_class)
self.dropout = nn.Dropout(self.hparams.drop_prob)
def forward(self, x):
# x = x[:, :, :, :512]
def forward(self, batch):
#print("batch", batch)
x, x_lengths, _ = batch
#print("x", x)
#print("xlen", x_lengths)
if self.slicing_mode == 'full':
print("before pack", x, x_lengths)
x = pack_padded_sequence(x, x_lengths, batch_first=True)
# init bn
x = self.bn_init(x)
......@@ -68,6 +79,8 @@ class CRNN(BasePtlModel):
# classifier
x = x.view(-1, x.size(0), 32)
# output, hidden = self.gru(x_pack)
x = self.gru1(x)[1][1] # TODO: Check if this is correct
x = self.dropout(x)
......@@ -75,6 +88,35 @@ class CRNN(BasePtlModel):
return logit
def training_step(self, data_batch, batch_i):
_, _, y = data_batch
y_hat = self.forward(data_batch)
y = y.float()
y_hat = y_hat.float()
return {'loss': self.loss(y_hat, y)}
def validation_step(self, data_batch, batch_i):
x, _, y = data_batch
y_hat = self.forward(data_batch)
y = y.float()
y_hat = y_hat.float()
return {
'val_loss': self.loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy()
}
def test_step(self, data_batch, batch_i):
x, _, y = data_batch
y_hat = self.forward(data_batch)
y = y.float()
y_hat = y_hat.float()
return {
'test_loss': self.loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy()
}
@staticmethod
def add_model_specific_args(parent_parser):
"""Parameters defined here will be available to your model through self.hparams
......
from datasets.midlevel import df_get_midlevel_set
from torch import optim
from torch.utils.data.dataset import random_split
from utils import PATH_ANNOTATIONS, PATH_AUDIO
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn import metrics
import os
from datasets.mtgjamendo import df_get_mtg_set, \
sample_slicing_function, \
full_song_slicing_function
import numpy as np
import pytorch_lightning as pl
from datasets.shared_data_utils import *
def my_loss(y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(model, data_batch, batch_nb):
x, _, y = data_batch
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss': model.my_loss(y_hat, y)}
def validation_step(model, data_batch, batch_nb):
# print("data_batch", data_batch)
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
#print("y", y)
#print("y_hat", y_hat)
#rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu(), average='macro')
#prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu(), average='macro')
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu())
#fscore = 0.
return {'val_loss': model.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
#'val_rocauc': rocauc,
#'val_prauc': prauc,
#'val_fscore': fscore
}
def validation_end(outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
y = []
y_hat = []
for output in outputs:
y.append(output['y'])
y_hat.append(output['y_hat'])
y = np.concatenate(y)
y_hat = np.concatenate(y_hat)
#print(y[0:10])
#print(y_hat[0:10])
rocauc = metrics.roc_auc_score(y, y_hat, average='macro')
prauc = metrics.average_precision_score(y, y_hat, average='macro')
#_, _, fscore, _ = metrics.precision_recall_fscore_support(y, y_hat, average='macro')
fscore = 0.
#print('metrics', rocauc, prauc, fscore)
#avg_auc = torch.stack([torch.tensor([x['val_rocauc']]) for x in outputs]).mean()
#avg_prauc = torch.stack([torch.tensor([x['val_prauc']]) for x in outputs]).mean()
#avg_fscore = torch.stack([torch.tensor([x['val_fscore']]) for x in outputs]).mean()
return {'val_loss': avg_loss,
'val_rocauc': rocauc,
'val_prauc': prauc,
'val_fscore': fscore}
def test_step(model, data_batch, batch_nb):
# print("data_batch", data_batch)
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu(), average='macro')
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu(), average='macro')
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu())
fscore = 0.
return {'test_loss': model.my_loss(y_hat, y),
'y': y.cpu(),
'y_hat': y_hat.cpu(),
#'test_rocauc': rocauc,
#'test_prauc': prauc,
#'test_fscore': fscore
}
def test_end(outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
y = []
y_hat = []
for output in outputs:
y.append(output['y'])
y_hat.append(output['y_hat'])
y = np.concatenate(y)
y_hat = np.concatenate(y_hat)
#print(y[0:10])
#print(y_hat[0:10])
rocauc = metrics.roc_auc_score(y, y_hat, average='macro')
prauc = metrics.average_precision_score(y, y_hat, average='macro')
#_, _, fscore, _ = metrics.precision_recall_fscore_support(y, y_hat, average='macro')
fscore = 0.
return {'test_loss': avg_loss,
'test_rocauc': rocauc,
'test_prauc': prauc,
'test_fscore': fscore}
def tng_dataloader(batch_size=32, augment_options=None):
train_csv = os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name, augment_options)
return DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
def val_dataloader(batch_size=32, augment_options=None):
validation_csv = os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name, augment_options)
return DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
def test_dataloader(batch_size=32):
test_csv = os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_test', test_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
from datasets.collate import PadSequence
# example config dict
base_model_config = {
......@@ -165,6 +28,7 @@ class BasePtlModel(pl.LightningModule):
self.data_source = config.get('data_source')
self.hparams = hparams
self.slicing_mode = hparams.slicing_mode
if hparams.slicing_mode == 'full':
self.slicer = full_song_slicing_function
elif hparams.slicing_mode == 'slice':
......@@ -325,6 +189,11 @@ class BasePtlModel(pl.LightningModule):
else:
raise Exception(f"Data source {self.data_source} not defined")
if self.slicing_mode == 'full':
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
......@@ -345,6 +214,12 @@ class BasePtlModel(pl.LightningModule):
else:
raise Exception(f"Data source {self.data_source} not defined")
if self.slicing_mode == 'full':
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
......@@ -364,6 +239,12 @@ class BasePtlModel(pl.LightningModule):
else:
raise Exception(f"Data source {self.data_source} not defined")
if self.slicing_mode == 'full':
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=PadSequence())
return DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment