Commit caf8592a authored by Paul Primus's avatar Paul Primus
Browse files

Initial commit

parents
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/dcase_2020_task_2" isTestSource="false" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding" addBOMForNewFiles="with NO BOM" />
</project>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (dcase_2020_task_2)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/dcase_2020_task_2.iml" filepath="$PROJECT_DIR$/.idea/dcase_2020_task_2.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="30c4609b-c0f3-47c2-96b2-44333599f4dd" name="Default Changelist" comment="" />
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FavoritesManager">
<favorites_list name="dcase_2020_task_2" />
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="ProjectId" id="1ZLpUs7YbXsYOxrGVqx7BbcwKzO" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showExcludedFiles" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" />
<property name="last_opened_file_path" value="$PROJECT_DIR$/../Intrument_tagging" />
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
<property name="nodejs_npm_path_reset_for_default_project" value="true" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PythonContentEntriesConfigurable" />
</component>
<component name="RunManager">
<configuration name="mcm_dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="dcase_2020_task_2" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
<env name="OMP_NUM_THREADS" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/dcase_2020_task_2/data_sets" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/dcase_2020_task_2/data_sets/mcm_dataset.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.mcm_dataset" />
</list>
</recent_temporary>
</component>
<component name="SvnConfiguration">
<configuration />
</component>
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="30c4609b-c0f3-47c2-96b2-44333599f4dd" name="Default Changelist" comment="" />
<created>1584628247565</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1584628247565</updated>
<workItem from="1584628276554" duration="416000" />
<workItem from="1584628750333" duration="12003000" />
<workItem from="1584641483829" duration="87000" />
<workItem from="1584641900040" duration="6767000" />
<workItem from="1584649733866" duration="4687000" />
<workItem from="1584655232955" duration="2786000" />
<workItem from="1584697347982" duration="7290000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="1" />
</component>
<component name="WindowStateProjectService">
<state x="418" y="182" key="#com.intellij.execution.impl.EditConfigurationsDialog" timestamp="1584641561314">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state x="418" y="182" key="#com.intellij.execution.impl.EditConfigurationsDialog/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584641561314" />
<state x="693" y="273" width="524" height="509" key="#com.intellij.refactoring.safeDelete.UnsafeUsagesDialog" timestamp="1584710054191">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state x="693" y="273" width="524" height="509" key="#com.intellij.refactoring.safeDelete.UnsafeUsagesDialog/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584710054191" />
<state x="697" y="253" width="515" height="549" key="EnvironmentVariablesDialog" timestamp="1584640280238">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state x="697" y="253" width="515" height="549" key="EnvironmentVariablesDialog/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584640280238" />
<state x="2284" y="474" width="424" height="491" key="FileChooserDialogImpl" timestamp="1584629978620">
<screen x="1920" y="0" width="1920" height="1200" />
</state>
<state x="2284" y="474" width="424" height="491" key="FileChooserDialogImpl/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584629978620" />
<state width="1874" height="174" key="GridCell.Tab.0.bottom" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.0.bottom/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.0.center" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.0.center/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.0.left" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.0.left/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.0.right" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.0.right/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.1.bottom" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.1.bottom/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.1.center" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.1.center/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.1.left" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.1.left/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state width="1874" height="174" key="GridCell.Tab.1.right" timestamp="1584709972934">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state width="1874" height="174" key="GridCell.Tab.1.right/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584709972934" />
<state x="449" y="159" width="1022" height="737" key="SettingsEditor" timestamp="1584628926988">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state x="449" y="159" width="1022" height="737" key="SettingsEditor/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584628926988" />
<state x="616" y="351" key="com.intellij.ide.util.TipDialog" timestamp="1584697356692">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state x="616" y="351" key="com.intellij.ide.util.TipDialog/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584697356692" />
<state x="2578" y="53" width="1536" height="748" maximized="true" key="dock-window-1" timestamp="1584710032199">
<screen x="1920" y="0" width="1920" height="1200" />
</state>
<state x="2578" y="53" width="1536" height="748" maximized="true" key="dock-window-1/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584710032199" />
<state x="623" y="244" width="672" height="678" key="search.everywhere.popup" timestamp="1584635196518">
<screen x="0" y="27" width="1920" height="994" />
</state>
<state x="623" y="244" width="672" height="678" key="search.everywhere.popup/0.27.1920.994/1920.0.1920.1200@0.27.1920.994" timestamp="1584635196518" />
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$USER_HOME$/miniconda3/envs/dcase_2020_task_2/lib/python3.7/site-packages/torchaudio/functional.py</url>
<line>42</line>
<option name="timeStamp" value="18" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/dcase_2020_task_2/data_sets/mcm_dataset.py</url>
<line>72</line>
<option name="timeStamp" value="30" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
<watches-manager>
<configuration name="PythonConfigurationType">
<watch expression="len(indices) / self.num_sampels_per_file" language="Python" />
</configuration>
</watches-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/dcase_2020_task_2$mcm_dataset.coverage" NAME="mcm_dataset Coverage Results" MODIFIED="1584709402422" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/dcase_2020_task_2/data_sets" />
</component>
</project>
\ No newline at end of file
## Install
1. To setup project & download data run the following commands:
- ```conda env create -f environment.yml```
- ```cd raw_data```
- ```download_data.sh```
2. Setup MongoDB & Ominboard for Sacred Logger
- https://docs.mongodb.com/manual/installation/
- https://github.com/vivekratnavel/omniboard
## Run experiment
python -m experiments.vae_experiment with
```
cd vae_priors
# Vanilla
python -m experiments.vae_experiment with -p prior_class=priors.StandardNormalPrior prior.kwargs.weight=1.0 -m student2.cp.jku.at:27017:better_priors
# Beta VAE
python -m experiments.vae_experiment with -p prior_class=priors.StandardNormalPrior prior.kwargs.weight=150.0 -m student2.cp.jku.at:27017:better_priors
# Annealed VAE
python -m experiments.vae_experiment with -p prior_class=priors.StandardNormalPrior prior.kwargs.c_max=10 prior.kwargs.c_stop_epoch=200 prior.kwargs.weight=1000.0 -m student2.cp.jku.at:27017:better_priors
# Factor VAE
python -m experiments.vae_experiment with -p prior_class=priors.StandardNormalPrior prior.kwargs.weight=1.0 use_factor=True factor.kwargs.weight=10.0 -m student2.cp.jku.at:27017:better_priors
# Orthogonal
python -m experiments.vae_experiment with -p prior_class=priors.OrthogonalPrior prior.kwargs.weight=1000.0 -m student2.cp.jku.at:27017:better_priors
# Simplex
python -m experiments.vae_experiment with -p prior_class=priors.SimplexPrior prior.kwargs.weight=1000.0 -m student2.cp.jku.at:27017:better_priors
# BETA TCVAE
python -m experiments.vae_experiment with -p prior_class=priors.BetaTCVaePrior prior.kwargs.weight=1.0 -m student2.cp.jku.at:27017:better_priors
```
4. Results are in MongoDB and experiment_logs folder
## TODOs
- Testing
- Typing
- Datasets
- Priors
- Models
- Make model more readable
- ...
\ No newline at end of file
from auxiliary_losses.auxiliary_loss_base import AuxiliaryLossBase
from auxiliary_losses.factor_vae import FactorVAE
from abc import ABC, abstractmethod
class AuxiliaryLossBase(ABC):
def __init__(self, weight=1.0):
super().__init__()
self.weight = weight
@abstractmethod
def auxiliary_loss(self, batch):
raise NotImplementedError
from auxiliary_losses import AuxiliaryLossBase
import torch
import torch.nn.functional as F
class FactorVAE(AuxiliaryLossBase, torch.nn.Module):
def __init__(self, model, weight=1.0):
super().__init__(weight=weight)
self.model = model
self.prev_batch = None
def forward(self, x):
return self.model(x)
def auxiliary_loss(self, batch):
discriminator = self(batch['codes'])
batch['auxiliary_loss'] = self.weight * (discriminator[:, :1] - discriminator[:, 1:]).mean()
return batch['auxiliary_loss']
def training_loss(self, batch):
device = batch['codes'].device
if self.prev_batch is not None:
d_z_prim_prev = self(self.prev_batch['codes'])
perm = permute_dims(batch['codes']).detach()
d_z_prim_perm = self(perm)
ones = torch.ones(d_z_prim_prev.shape[0], dtype=torch.long, device=device)
zeros = torch.zeros(d_z_prim_perm.shape[0], dtype=torch.long, device=device)
factor_loss = 0.5 * (
F.cross_entropy(d_z_prim_prev, zeros) + F.cross_entropy(d_z_prim_perm, ones)
)
batch['factor_loss'] = factor_loss
else:
factor_loss = torch.tensor(0.0, requires_grad=True, device=device)
batch['factor_loss'] = factor_loss
self.prev_batch = batch
return batch['factor_loss']
def permute_dims(z):
assert z.dim() == 2
B, _ = z.size()
perm_z = []
for z_j in z.split(1, 1):
perm = torch.randperm(B).to(z.device)
perm_z_j = z_j[perm]
perm_z.append(perm_z_j)
return torch.cat(perm_z, 1)
from datetime import datetime
import os
def configuration():
seed = 1220
deterministic = False
id = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
log_path = os.path.join('..', 'experiment_logs', id)
#####################
# quick configuration, uses default parameters of more detailed configuration
#####################
prior_class = 'priors.StandardNormalPrior'
latent_size = 10
use_factor_loss = False
data_set_class = 'data_sets.SimpleDots'
batch_size = 64
epochs = 200
########################
# detailed configuration
########################
# set default values for different priors
if prior_class == 'priors.NoPrior':
prior = {
'class': prior_class,
'kwargs': {
'latent_size': latent_size,
'weight': 1.0
}
}
elif prior_class == 'priors.StandardNormalPrior':
prior = {
'class': prior_class,
'kwargs': {
'latent_size': latent_size,
'weight': 1.0,
'c_max': 0.0,
'c_stop_epoch': epochs
}
}
elif prior_class == 'priors.SimplexPrior' or prior_class == 'priors.OrthogonalPrior':
prior = {
'class': prior_class,
'kwargs': {
'min_anneal': 0.0,
'max_anneal': 1.0,
'anneal_stop_epoch': epochs,
'latent_size': latent_size,
'weight': 1.0,
'c_max': 0.0,
'c_stop_epoch': epochs
}
}
elif prior_class == 'priors.BetaTCVaePrior':
# TODO: add default parameters
prior = {
'class': prior_class,
'args': [
'@training_data_set.size'
],
'kwargs': {
'min_anneal': 0.0,
'max_anneal': 1.0,
'anneal_stop_epoch': epochs,
'latent_size': latent_size,
'weight': 1.0,
'c_max': 0.0,
'c_stop_epoch': epochs
}
}
elif prior_class == 'priors.DIPVaePrior':
# TODO: add default parameters
prior = {
'class': prior_class,
'args': [
'@training_data_set.size'
],
'kwargs': {
'min_anneal': 0.0,
'max_anneal': 1.0,
'anneal_stop_epoch': epochs,
'latent_size': latent_size,
'weight': 1.0,
'c_max': 0.0,
'c_stop_epoch': epochs,
'dip_first': False,
}
}
if data_set_class == 'data_sets.SimpleDots':
circle_radius = 8
image_size = (64, 64)
num_images = 64 * 64
training_data_set = {
'class': data_set_class,
'kwargs': {
'circle_radius': circle_radius,
'image_size': image_size,
'num_images': num_images
}
}
testing_data_set = {
'class': data_set_class,
'kwargs': {
'circle_radius': circle_radius,
'image_size': image_size,
'num_images': num_images
}
}
# TODO: include other datasets
reconstruction = {
'class': 'reconstructions.BinaryCrossEntropy',
'kwargs': {
'weight': 1.0,
'input_shape': '@training_data_set.observation_shape'
}
}
auto_encoder_model = {
'class': 'models.VanillaCNN',
'args': [
'@training_data_set.observation_shape',
'@reconstruction',
'@prior',
'@training_data_set.normalize_observations'
]
}
lr_scheduler = {
'class': 'torch.optim.lr_scheduler.StepLR',
'args': [
'@optimizer',
],
'kwargs': {
'step_size': epochs
}
}
optimizer = {
'class': 'torch.optim.AdamW',
'args': [
'@auto_encoder_model.parameters()'
],
'kwargs': {
'lr': 5e-4,
'betas': (0.9, 0.999),
'amsgrad': True,
'weight_decay': 0.01,
}
}
training_data_loader = {
'class': 'torch.utils.data.DataLoader',
'args': [
'@training_data_set'
],
'kwargs': {
'batch_size': batch_size,
'shuffle': True,
'num_workers': 0
}
}
testing_data_loader = {
'class': 'torch.utils.data.DataLoader',
'args': [
'@testing_data_set'
],
'kwargs': {
'batch_size': 64,
'shuffle': False,
'num_workers': 0
}
}
trainer = {
'class': 'trainers.PTLTrainer',
'kwargs': {
'max_epochs': epochs,
'checkpoint_callback': False,
'logger': False,
'early_stop_callback': False,
'gpus': [0],
}
}
if use_factor_loss:
factor = {
'class': 'auxiliary_losses.FactorVAE',
'args': [
{
'class': 'models.Critic',
'args': [
'@prior.latent_size',
],
'ref': 'factor_model'
}
],
'kwargs': {
'weight': 1.0
}
}
factor_optimizer = {
'class': 'torch.optim.AdamW',
'args': [
'@factor_model.parameters()'
],
'kwargs': {
'lr': 5e-4,
'betas': (0.9, 0.999),
'amsgrad': True,
'weight_decay': 0.01,
}
}
factor_lr_scheduler = {
'class': 'torch.optim.lr_scheduler.StepLR',
'args': [
'@factor_optimizer',
],
'kwargs': {
'step_size': 200
}
}
from data_sets.base_data_set import BaseDataSet
from data_sets.mcm_dataset import MCMDataset
\ No newline at end of file
from abc import ABC, abstractmethod, abstractproperty
from typing import Any
class BaseDataSet(ABC):
@property
@abstractmethod
def observation_shape(self) -> tuple:
raise NotImplementedError
@abstractmethod
def __getitem__(self, item: int) -> Any:
raise NotImplementedError