Commit 4191dde3 authored by Verena Haunschmid's avatar Verena Haunschmid
Browse files

Clean up save_1_or_2_dim_dualnet_visualizations interface

parent 285fdd1c
......@@ -106,7 +106,7 @@ def train_dualnet(model, loaders, config):
# Visualize the learned critic landscape.
if config.visualize:
save_1_or_2_dim_dualnet_visualizations(model, config.distrib1.dim, dirs.figures_dir, config,
save_1_or_2_dim_dualnet_visualizations(model, dirs.figures_dir, config,
state['epoch'], state['loss'])
# Check if this is the best model.
......@@ -158,7 +158,7 @@ def train_dualnet(model, loaders, config):
# Visualize the learned critic landscape.
if config.visualize:
save_1_or_2_dim_dualnet_visualizations(model, config.distrib1.dim, dirs.figures_dir, config,
save_1_or_2_dim_dualnet_visualizations(model, dirs.figures_dir, config,
after_training=False)
return test_state
......
......@@ -94,17 +94,19 @@ def load_best_model_and_optimizer(model, optimizer, best_path):
load_optimizer(optimizer, best_optimizer_path)
def save_1_or_2_dim_dualnet_visualizations(model, dim, figures_dir, config, epoch=None, loss=None,
def save_1_or_2_dim_dualnet_visualizations(model, figures_dir, config, epoch=None, loss=None,
after_training=False):
dim = config.distrib1.dim
if not after_training:
if dim == 2:
save_2d_dualnet_visualizations(model, figures_dir, config, epoch, loss)
if dim == 1:
save_1d_dualnet_visualizations(model, figures_dir, config, epoch, loss)
else:
if config.distrib1.dim == 2:
if dim == 2:
save_2d_dualnet_visualizations(model, figures_dir, config, after_training=True)
if config.distrib1.dim == 1:
if dim == 1:
save_1d_dualnet_visualizations(model, figures_dir, config, after_training=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment