Commit 8db5855f authored by Verena Praher's avatar Verena Praher
Browse files

add mask_mode; more intelligent augmentation; cleanup code

parent 963a2464
import random import random
def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False): def freq_mask(spec, F=30, num_masks=1, replace_with=0.):
cloned = spec.clone() cloned = spec.clone()
num_mel_channels = cloned.shape[1] num_mel_channels = cloned.shape[1]
...@@ -13,15 +13,12 @@ def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False): ...@@ -13,15 +13,12 @@ def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False):
if (f_zero == f_zero + f): return cloned if (f_zero == f_zero + f): return cloned
mask_end = random.randrange(f_zero, f_zero + f) mask_end = random.randrange(f_zero, f_zero + f)
if (replace_with_zero): cloned[0][f_zero:mask_end] = replace_with
cloned[0][f_zero:mask_end] = 0
else:
cloned[0][f_zero:mask_end] = cloned.mean()
return cloned return cloned
def time_mask(spec, T=40, num_masks=1, replace_with_zero=False): def time_mask(spec, T=40, num_masks=1, replace_with=0.):
cloned = spec.clone() cloned = spec.clone()
len_spectro = cloned.shape[2] len_spectro = cloned.shape[2]
...@@ -33,8 +30,6 @@ def time_mask(spec, T=40, num_masks=1, replace_with_zero=False): ...@@ -33,8 +30,6 @@ def time_mask(spec, T=40, num_masks=1, replace_with_zero=False):
if (t_zero == t_zero + t): return cloned if (t_zero == t_zero + t): return cloned
mask_end = random.randrange(t_zero, t_zero + t) mask_end = random.randrange(t_zero, t_zero + t)
if (replace_with_zero): cloned[0][:, t_zero:mask_end] = replace_with
cloned[0][:, t_zero:mask_end] = 0
else:
cloned[0][:, t_zero:mask_end] = cloned.mean()
return cloned return cloned
\ No newline at end of file
...@@ -272,6 +272,7 @@ class H5FCachedDataset(Dataset): ...@@ -272,6 +272,7 @@ class H5FCachedDataset(Dataset):
self.cache_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name) self.cache_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name)
self.cache_file_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name + ".hdf5") self.cache_file_path = os.path.join(cache_path, dataset_name, "h5py_files_cache", self.x_name + ".hdf5")
self.slicing_function = slicing_function self.slicing_function = slicing_function
print("slicing_function", self.slicing_function)
self.augment_options = augment_options self.augment_options = augment_options
try: try:
original_umask = os.umask(0) original_umask = os.umask(0)
...@@ -319,11 +320,19 @@ class H5FCachedDataset(Dataset): ...@@ -319,11 +320,19 @@ class H5FCachedDataset(Dataset):
if self.augment_options is not None: if self.augment_options is not None:
mask_freq_num = self.augment_options['mask_freq_num'] mask_freq_num = self.augment_options['mask_freq_num']
mask_time_num = self.augment_options['mask_time_num'] mask_time_num = self.augment_options['mask_time_num']
mask_mode = self.augment_options['mask_mode']
# x_replace = 0.
if mask_mode == 'mean':
x_replace = x.mean()
elif mask_mode == 'min':
x_replace = x.min()
else:
raise ValueError('unknown mask_mode', mask_mode)
if mask_freq_num > 0: if mask_freq_num > 0:
x = freq_mask(x, mask_freq_num, True) x = freq_mask(x, num_masks=mask_freq_num, replace_with=x_replace)
if mask_time_num > 0: if mask_time_num > 0:
x = time_mask(x, mask_time_num, True) x = time_mask(x, num_masks=mask_time_num, replace_with=x_replace)
return x, y, z return x, y, z
......
...@@ -5,17 +5,12 @@ from models.baseline_w_augmentation import CNN as Network ...@@ -5,17 +5,12 @@ from models.baseline_w_augmentation import CNN as Network
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os import os
config = { model_config = {
'epochs': 1, 'data_source':'mtgjamendo',
'patience': 20, 'validation_metrics':['rocauc', 'prauc'],
'earlystopping_metric': 'val_loss', # 'val_prauc' 'test_metrics':['rocauc', 'prauc']
'earlystopping_mode': 'min' # 'max'
} }
def epochs_100():
global config
config['epochs'] = 100
def run(hparams): def run(hparams):
init_experiment(comment=hparams.experiment_name) init_experiment(comment=hparams.experiment_name)
...@@ -24,26 +19,6 @@ def run(hparams): ...@@ -24,26 +19,6 @@ def run(hparams):
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}") logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(save_dir=CURR_RUN_PATH) exp = Experiment(save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
global config
# parameters used in the baseline (read from main.py and solver.py) # parameters used in the baseline (read from main.py and solver.py)
# n_epochs = 500 # n_epochs = 500
# lr = 1e-4 # lr = 1e-4
...@@ -51,10 +26,10 @@ def run(hparams): ...@@ -51,10 +26,10 @@ def run(hparams):
# batch_size = 32 # batch_size = 32
early_stop = EarlyStopping( early_stop = EarlyStopping(
monitor=config['earlystopping_metric'], monitor='val_loss',
patience=config['patience'], patience=20,
verbose=True, verbose=True,
mode=config['earlystopping_mode'] mode='min'
) )
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
...@@ -67,7 +42,7 @@ def run(hparams): ...@@ -67,7 +42,7 @@ def run(hparams):
if USE_GPU: if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp', trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent, experiment=exp, max_nb_epochs=hparams.max_epochs, train_percent_check=hparams.train_percent,
fast_dev_run=False, fast_dev_run=False,
early_stop_callback=early_stop, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback checkpoint_callback=checkpoint_callback
...@@ -76,7 +51,7 @@ def run(hparams): ...@@ -76,7 +51,7 @@ def run(hparams):
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01, trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True) fast_dev_run=True)
model = Network(num_class=56, hparams=hparams) # TODO num_class model = Network(model_config, num_class=56, hparams=hparams) # TODO num_class
print(model) print(model)
...@@ -91,6 +66,11 @@ if __name__=='__main__': ...@@ -91,6 +66,11 @@ if __name__=='__main__':
parent_parser.add_argument('--config', type=str, help='config function to run') parent_parser.add_argument('--config', type=str, help='config function to run')
parent_parser.add_argument('--train_percent', type=float, parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use') default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs')
parent_parser.add_argument('--slicing_mode', default='slice', type=str)
parent_parser.add_argument('--input_size', default=512, type=int)
parent_parser.add_argument('--batch_size', default=32, type=int)
parser = Network.add_model_specific_args(parent_parser) parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args() hyperparams = parser.parse_args()
run(hyperparams) run(hyperparams)
\ No newline at end of file
...@@ -2,11 +2,9 @@ from utils import * ...@@ -2,11 +2,9 @@ from utils import *
import pytorch_lightning as pl import pytorch_lightning as pl
from models.shared_stuff import * from models.shared_stuff import *
from sklearn import metrics from sklearn import metrics
from models.shared_stuff import BasePtlModel
# TODO pr-auc class CNN(BasePtlModel):
# TODO f1-score
class CNN(pl.LightningModule):
def __init__(self, num_class): def __init__(self, num_class):
super(CNN, self).__init__() super(CNN, self).__init__()
...@@ -80,30 +78,6 @@ class CNN(pl.LightningModule): ...@@ -80,30 +78,6 @@ class CNN(pl.LightningModule):
return logit return logit
def my_loss(self, y_hat, y):
return my_loss(y_hat, y)
def training_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
def validation_step(self, data_batch, batch_nb):
return validation_step(self, data_batch, batch_nb)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
test_metrics = test_end(outputs)
self.experiment.log(test_metrics)
return test_metrics
def validation_end(self, outputs):
return validation_end(outputs)
def configure_optimizers(self): def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
...@@ -111,14 +85,6 @@ class CNN(pl.LightningModule): ...@@ -111,14 +85,6 @@ class CNN(pl.LightningModule):
def tng_dataloader(self): def tng_dataloader(self):
return tng_dataloader() return tng_dataloader()
@pl.data_loader
def val_dataloader(self):
return val_dataloader()
@pl.data_loader
def test_dataloader(self):
return test_dataloader()
@staticmethod @staticmethod
def add_model_specific_args(parent_parser): def add_model_specific_args(parent_parser):
return parent_parser return parent_parser
......
from utils import * from utils import *
import pytorch_lightning as pl import pytorch_lightning as pl
from models.shared_stuff import * from models.shared_stuff import BasePtlModel
from sklearn import metrics from sklearn import metrics
from test_tube import HyperOptArgumentParser from test_tube import HyperOptArgumentParser
...@@ -9,13 +9,14 @@ from test_tube import HyperOptArgumentParser ...@@ -9,13 +9,14 @@ from test_tube import HyperOptArgumentParser
# TODO pr-auc # TODO pr-auc
# TODO f1-score # TODO f1-score
class CNN(pl.LightningModule): class CNN(BasePtlModel):
def __init__(self, hparams, num_class): def __init__(self, config, num_class, hparams):
super(CNN, self).__init__() super(CNN, self).__init__(config, hparams)
self.hparams = hparams self.hparams = hparams
self.augment_options = {'mask_freq_num': hparams.mask_freq_num, self.augment_options = {'mask_freq_num': hparams.mask_freq_num,
'mask_time_num': hparams.mask_time_num 'mask_time_num': hparams.mask_time_num,
'mask_mode': hparams.mask_mode
} }
# init bn # init bn
...@@ -88,44 +89,45 @@ class CNN(pl.LightningModule): ...@@ -88,44 +89,45 @@ class CNN(pl.LightningModule):
return logit return logit
def my_loss(self, y_hat, y): # def my_loss(self, y_hat, y):
return my_loss(y_hat, y) # return my_loss(y_hat, y)
#
def training_step(self, data_batch, batch_nb): # def training_step(self, data_batch, batch_nb):
return training_step(self, data_batch, batch_nb) # return training_step(self, data_batch, batch_nb)
#
def validation_step(self, data_batch, batch_nb): # def validation_step(self, data_batch, batch_nb):
return validation_step(self, data_batch, batch_nb) # return validation_step(self, data_batch, batch_nb)
#
def test_step(self, data_batch, batch_nb): # def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb) # return test_step(self, data_batch, batch_nb)
def test_end(self, outputs): # def test_end(self, outputs):
test_metrics = test_end(outputs) # test_metrics = test_end(outputs)
self.experiment.log(test_metrics) # self.experiment.log(test_metrics)
return test_metrics # return test_metrics
def validation_end(self, outputs): # def validation_end(self, outputs):
return validation_end(outputs) # return validation_end(outputs)
def configure_optimizers(self): def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader # @pl.data_loader
def tng_dataloader(self): # def tng_dataloader(self):
return tng_dataloader(augment_options=self.augment_options) # return tng_dataloader(augment_options=self.augment_options)
#
@pl.data_loader # @pl.data_loader
def val_dataloader(self): # def val_dataloader(self):
return val_dataloader(augment_options=self.augment_options) # return val_dataloader(augment_options=self.augment_options)
#
@pl.data_loader # @pl.data_loader
def test_dataloader(self): # def test_dataloader(self):
return test_dataloader() # return test_dataloader()
@staticmethod @staticmethod
def add_model_specific_args(parent_parser): def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--mask_freq_num', default=0, type=int) parser.add_argument('--mask_freq_num', default=0, type=int)
parser.add_argument('--mask_time_num', default=0, type=int) parser.add_argument('--mask_time_num', default=0, type=int)
parser.add_argument('--mask_mode', default='mean', type=str)
return parser return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment