Commit ce850b60 authored by Verena Praher's avatar Verena Praher
Browse files

add data augmentation: frequency masking and time masking

parent ae749acd
......@@ -6,6 +6,7 @@ import numpy as np
from pathlib import Path
import torch
from os.path import expanduser
from datasets.augment import freq_mask, time_mask
class MelSpecDataset(Dataset):
def __init__(self, phase='train', ann_root=None, spec_root=None, length=MAX_LENGTH, framed=True):
......@@ -73,10 +74,6 @@ class MelSpecDataset(Dataset):
return tagslist
class HDF5Dataset(Dataset):
"""Represents an abstract HDF5 dataset.
......@@ -251,6 +248,7 @@ class H5FCachedDataset(Dataset):
def __init__(self, get_dataset_func, dataset_name, slicing_function, x_name="", y_name="",
cache_path="~/shared/kofta_cached_datasets/",
augment_options=None
):
"""
......@@ -274,6 +272,7 @@ class H5FCachedDataset(Dataset):
self.cache_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name)
self.cache_file_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name + ".hdf5")
self.slicing_function = slicing_function
self.augment_options = augment_options
try:
original_umask = os.umask(0)
os.makedirs(self.cache_path, exist_ok=True)
......@@ -317,6 +316,15 @@ class H5FCachedDataset(Dataset):
(idx, xlen), y, z = torch.load(cpath)
x = self.slicing_function(self.h5data, idx, xlen)
if self.augment_options is not None:
mask_freq_num = self.augment_options['mask_freq_num']
mask_time_num = self.augment_options['mask_time_num']
if mask_freq_num > 0:
x = freq_mask(x, mask_freq_num, True)
if mask_time_num > 0:
x = time_mask(x, mask_time_num, True)
return x, y, z
def get_ordered_labels(self):
......
......@@ -89,7 +89,7 @@ audio_processor = processor_mtgjamendo44k
label_encoder = None
def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name):
def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, augment_options=None):
audio_path = os.path.expanduser(audio_path)
global label_encoder
print("loading dataset from '{}'".format(name))
......@@ -100,7 +100,8 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name):
df_trset = H5FCachedDataset(getdatset, name, slicing_function=sample_slicing_function,
x_name=cache_x_name,
cache_path=PATH_DATA_CACHE
cache_path=PATH_DATA_CACHE,
augment_options=augment_options
)
return df_trset
......
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer
from test_tube import Experiment, HyperOptArgumentParser
from models.baseline_w_augmentation import CNN as Network
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
config = {
'epochs': 1,
'patience': 20,
'earlystopping_metric': 'val_loss', # 'val_prauc'
'earlystopping_mode': 'min' # 'max'
}
def epochs_100():
global config
config['epochs'] = 100
def run(hparams):
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(CURR_RUN_PATH)
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
global config
# parameters used in the baseline (read from main.py and solver.py)
# n_epochs = 500
# lr = 1e-4
# num_class = 56
# batch_size = 32
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
verbose=True,
mode=config['earlystopping_mode']
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(CURR_RUN_PATH, 'best.ckpt'),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True)
model = Network(num_class=56, hparams=hparams) # TODO num_class
print(model)
trainer.fit(model)
trainer.test()
if __name__=='__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--config', type=str, help='config function to run')
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
from utils import *
import pytorch_lightning as pl
from models.shared_stuff import *
from sklearn import metrics
from test_tube import HyperOptArgumentParser
# TODO pr-auc
# TODO f1-score
class CNN(pl.LightningModule):
def __init__(self, hparams, num_class):
super(CNN, self).__init__()
self.hparams = hparams
self.augment_options = {'mask_freq_num': hparams.mask_freq_num,
'mask_time_num': hparams.mask_time_num
}
# init bn
self.bn_init = nn.BatchNorm2d(1)
# layer 1
self.conv_1 = nn.Conv2d(1, 64, 3, padding=1)
self.bn_1 = nn.BatchNorm2d(64)
self.mp_1 = nn.MaxPool2d((2, 4))
# layer 2
self.conv_2 = nn.Conv2d(64, 128, 3, padding=1)
self.bn_2 = nn.BatchNorm2d(128)
self.mp_2 = nn.MaxPool2d((2, 4))
# layer 3
self.conv_3 = nn.Conv2d(128, 128, 3, padding=1)
self.bn_3 = nn.BatchNorm2d(128)
self.mp_3 = nn.MaxPool2d((2, 4))
# layer 4
self.conv_4 = nn.Conv2d(128, 128, 3, padding=1)
self.bn_4 = nn.BatchNorm2d(128)
self.mp_4 = nn.MaxPool2d((3, 5))
# layer 5
self.conv_5 = nn.Conv2d(128, 64, 3, padding=1)
self.bn_5 = nn.BatchNorm2d(64)
self.mp_5 = nn.MaxPool2d((4, 4))
# classifier
self.dense = nn.Linear(320, num_class)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# x = x.unsqueeze(1)
x = x[:, :, :, :512]
# init bn
x = self.bn_init(x)
# print(x.shape)
# layer 1
x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x))))
# print(x.shape)
# layer 2
x = nn.ELU()(self.bn_2(self.conv_2(x)))
# x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x))))
# print(x.shape)
# layer 3
x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x))))
# print(x.shape)
# layer 4
# x = nn.ELU()(self.bn_4(self.conv_4(x)))
x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x))))
# print(x.shape)
# layer 5
x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x))))
# print(x.shape)
# classifier
x = x.view(x.size(0), -1)
# print("Lin input", x.shape)
x = self.dropout(x)
logit = nn.Sigmoid()(self.dense(x))
# print(x.shape)
return logit
def my_loss(self, y_hat, y):
return my_loss(y_hat, y)
def training_step(self, data_batch, batch_nb):
return training_step(self, data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb):
return validation_step(self, data_batch, batch_nb)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
test_metrics = test_end(outputs)
self.experiment.log(test_metrics)
return test_metrics
def validation_end(self, outputs):
return validation_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader
def tng_dataloader(self):
return tng_dataloader(augment_options=self.augment_options)
@pl.data_loader
def val_dataloader(self):
return val_dataloader(augment_options=self.augment_options)
@pl.data_loader
def test_dataloader(self):
return test_dataloader()
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--mask_freq_num', default=0, type=int)
parser.add_argument('--mask_time_num', default=0, type=int)
return parser
......@@ -117,19 +117,19 @@ def test_end(outputs):
'test_fscore': fscore}
def tng_dataloader(batch_size=32):
def tng_dataloader(batch_size=32, augment_options=None):
train_csv = os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name)
dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name, augment_options)
return DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
def val_dataloader(batch_size=32):
def val_dataloader(batch_size=32, augment_options=None):
validation_csv = os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name)
dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name, augment_options)
return DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment