Commit 9bf4bd54 authored by Paul Primus's avatar Paul Primus
Browse files

add vae training

parent 5284c500
......@@ -93,4 +93,5 @@ class BaseExperiment(ABC, torch.nn.Module):
def run(self):
self.trainer.fit(self)
self.trainer.test(self)
self.trainer.save_checkpoint(os.path.join(self.objects['log_path'], "model.ckpt"))
return self.result
from experiments import BaseExperiment
from datetime import datetime
import os
import pytorch_lightning as pl
import torch
from sacred import Experiment
from utils.logger import Logger
import torch.utils.data
# workaround...
from sacred import SETTINGS
SETTINGS['CAPTURE_MODE'] = 'sys'
import numpy as np
class VAEExperiment(BaseExperiment, pl.LightningModule):
'''
Reproduction of the DCASE Baseline. It is basically an Auto Encoder, the anomaly score is the reconstruction error.
'''
def __init__(self, configuration_dict, _run):
super().__init__(configuration_dict)
self.network = self.objects['auto_encoder_model']
self.prior = self.objects['prior']
self.reconstruction = self.objects['reconstruction']
self.logger_ = Logger(_run, self, self.configuration_dict, self.objects)
# experiment state variables
self.epoch = -1
self.step = 0
self.result = None
def forward(self, batch):
batch['epoch'] = self.epoch
batch = self.network(batch)
return batch
def training_step(self, batch, batch_num, optimizer_idx=0):
if batch_num == 0 and optimizer_idx == 0:
self.epoch += 1
if optimizer_idx == 0:
batch = self(batch)
reconstruction_loss = self.reconstruction.loss(batch)
prior_loss = self.prior.loss(batch)
batch['reconstruction_loss'] = reconstruction_loss / (self.objects['batch_size'] * self.objects['num_mel'] * self.objects['context'])
batch['prior_loss'] = prior_loss / self.objects['batch_size']
batch['loss'] = reconstruction_loss + prior_loss
if batch_num == 0:
self.logger_.log_reconstruction(batch, self.epoch)
self.logger_.log_training_step(batch, self.step)
self.step += 1
else:
raise AttributeError
return {
'loss': batch['loss'],
'tqdm': {'loss': batch['loss']},
}
def validation_step(self, batch, batch_num):
self(batch)
return {
'targets': batch['targets'],
'scores': batch['scores'],
'machine_types': batch['machine_types'],
'machine_ids': batch['machine_ids'],
'part_numbers': batch['part_numbers'],
'file_ids': batch['file_ids']
}
def validation_end(self, outputs):
self.logger_.log_vae_validation(outputs, self.step, self.epoch)
return {
'val_loss': np.concatenate([o['scores'].detach().cpu().numpy() for o in outputs]).mean()
}
def test_step(self, batch, batch_num):
return self.validation_step(batch, batch_num)
def test_end(self, outputs):
# TODO: add new logging method
# self.result = self.logger_.log_testing(outputs)
self.logger_.close()
return {}
def train_dataloader(self):
if self.objects['debug']:
ds = torch.utils.data.Subset(self.objects['data_set'].get_whole_training_data_set(), np.arange(1024))
else:
ds = self.objects['data_set'].get_whole_training_data_set()
dl = torch.utils.data.DataLoader(
ds,
batch_size=self.objects['batch_size'],
shuffle=True,
num_workers=self.objects['num_workers'],
drop_last=False
)
return dl
def configuration():
seed = 1220
deterministic = False
id = datetime.now().strftime("%Y-%m-%d_%H:%M:%S:%f")
log_path = os.path.join('..', 'experiment_logs', id)
#####################
# quick configuration, uses default parameters of more detailed configuration
#####################
machine_type = 0
machine_id = 0
latent_size = 40
batch_size = 512
debug = False
if debug:
epochs = 1
num_workers = 0
else:
epochs = 50
num_workers = 4
learning_rate = 1e-4
weight_decay = 0
normalize = 'all'
normalize_raw = True
prior_class = 'priors.StandardNormalPrior'
context = 11
descriptor = "vae_training_{}_{}_{}_{}_{}_{}_{}_{}".format(
prior_class,
latent_size,
batch_size,
learning_rate,
weight_decay,
normalize,
normalize_raw,
context
)
########################
# detailed configuration
########################
num_mel = 40
n_fft = 512
hop_size = 256
prior = {
'class': prior_class,
'kwargs': {
'latent_size': latent_size,
'weight': 1
}
}
data_set = {
'class': 'data_sets.MCMDataSet',
'kwargs': {
'context': context,
'num_mel': num_mel,
'n_fft': n_fft,
'hop_size': hop_size,
'normalize': normalize,
'normalize_raw': normalize_raw
}
}
reconstruction = {
'class': 'losses.MSE',
'kwargs': {
'weight': 1,
'input_shape': '@data_set.observation_shape'
}
}
auto_encoder_model = {
'class': 'models.SamplingFCAE',
'args': [
'@data_set.observation_shape',
'@reconstruction',
'@prior'
]
}
lr_scheduler = {
'class': 'torch.optim.lr_scheduler.StepLR',
'args': [
'@optimizer',
],
'kwargs': {
'step_size': epochs
}
}
optimizer = {
'class': 'torch.optim.Adam',
'args': [
'@auto_encoder_model.parameters()'
],
'kwargs': {
'lr': learning_rate,
'betas': (0.9, 0.999),
'amsgrad': False,
'weight_decay': weight_decay,
}
}
trainer = {
'class': 'trainers.PTLTrainer',
'kwargs': {
'max_epochs': epochs,
'checkpoint_callback': False,
'logger': False,
'early_stop_callback': False,
'gpus': [0],
'show_progress_bar': True,
'progress_bar_refresh_rate': 1000
}
}
ex = Experiment('dcase2020_task2_vae_training')
cfg = ex.config(configuration)
@ex.automain
def run(_config, _run):
experiment = VAEExperiment(_config, _run)
return experiment.run()
......@@ -10,7 +10,7 @@ class MSE(ReconstructionBase):
self.p = p
def loss(self, batch, *args, **kwargs):
bce = F.mse_loss(batch['predictions'], batch['observations'], reduction='mean')
bce = F.mse_loss(batch['predictions'], batch['observations'], reduction='sum')
batch['reconstruction_loss'] = self.weight * bce
return batch['reconstruction_loss']
......
......@@ -42,8 +42,8 @@ class BaselineFCAE(torch.nn.Module, VAEBase):
torch.nn.BatchNorm1d(128),
torch.nn.ReLU(True),
# bn
torch.nn.Linear(128, prior.latent_size),
torch.nn.BatchNorm1d(prior.latent_size),
torch.nn.Linear(128, prior.input_size),
torch.nn.BatchNorm1d(prior.input_size),
torch.nn.ReLU(True)
)
......
......@@ -27,7 +27,7 @@ class SamplingFCAE(torch.nn.Module, VAEBase):
torch.nn.ReLU(True),
torch.nn.Linear(512, 512),
torch.nn.ReLU(True),
torch.nn.Linear(512, prior.latent_size),
torch.nn.Linear(512, prior.input_size),
)
self.decoder = torch.nn.Sequential(
......
......@@ -25,7 +25,7 @@ class StandardNormalPrior(PriorBase):
def loss(self, batch):
batch['klds'] = -0.5 * (1 + batch['logvars'] - batch['mus'].pow(2) - batch['logvars'].exp())
batch['prior_loss'] = batch['klds'].sum(1).mean(0)
batch['prior_loss'] = batch['klds'].sum()
return self.weight_anneal(batch)
@property
......
......@@ -46,6 +46,14 @@ class Logger:
elif type(batch[key]) == torch.Tensor and batch[key].ndim == 1 and batch[key].shape[0] == 1:
self.__log_metric__(key, batch[key].item(), step)
def log_reconstruction(self, batch, epoch):
for i, (observation, reconstructed) in enumerate(zip(batch['observations'], batch['predictions'])):
if i == 10:
break
self.__log_image__(observation, '{}_{}_image_x.png'.format(epoch, i))
self.__log_image__(reconstructed, '{}_{}_image_xhat.png'.format(epoch, i))
def log_validation(self, outputs, step, epoch, all_ids=False):
if epoch == -1:
......@@ -63,10 +71,12 @@ class Logger:
for i, typ in enumerate(np.arange(6)):
for j, id in enumerate(TRAINING_ID_MAP[typ]):
plt.subplot(6, 4, ((i*4) + j)+1)
plt.subplot(6, 4, ((i * 4) + j) + 1)
x_normal = scores_mean[np.logical_and(ground_truth == 0, np.logical_and(machine_ids == id, machine_types == typ))]
x_abnormal = scores_mean[np.logical_and(ground_truth == 1, np.logical_and(machine_ids == id, machine_types == typ))]
x_normal = scores_mean[
np.logical_and(ground_truth == 0, np.logical_and(machine_ids == id, machine_types == typ))]
x_abnormal = scores_mean[
np.logical_and(ground_truth == 1, np.logical_and(machine_ids == id, machine_types == typ))]
plt.hist(x_normal, bins, alpha=0.5, label='normal')
plt.hist(x_abnormal, bins, alpha=0.5, label='abnormal')
......@@ -74,7 +84,6 @@ class Logger:
if i == 0 and j == 0:
plt.legend(loc='upper right')
plt.savefig(os.path.join(self.log_dir, 'score_distribution_{}.png'.format(epoch)), bbox_inches='tight')
plt.close()
......@@ -106,6 +115,27 @@ class Logger:
'pauroc_max': float(np.mean(pauroc_max))
}
def log_vae_validation(self, outputs, step, epoch):
if epoch == -1:
return None
errors = np.concatenate([o['scores'].detach().cpu().numpy() for o in outputs])
machine_types = np.concatenate([o['machine_types'].detach().cpu().numpy() for o in outputs])
machine_ids = np.concatenate([o['machine_ids'].detach().cpu().numpy() for o in outputs])
for ty in range(6):
for id in TRAINING_ID_MAP[ty]:
idxs = np.logical_and(machine_types == ty, machine_ids == id)
if np.any(idxs):
self.__log_metric__(
'validation_reconstruction_error_{}_{}'.format(id, ty),
np.mean(errors[idxs]),
step
)
self.__log_metric__('validation_reconstruction_error', np.mean(errors), step)
def log_testing(self, outputs, all_ids=False):
return self.log_validation(outputs, 0, -2, all_ids=all_ids)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment