Commit 4881ccd5 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

implement midlevel model pretraining

parent ac30bc96
......@@ -230,7 +230,12 @@ class AudioPreprocessDataset(Dataset):
x = self.preprocessor(os.path.join(self.base_dir, self.files[index]))
if self.return_tensor and not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
return x, self.files[index], self.labels[index]
try:
labels = self.labels.iloc[index].to_numpy()
except:
labels = self.labels[index]
return x, self.files[index], labels
def get_ordered_ids(self):
return self.files
......
import os
from datasets.dataset import H5FCachedDataset, AudioPreprocessDataset
import torch
import librosa
import numpy as np
import pandas as pd
from utils import PATH_DATA_CACHE
def sample_slicing_function(h5data, idx, xlen):
timeframes = 512
k = torch.randint(xlen - timeframes + 1, (1,))[0].item()
x = h5data[idx + k:idx + k + timeframes]
return torch.from_numpy(x.transpose(1, 0).reshape(1, 256, timeframes))
t2_parse_labels_cache = {}
def t2_parse_labels(csvf):
global t2_parse_labels_cache
if t2_parse_labels_cache.get(csvf) is not None:
return t2_parse_labels_cache.get(csvf)
df = pd.read_csv(csvf, sep=',')
song_ids = df['song_id'].astype(str)+'.mp3'
labels = df[df.columns[1:]]
t2_parse_labels_cache[csvf] = song_ids, labels
return t2_parse_labels_cache[csvf]
def processor_midlevel44k(file_path):
n_fft = 2048 # 2048
sr = 44100 # 22050 # 44100 # 32000
mono = True # @todo ask mattias
log_spec = False
n_mels = 256
hop_length = 512
fmax = None
dpath, filename = os.path.split(file_path)
#file_path2 = dpath + "/../audio22k/" + filename
if mono:
# this is the slowest part resampling
sig, sr = librosa.load(file_path, sr=sr, mono=True)
sig = sig[np.newaxis]
else:
sig, sr = librosa.load(file_path, sr=sr, mono=False)
# sig, sf_sr = sf.read(file_path)
# sig = np.transpose(sig, (1, 0))
# sig = np.asarray([librosa.resample(s, sf_sr, sr) for s in sig])
spectrograms = []
for y in sig:
# compute stft
stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=None, window='hann', center=True,
pad_mode='reflect')
# keep only amplitures
stft = np.abs(stft)
# spectrogram weighting
if log_spec:
stft = np.log10(stft + 1)
else:
freqs = librosa.core.fft_frequencies(sr=sr, n_fft=n_fft)
stft = librosa.perceptual_weighting(stft ** 2, freqs, ref=1.0, amin=1e-10, top_db=80.0)
# apply mel filterbank
spectrogram = librosa.feature.melspectrogram(S=stft, sr=sr, n_mels=n_mels, fmax=fmax)
# keep spectrogram
spectrograms.append(np.asarray(spectrogram))
spectrograms = np.asarray(spectrograms, dtype=np.float32)
return torch.from_numpy(spectrograms)
audio_processor = processor_midlevel44k
label_encoder = None
def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
audio_path = os.path.expanduser(audio_path)
global label_encoder
print("loading dataset from '{}'".format(name))
def getdatset():
files, labels = t2_parse_labels(midlevel_files_csv)
return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
x_name=cache_x_name,
cache_path=PATH_DATA_CACHE
)
return df_trset, len(df_trset)
from utils import USE_GPU, init_experiment, set_paths
from pytorch_lightning import Trainer
from test_tube import Experiment, HyperOptArgumentParser
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
from models.midlevel_vgg import ModelMidlevel as Network
import torch
config = {
'epochs': 1,
'patience': 50,
'earlystopping_metric': 'val_loss', # 'val_prauc'
'earlystopping_mode': 'min' # 'max'
}
def epochs_500():
global config
config['epochs'] = 500
def epochs_100():
global config
config['epochs'] = 100
config['patience'] = 20
def epochs_20():
global config
config['epochs'] = 20
def pretrain_midlevel(hparams):
set_paths('midlevel')
from utils import CURR_RUN_PATH, logger, streamlog # import these after init_experiment
streamlog.info("Training midlevel...")
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
verbose=True,
mode=config['earlystopping_mode']
)
chkpt_dir = os.path.join(CURR_RUN_PATH, 'midlevel.ckpt')
checkpoint_callback = ModelCheckpoint(
filepath=chkpt_dir,
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
)
if USE_GPU:
trainer = Trainer(
gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
else:
trainer = Trainer(
experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=False, checkpoint_callback=checkpoint_callback
)
model = Network(num_targets=7)
print(model)
trainer.fit(model)
# streamlog.info("Running test")
# trainer.test()
logger.info(f"Loading model from {chkpt_dir}")
model = Network(num_targets=7, on_gpu=USE_GPU, load_from=chkpt_dir)
logger.info(f"Loaded model successfully")
pass
def run(hparams):
init_experiment(comment=hparams.experiment_name)
pretrain_midlevel(hparams)
if __name__=='__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--config', type=str, help='config function to run')
#TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
# implementation of argument parser
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
import torch.nn as nn
from datasets.midlevel import df_get_midlevel_set
from models.shared_stuff import *
from utils import *
from datasets.dataset import HDF5Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from sklearn import metrics
def initialize_weights(module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight.data, mode='fan_in', nonlinearity="relu")
# nn.init.kaiming_normal_(module.weight.data, mode='fan_out')
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.bias.data.zero_()
class ModelMidlevel(pl.LightningModule):
def __init__(self, num_targets, initialize=True, load_from=None, on_gpu=None, map_location=None):
super(ModelMidlevel, self).__init__()
data_root, audio_path, csvs_path = get_paths()
cache_x_name = '_ap_midlevel44k'
from torch.utils.data import random_split
dataset, dataset_length = df_get_midlevel_set('midlevel', os.path.join(csvs_path, 'annotations.csv'), audio_path, cache_x_name)
self.trainset, self.validationset, self.testset = random_split(dataset, [int(i*dataset_length) for i in [0.7, 0.2, 0.1]])
self.num_targets = num_targets
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, 5, 2, 2), # (in_channels, out_channels, kernel_size, stride, padding)
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.mp2x2_dropout = nn.Sequential(
nn.MaxPool2d(2),
nn.Dropout2d(0.3)
)
self.ap2x2_dropout = nn.Sequential(
nn.AvgPool2d(2),
nn.Dropout2d(0.3)
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv5 = nn.Sequential(
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv6 = nn.Sequential(
nn.Conv2d(256, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv7 = nn.Sequential(
nn.Conv2d(256, 384, 3, 1, 1),
nn.BatchNorm2d(384),
nn.ReLU()
)
self.conv7b = nn.Sequential(
nn.Conv2d(384, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv11 = nn.Sequential(
nn.Conv2d(512, 256, 1, 1, 0),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1))
)
self.fc_ml = nn.Linear(256, 7)
if initialize:
self.apply(initialize_weights)
if load_from:
self._load_model(load_from, map_location, on_gpu)
def forward(self, x):
# 313 * 149 * 1
x = self.conv1(x) # 157 * 75 * 64
x = self.conv2(x) # 157 * 75 * 64
x = self.ap2x2_dropout(x) # 78 * 37 * 64
x = self.conv3(x) # 78 * 37 * 128
x = self.conv4(x) # 78 * 37 * 128
x = self.ap2x2_dropout(x) # 39 * 18 * 128
x = self.conv5(x) # 39 * 18 * 256
x = self.conv6(x) # 39 * 18 * 256
x = self.conv7(x) # 39 * 18 * 384
x = self.conv7b(x) # 39 * 18 * 384
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
ml = self.fc_ml(x)
return ml
def _load_model(self, load_from, map_location=None, on_gpu=True):
last_epoch = -1
last_ckpt_name = None
import re
checkpoints = os.listdir(load_from)
for name in checkpoints:
# ignore hpc ckpts
if 'hpc_' in name:
continue
if '.ckpt' in name:
epoch = name.split('epoch_')[1]
epoch = int(re.sub('[^0-9]', '', epoch))
if epoch > last_epoch:
last_epoch = epoch
last_ckpt_name = name
# restore last checkpoint
if last_ckpt_name is not None:
last_ckpt_path = os.path.join(load_from, last_ckpt_name)
if on_gpu:
if map_location is not None:
checkpoint = torch.load(last_ckpt_path, map_location=map_location)
else:
checkpoint = torch.load(last_ckpt_path)
else:
checkpoint = torch.load(last_ckpt_path, map_location=lambda storage, loc: storage)
self.load_state_dict(checkpoint['state_dict'])
def my_loss(self, y_hat, y):
return F.mse_loss(y_hat, y)
def forward_full_song(self, x, y):
# print(x.shape)
#TODO full song???
return self.forward(x[:, :, :, :512])
# y_hat = torch.zeros((x.shape[0], 56), requires_grad=True).cuda()
# hop_size = 256
# i=0
# count = 0
# while i < x.shape[-1]:
# y_hat += self.forward(x[:,:,:,i:i+512])
# i += hop_size
# count += 1
# return y_hat/count
def training_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
def validation_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'val_loss': self.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
y = []
y_hat = []
for output in outputs:
y.append(output['y'])
y_hat.append(output['y_hat'])
y = np.concatenate(y)
y_hat = np.concatenate(y_hat)
return {'val_loss': avg_loss}
def test_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
return {'test_loss': self.my_loss(y_hat, y),
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
}
def test_end(self, outputs):
avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
test_metrics = {"test_loss":avg_test_loss}
self.experiment.log(test_metrics)
return test_metrics
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader
def tng_dataloader(self):
return DataLoader(dataset=self.trainset, batch_size=32, shuffle=True)
@pl.data_loader
def val_dataloader(self):
return DataLoader(dataset=self.validationset, batch_size=32, shuffle=True)
@pl.data_loader
def test_dataloader(self):
return DataLoader(dataset=self.testset, batch_size=32, shuffle=True)
@staticmethod
def add_model_specific_args(parent_parser):
return parent_parser
pass
......@@ -19,8 +19,28 @@ plt.rcParams["figure.dpi"] = 288 # increase dpi for clearer plots
# PARAMS =======================
INPUT_SIZE = (96, 256)
MAX_LENGTH = 10000
data_roots = {
"mtgjamendo":{
"rechenknecht3.cp.jku.at": "/media/rk3/shared/datasets/MTG-Jamendo",
"rechenknecht2.cp.jku.at": "/media/rk2/shared/datasets/MTG-Jamendo",
"rechenknecht1.cp.jku.at": "/media/rk1/shared/datasets/MTG-Jamendo",
"hermine": "/media/verena/SAMSUNG/Data/MTG-Jamendo",
"verena-830g5": "/media/verena/SAMSUNG/Data/MTG-Jamendo",
"shreyan-HP": "/home/shreyan/mounts/home@rk3/shared/datasets/MTG-Jamendo",
"shreyan-All-Series": "/mnt/2tb/datasets/MTG-Jamendo"
},
"midlevel":{
"rechenknecht3.cp.jku.at": "/media/rk3/shared/datasets/midlevel",
"rechenknecht2.cp.jku.at": "/media/rk2/shared/datasets/midlevel",
"rechenknecht1.cp.jku.at": "/media/rk1/shared/datasets/midlevel",
"hermine":"",
"verena-830g5":"",
"shreyan-HP": "/mnt/2tb/datasets/MidlevelFeatures",
"shreyan-All-Series": "/mnt/2tb/datasets/MidlevelFeatures"
}
}
use_dataset= "mtgjamendo"
# CONFIG =======================
# paths:
......@@ -33,33 +53,32 @@ username = getpass.getuser()
if hostname in ['rechenknecht3.cp.jku.at']:
plt.switch_backend('agg')
PATH_DATA_ROOT = '/media/rk3/shared/datasets/MTG-Jamendo'
PATH_DATA_CACHE = '/media/rk3/shared/kofta_cached_datasets'
USE_GPU = True
elif hostname == 'rechenknecht2.cp.jku.at':
plt.switch_backend('agg')
PATH_DATA_ROOT = '/media/rk2/shared/datasets/MTG-Jamendo'
PATH_DATA_CACHE = '/media/rk2/shared/kofta_cached_datasets'
USE_GPU = True
elif hostname == 'rechenknecht1.cp.jku.at':
plt.switch_backend('agg')
PATH_DATA_CACHE = '/media/rk1/shared/kofta_cached_datasets'
USE_GPU = True
elif hostname == 'hermine': # PC verena
plt.switch_backend('agg')
PATH_DATA_ROOT = '/media/verena/SAMSUNG/Data/MTG-Jamendo'
USE_GPU = True
elif hostname == 'verena-830g5': # Laptop Verena
PATH_DATA_ROOT = '/media/verena/SAMSUNG/Data/MTG-Jamendo'
USE_GPU = False
elif hostname == 'shreyan-HP': # Laptop Shreyan
PATH_DATA_ROOT = '/home/shreyan/mounts/home@rk3/shared/datasets/MTG-Jamendo'
USE_GPU = False
else:
PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo'
# PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
# PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
PATH_DATA_CACHE = '/mnt/2tb/datasets/data_caches'
USE_GPU = False
if username == 'verena':
PATH_RESULTS = '/home/verena/experiments/moodwalk/'
PATH_DATA_ROOT = data_roots[use_dataset][hostname]
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
......@@ -68,6 +87,21 @@ PATH_MELSPEC_DOWNLOADED_FRAMED = os.path.join(PATH_MELSPEC_DOWNLOADED, 'framed')
PATH_MELSPEC_DOWNLOADED_HDF5 = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
TRAINED_MODELS_PATH = ''
def set_paths(dataset_name):
global PATH_DATA_ROOT, PATH_AUDIO, PATH_ANNOTATIONS
PATH_DATA_ROOT = data_roots[dataset_name][hostname]
if dataset_name=='midlevel':
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'metadata_annotations')
elif dataset_name=='mtgjamendo':
PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_audio')
PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations')
def get_paths():
return PATH_DATA_ROOT, PATH_AUDIO, PATH_ANNOTATIONS
# run name
def make_run_name(suffix=''):
# assert ' ' not in suffix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment