Commit b3ba542d authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

khaled's model working

parent 4479fab9
__pycache__/*
*/__pycache__/*
*__pycache__/*
results/runs*
.idea*
......@@ -2,6 +2,7 @@ from utils import *
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.vgg_basic import MultiTagger
from models import cp_resnet
def run():
......@@ -16,7 +17,8 @@ def run():
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model = MultiTagger(num_tags=56)
from strategies import model_config
model = cp_resnet.Network(model_config)
trainer.fit(model)
......
......@@ -7,6 +7,10 @@ import torch.nn.functional as F
from librosa.filters import mel as librosa_mel_fn
from utils import *
from datasets import MelSpecDataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
def initialize_weights(module):
if isinstance(module, nn.Conv2d):
......@@ -199,7 +203,7 @@ class BottleneckBlock(nn.Module):
return y
class Network(nn.Module):
class Network(pl.LightningModule):
def __init__(self, config):
super(Network, self).__init__()
......@@ -332,6 +336,7 @@ class Network(nn.Module):
if first_RUN: print("stage3:", x.size())
return x
def forward(self, x):
global first_RUN
if self.use_raw_spectograms:
......@@ -351,3 +356,86 @@ class Network(nn.Module):
if first_RUN: print("logit:", logit.size())
first_RUN = False
return logit
def my_loss(self, y_hat, y):
return F.binary_cross_entropy_with_logits(y_hat, y)
def forward_full_song(self, x, y):
# yy = []
# for i in range(2):
# yy.append(self.forward(torch.unsqueeze(x[:, i, :, :], dim=1))[1])
# return torch.stack(yy).mean(dim=0)
if USE_GPU:
yy = torch.zeros(y.shape, device=y.get_device())
else:
yy = torch.zeros(y.shape)
frames_to_process = 10
for i in range(frames_to_process):
yy += self.forward(torch.unsqueeze(x[:, i, :, :], dim=1))[1]
return yy/frames_to_process
def training_step(self, data_batch, batch_nb):
x, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
def validation_step(self, data_batch, batch_nb):
x, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
y_hat_probs = F.softmax(y_hat, dim=1)
y_hat_binary = (y_hat_probs > 0.5).type(torch.int)
rocauc = roc_auc_score(y.t().cpu(), y_hat_probs.t().cpu())
fscore = f1_score(y.t().cpu(), y_hat_probs.t().cpu(), average='micro')
return {'val_loss': self.my_loss(y_hat, y),
'rocauc':rocauc,
'fscore': fscore}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_f = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'val_loss':avg_loss,
'rocauc':avg_auc,
'fscore':avg_f}
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]
@pl.data_loader
def tng_dataloader(self):
trainset = MelSpecDataset(phase='train', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED)
return DataLoader(dataset=trainset, batch_size=32, shuffle=True)
@pl.data_loader
def val_dataloader(self):
validationset = MelSpecDataset(phase='validation', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED)
return DataLoader(dataset=validationset, batch_size=128, shuffle=True)
@pl.data_loader
def test_dataloader(self):
testset = MelSpecDataset(phase='test', ann_root=PATH_ANNOTATIONS,
spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED)
return DataLoader(dataset=testset, batch_size=32, shuffle=True)
if __name__=='__main__':
model_config = {
"input_shape": [1, 1, 256, 96], # [batch,channels,time,freq]
"n_classes": 56,
"depth": 26,
"base_channels": 128,
"n_blocks_per_stage": [3, 1, 1],
"stage1": {"maxpool": [1, 2], "k1s": [3, 3, 3], "k2s": [1, 3, 3]},
"stage2": {"maxpool": [1], "k1s": [3, ], "k2s": [1, ]},
"stage3": {"maxpool": [], "k1s": [1, ], "k2s": [1, ]},
"block_type": "basic"
}
net = Network(model_config)
pass
\ No newline at end of file
......@@ -9,7 +9,7 @@ import pytorch_lightning as pl
from sklearn.metrics import roc_auc_score
class MultiTagger(ptl.LightningModule):
class MultiTagger(pl.LightningModule):
def __init__(self, num_tags=8):
super(MultiTagger, self).__init__()
self.num_tags = num_tags
......
......@@ -25,6 +25,7 @@ from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.vgg_basic import MultiTagger
from models import cp_resnet
logger.info(CURR_RUN_PATH)
......@@ -42,7 +43,20 @@ if __name__=='__main__':
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model = MultiTagger(num_tags=56)
# model = MultiTagger(num_tags=56)
model_config = {
"input_shape": [1, 1, 256, 96], #[batch,channels,time,freq]
"n_classes": 56,
"depth": 26,
"base_channels": 128,
"n_blocks_per_stage": [3, 1, 1],
"stage1": {"maxpool": [1, 2], "k1s": [3, 3, 3], "k2s": [1, 3, 3]},
"stage2": {"maxpool": [1], "k1s": [3, ], "k2s": [1, ]},
"stage3": {"maxpool": [], "k1s": [1, ], "k2s": [1, ]},
"block_type": "basic"
}
model = cp_resnet.Network(model_config)
trainer.fit(model)
......
model_config = {
"depth": 26,
"base_channels": 128,
"n_blocks_per_stage": [3, 1, 1],
"stage1": {"maxpool": [1, 2], "k1s": [3, 3, 3], "k2s": [1, 3, 3]},
"stage2": {"maxpool": [1], "k1s": [3, ], "k2s": [1, ]},
"stage3": {"maxpool": [], "k1s": [1, ], "k2s": [1, ]},
"block_type": "basic"
}
\ No newline at end of file
"input_shape": [1, 1, 256, 96], #[batch,channels,time,freq]
"n_classes": 56,
"depth": 26,
"base_channels": 128,
"n_blocks_per_stage": [3, 1, 1],
"stage1": {"maxpool": [1, 2], "k1s": [3, 3, 3], "k2s": [1, 3, 3]},
"stage2": {"maxpool": [1], "k1s": [3, ], "k2s": [1, ]},
"stage3": {"maxpool": [], "k1s": [1, ], "k2s": [1, ]},
"block_type": "basic"
}
\ No newline at end of file
......@@ -7,13 +7,15 @@ import torch.nn as nn
import logging
import numpy as np
import pandas as pd
import pytorch_lightning as ptl
from matplotlib import pyplot as plt
from tqdm import tqdm
plt.rcParams["figure.dpi"] = 288 # increase dpi for clearer plots
from plotting import * # mostly for debug
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, auc
# PARAMS =======================
INPUT_SIZE = (96, 256)
MAX_FRAMES = 40
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment