Commit 8db5855f authored by Verena Praher's avatar Verena Praher
Browse files

add mask_mode; more intelligent augmentation; cleanup code

parent 963a2464
import random
def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False):
def freq_mask(spec, F=30, num_masks=1, replace_with=0.):
cloned = spec.clone()
num_mel_channels = cloned.shape[1]
......@@ -13,15 +13,12 @@ def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False):
if (f_zero == f_zero + f): return cloned
mask_end = random.randrange(f_zero, f_zero + f)
if (replace_with_zero):
cloned[0][f_zero:mask_end] = 0
else:
cloned[0][f_zero:mask_end] = cloned.mean()
cloned[0][f_zero:mask_end] = replace_with
return cloned
def time_mask(spec, T=40, num_masks=1, replace_with_zero=False):
def time_mask(spec, T=40, num_masks=1, replace_with=0.):
cloned = spec.clone()
len_spectro = cloned.shape[2]
......@@ -33,8 +30,6 @@ def time_mask(spec, T=40, num_masks=1, replace_with_zero=False):
if (t_zero == t_zero + t): return cloned
mask_end = random.randrange(t_zero, t_zero + t)
if (replace_with_zero):
cloned[0][:, t_zero:mask_end] = 0
else:
cloned[0][:, t_zero:mask_end] = cloned.mean()
cloned[0][:, t_zero:mask_end] = replace_with
return cloned
\ No newline at end of file
......@@ -272,6 +272,7 @@ class H5FCachedDataset(Dataset):
self.cache_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name)
self.cache_file_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name + ".hdf5")
self.slicing_function = slicing_function
print("slicing_function", self.slicing_function)
self.augment_options = augment_options
try:
original_umask = os.umask(0)
......@@ -319,11 +320,19 @@ class H5FCachedDataset(Dataset):
if self.augment_options is not None:
mask_freq_num = self.augment_options['mask_freq_num']
mask_time_num = self.augment_options['mask_time_num']
mask_mode = self.augment_options['mask_mode']
# x_replace = 0.
if mask_mode == 'mean':
x_replace = x.mean()
elif mask_mode == 'min':
x_replace = x.min()
else:
raise ValueError('unknown mask_mode', mask_mode)
if mask_freq_num > 0:
x = freq_mask(x, mask_freq_num, True)
x = freq_mask(x, num_masks=mask_freq_num, replace_with=x_replace)
if mask_time_num > 0:
x = time_mask(x, mask_time_num, True)
x = time_mask(x, num_masks=mask_time_num, replace_with=x_replace)
return x, y, z
......
......@@ -5,17 +5,12 @@ from models.baseline_w_augmentation import CNN as Network
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
config = {
'epochs': 1,
'patience': 20,
'earlystopping_metric': 'val_loss', # 'val_prauc'
'earlystopping_mode': 'min' # 'max'
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
}
def epochs_100():
global config
config['epochs'] = 100
def run(hparams):
init_experiment(comment=hparams.experiment_name)
......@@ -24,26 +19,6 @@ def run(hparams):
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
global config
# parameters used in the baseline (read from main.py and solver.py)
# n_epochs = 500
# lr = 1e-4
......@@ -51,10 +26,10 @@ def run(hparams):
# batch_size = 32
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
patience=config['patience'],
monitor='val_loss',
patience=20,
verbose=True,
mode=config['earlystopping_mode']
mode='min'
)
checkpoint_callback = ModelCheckpoint(
......@@ -67,7 +42,7 @@ def run(hparams):
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
experiment=exp, max_nb_epochs=hparams.max_epochs, train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
......@@ -76,7 +51,7 @@ def run(hparams):
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True)
model = Network(num_class=56, hparams=hparams) # TODO num_class
model = Network(model_config, num_class=56, hparams=hparams) # TODO num_class
print(model)
......@@ -91,6 +66,11 @@ if __name__=='__main__':
parent_parser.add_argument('--config', type=str, help='config function to run')
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs')
parent_parser.add_argument('--slicing_mode', default='slice', type=str)
parent_parser.add_argument('--input_size', default=512, type=int)
parent_parser.add_argument('--batch_size', default=32, type=int)
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
......@@ -2,11 +2,9 @@ from utils import *
import pytorch_lightning as pl
from models.shared_stuff import *
from sklearn import metrics
from models.shared_stuff import BasePtlModel
# TODO pr-auc
# TODO f1-score
class CNN(pl.LightningModule):
class CNN(BasePtlModel):
def __init__(self, num_class):
super(CNN, self).__init__()
......@@ -80,30 +78,6 @@ class CNN(pl.LightningModule):
return logit
def my_loss(self, y_hat, y):
return my_loss(y_hat, y)
def training_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
def validation_step(self, data_batch, batch_nb):
return validation_step(self, data_batch, batch_nb)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
test_metrics = test_end(outputs)
self.experiment.log(test_metrics)
return test_metrics
def validation_end(self, outputs):
return validation_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
......@@ -111,14 +85,6 @@ class CNN(pl.LightningModule):
def tng_dataloader(self):
return tng_dataloader()
@pl.data_loader
def val_dataloader(self):
return val_dataloader()
@pl.data_loader
def test_dataloader(self):
return test_dataloader()
@staticmethod
def add_model_specific_args(parent_parser):
return parent_parser
......
from utils import *
import pytorch_lightning as pl
from models.shared_stuff import *
from models.shared_stuff import BasePtlModel
from sklearn import metrics
from test_tube import HyperOptArgumentParser
......@@ -9,13 +9,14 @@ from test_tube import HyperOptArgumentParser
# TODO pr-auc
# TODO f1-score
class CNN(pl.LightningModule):
def __init__(self, hparams, num_class):
super(CNN, self).__init__()
class CNN(BasePtlModel):
def __init__(self, config, num_class, hparams):
super(CNN, self).__init__(config, hparams)
self.hparams = hparams
self.augment_options = {'mask_freq_num': hparams.mask_freq_num,
'mask_time_num': hparams.mask_time_num
'mask_time_num': hparams.mask_time_num,
'mask_mode': hparams.mask_mode
}
# init bn
......@@ -88,44 +89,45 @@ class CNN(pl.LightningModule):
return logit
def my_loss(self, y_hat, y):
return my_loss(y_hat, y)
def training_step(self, data_batch, batch_nb):
return training_step(self, data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb):
return validation_step(self, data_batch, batch_nb)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
test_metrics = test_end(outputs)
self.experiment.log(test_metrics)
return test_metrics
def validation_end(self, outputs):
return validation_end(outputs)
# def my_loss(self, y_hat, y):
# return my_loss(y_hat, y)
#
# def training_step(self, data_batch, batch_nb):
# return training_step(self, data_batch, batch_nb)
#
# def validation_step(self, data_batch, batch_nb):
# return validation_step(self, data_batch, batch_nb)
#
# def test_step(self, data_batch, batch_nb):
# return test_step(self, data_batch, batch_nb)
# def test_end(self, outputs):
# test_metrics = test_end(outputs)
# self.experiment.log(test_metrics)
# return test_metrics
# def validation_end(self, outputs):
# return validation_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader
def tng_dataloader(self):
return tng_dataloader(augment_options=self.augment_options)
@pl.data_loader
def val_dataloader(self):
return val_dataloader(augment_options=self.augment_options)
@pl.data_loader
def test_dataloader(self):
return test_dataloader()
# @pl.data_loader
# def tng_dataloader(self):
# return tng_dataloader(augment_options=self.augment_options)
#
# @pl.data_loader
# def val_dataloader(self):
# return val_dataloader(augment_options=self.augment_options)
#
# @pl.data_loader
# def test_dataloader(self):
# return test_dataloader()
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--mask_freq_num', default=0, type=int)
parser.add_argument('--mask_time_num', default=0, type=int)
parser.add_argument('--mask_mode', default='mean', type=str)
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment