Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shreyan Chowdhury
moodwalk
Commits
c070c7d9
Commit
c070c7d9
authored
Sep 11, 2019
by
Verena Praher
Browse files
add option to reload pretrained midlevel model
parent
6ca30b70
Changes
1
Show whitespace changes
Inline
Side-by-side
experiments/experiment_midlevel.py
View file @
c070c7d9
...
...
@@ -85,7 +85,7 @@ def pretrain_midlevel(hparams):
if
USE_GPU
:
trainer
=
Trainer
(
gpus
=
[
0
],
distributed_backend
=
'ddp'
,
experiment
=
exp
,
max_nb_epochs
=
config
[
'epochs'
]
,
train_percent_check
=
hparams
.
train_percent
,
experiment
=
exp
,
max_nb_epochs
=
20
,
train_percent_check
=
hparams
.
train_percent
,
fast_dev_run
=
False
,
early_stop_callback
=
early_stop
,
checkpoint_callback
=
checkpoint_callback
...
...
@@ -106,11 +106,10 @@ def pretrain_midlevel(hparams):
# streamlog.info("Running test")
# trainer.test()
def
train_mtgjamendo
(
hparams
):
def
train_mtgjamendo
(
hparams
,
midlevel_chkpt_dir
):
set_paths
(
'midlevel'
)
from
utils
import
CURR_RUN_PATH
,
logger
,
streamlog
# import these after init_experiment
chkpt_dir
=
os
.
path
.
join
(
CURR_RUN_PATH
,
'mtg.ckpt'
)
midlevel_chkpt_dir
=
os
.
path
.
join
(
CURR_RUN_PATH
,
'midlevel.ckpt'
)
logger
.
info
(
f
"tensorboard --logdir=
{
CURR_RUN_PATH
}
"
)
exp
=
Experiment
(
name
=
'mtg'
,
save_dir
=
CURR_RUN_PATH
)
...
...
@@ -137,7 +136,7 @@ def train_mtgjamendo(hparams):
if
USE_GPU
:
trainer
=
Trainer
(
gpus
=
[
0
],
distributed_backend
=
'ddp'
,
experiment
=
exp
,
max_nb_epochs
=
config
[
'epochs'
]
,
train_percent_check
=
hparams
.
train_percent
,
experiment
=
exp
,
max_nb_epochs
=
20
,
train_percent_check
=
hparams
.
train_percent
,
fast_dev_run
=
False
,
early_stop_callback
=
early_stop
,
checkpoint_callback
=
checkpoint_callback
...
...
@@ -155,8 +154,15 @@ def train_mtgjamendo(hparams):
def
run
(
hparams
):
init_experiment
(
comment
=
hparams
.
experiment_name
)
if
hparams
.
pretrained_midlevel
is
None
:
pretrain_midlevel
(
hparams
)
train_mtgjamendo
(
hparams
)
from
utils
import
CURR_RUN_PATH
midlevel_chkpt_dir
=
os
.
path
.
join
(
CURR_RUN_PATH
,
'midlevel.ckpt'
)
else
:
from
utils
import
logger
midlevel_chkpt_dir
=
hparams
.
pretrained_midlevel
logger
.
info
(
"Using pretrained model"
,
midlevel_chkpt_dir
)
train_mtgjamendo
(
hparams
,
midlevel_chkpt_dir
)
if
__name__
==
'__main__'
:
...
...
@@ -168,6 +174,8 @@ if __name__=='__main__':
# implementation of argument parser
parent_parser
.
add_argument
(
'--train_percent'
,
type
=
float
,
default
=
1.0
,
help
=
'how much train data to use'
)
parent_parser
.
add_argument
(
'--pretrained_midlevel'
,
default
=
"/home/verena/experiments/moodwalk/runs/ca601 - pretrain_midlevel 20 epochs/midlevel.ckpt/"
,
type
=
str
)
parser
=
Network
.
add_model_specific_args
(
parent_parser
)
hyperparams
=
parser
.
parse_args
()
run
(
hyperparams
)
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment