Commit 4869d2e5 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury

change function name train_dataloader to go with change in ptl, add...

change function name train_dataloader to go with change in ptl, add correlation metrics, other refactoring
parent 2e3be6de
from datasets.midlevel import df_get_midlevel_set
from scipy.stats import pearsonr
from torch import optim
from torch.utils.data.dataset import random_split
import torch
......@@ -28,15 +29,22 @@ class BasePtlModel(pl.LightningModule):
self.data_source = config.get('data_source')
self.hparams = hparams
self.slicing_mode = hparams.slicing_mode
if hparams.slicing_mode == 'full':
try:
self.slicing_mode = hparams.slicing_mode
except:
self.slicing_mode = 'slice'
if self.slicing_mode == 'full':
self.slicer = full_song_slicing_function
elif hparams.slicing_mode == 'slice':
elif self.slicing_mode == 'slice':
self.slicer = sample_slicing_function
else:
raise Exception(f"Invalid slicing mode {hparams.slicing_mode}")
self.input_size = hparams.input_size
try:
self.input_size = hparams.input_size
except:
pass
self.training_metrics = config.get('training_metrics')
self.validation_metrics = config.get('validation_metrics')
......@@ -159,6 +167,22 @@ class BasePtlModel(pl.LightningModule):
metrics_res[metric] = metrics.average_precision_score(Y, Y_hat, average='macro')
if metric == 'prauc-micro':
metrics_res[metric] = metrics.average_precision_score(Y, Y_hat, average='micro')
if metric == 'corr_avg':
corr, pval = [], []
for i in range(7):
c, p = pearsonr(Y[:,i], Y_hat[:,i])
corr.append(c)
pval.append(p)
metrics_res['corr_avg'] = np.mean(corr)
metrics_res['pval_avg'] = np.mean(pval)
if metric == 'corr':
corr, pval = [], []
for i in range(7):
c, p = pearsonr(Y[:, i], Y_hat[:, i])
corr.append(c)
pval.append(p)
metrics_res['corr'] = corr
metrics_res['pval'] = pval
return metrics_res
......@@ -176,7 +200,7 @@ class BasePtlModel(pl.LightningModule):
raise Exception(f"Loss not implemented for {self.data_source}")
@pl.data_loader
def tng_dataloader(self):
def train_dataloader(self):
if self.data_source == 'mtgjamendo':
dataset = df_get_mtg_set('mtgjamendo',
path_mtgjamendo_annotations_train,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment