Commit abf2d01d authored by Verena Praher's avatar Verena Praher
Browse files

make dataloaders and validation_step reusable

parent fdf03577
import torch.nn as nn
from datasets.mtgjamendo import df_get_mtg_set
from utils import *
from datasets.dataset import HDF5Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from models.shared_stuff import *
from sklearn import metrics
# TODO pr-auc
......@@ -88,7 +82,7 @@ class CNN(pl.LightningModule):
return logit
def my_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
return my_loss(y_hat, y)
def forward_full_song(self, x, y):
# print(x.shape)
......@@ -129,44 +123,22 @@ class CNN(pl.LightningModule):
'fscore':fscore}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean()
avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'val_loss':avg_loss,
'rocauc':avg_auc,
'prauc':avg_prauc,
'fscore':avg_fscore}
return validation_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader
def tng_dataloader(self):
train_csv = os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
return tng_dataloader()
@pl.data_loader
def val_dataloader(self):
validation_csv = os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
return val_dataloader()
@pl.data_loader
def test_dataloader(self):
test_csv = os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_test', test_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
return test_dataloader()
@staticmethod
def add_model_specific_args(parent_parser, root_dir):
......
from utils import PATH_ANNOTATIONS, PATH_AUDIO
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn import metrics
import os
from datasets.mtgjamendo import df_get_mtg_set
def my_loss(y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(model, data_batch, batch_nb):
x, _, y = data_batch
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss': model.my_loss(y_hat, y)}
def validation_step(model, data_batch, batch_nb):
# print("data_batch", data_batch)
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu())
fscore = 0.
return {'val_loss': model.my_loss(y_hat, y),
'rocauc': rocauc,
'prauc': prauc,
'fscore': fscore}
def validation_end(outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean()
avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'val_loss': avg_loss,
'rocauc': avg_auc,
'prauc': avg_prauc,
'fscore': avg_fscore}
def tng_dataloader():
train_csv = os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
def val_dataloader():
validation_csv = os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
def test_dataloader():
test_csv = os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_test', test_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment