Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shreyan Chowdhury
moodwalk
Commits
176d4816
Commit
176d4816
authored
Sep 11, 2019
by
Shreyan Chowdhury
Browse files
implement pretrained model loading and fine-tuning
parent
eb2252a1
Changes
3
Hide whitespace changes
Inline
Side-by-side
experiments/experiment_midlevel.py
View file @
176d4816
...
...
@@ -30,11 +30,19 @@ def epochs_20():
config
[
'epochs'
]
=
20
def
midlevel_configs
():
global
config
config
[
'epochs'
]
=
2
def
mtg_configs
():
global
config
config
[
'epochs'
]
=
1
def
pretrain_midlevel
(
hparams
):
set_paths
(
'midlevel'
)
from
utils
import
CURR_RUN_PATH
,
logger
,
streamlog
# import these after init_experiment
streamlog
.
info
(
"Training midlevel..."
)
logger
.
info
(
f
"tensorboard --logdir=
{
CURR_RUN_PATH
}
"
)
exp
=
Experiment
(
name
=
'midlevel'
,
save_dir
=
CURR_RUN_PATH
)
...
...
@@ -92,19 +100,63 @@ def pretrain_midlevel(hparams):
print
(
model
)
logger
.
info
(
"Training midlevel"
)
trainer
.
fit
(
model
)
logger
.
info
(
"Training midlevel completed"
)
# streamlog.info("Running test")
# trainer.test()
logger
.
info
(
f
"Loading model from
{
chkpt_dir
}
"
)
model
=
Network
(
num_targets
=
7
,
on_gpu
=
USE_GPU
,
load_from
=
chkpt_dir
)
def
train_mtgjamendo
(
hparams
):
set_paths
(
'midlevel'
)
from
utils
import
CURR_RUN_PATH
,
logger
,
streamlog
# import these after init_experiment
chkpt_dir
=
os
.
path
.
join
(
CURR_RUN_PATH
,
'mtg.ckpt'
)
midlevel_chkpt_dir
=
os
.
path
.
join
(
CURR_RUN_PATH
,
'midlevel.ckpt'
)
logger
.
info
(
f
"tensorboard --logdir=
{
CURR_RUN_PATH
}
"
)
exp
=
Experiment
(
name
=
'mtg'
,
save_dir
=
CURR_RUN_PATH
)
mtg_configs
()
logger
.
info
(
f
"Loading model from
{
midlevel_chkpt_dir
}
"
)
model
=
Network
(
num_targets
=
7
,
dataset
=
'mtgjamendo'
,
on_gpu
=
USE_GPU
,
load_from
=
midlevel_chkpt_dir
)
logger
.
info
(
f
"Loaded model successfully"
)
pass
early_stop
=
EarlyStopping
(
monitor
=
config
[
'earlystopping_metric'
],
patience
=
config
[
'patience'
],
verbose
=
True
,
mode
=
config
[
'earlystopping_mode'
]
)
checkpoint_callback
=
ModelCheckpoint
(
filepath
=
chkpt_dir
,
save_best_only
=
True
,
verbose
=
True
,
monitor
=
'val_loss'
,
mode
=
'min'
)
if
USE_GPU
:
trainer
=
Trainer
(
gpus
=
[
0
],
distributed_backend
=
'ddp'
,
experiment
=
exp
,
max_nb_epochs
=
config
[
'epochs'
],
train_percent_check
=
hparams
.
train_percent
,
fast_dev_run
=
False
,
early_stop_callback
=
early_stop
,
checkpoint_callback
=
checkpoint_callback
)
else
:
trainer
=
Trainer
(
experiment
=
exp
,
max_nb_epochs
=
1
,
train_percent_check
=
0.01
,
fast_dev_run
=
False
,
checkpoint_callback
=
checkpoint_callback
)
logger
.
info
(
"Training mtgjamendo"
)
trainer
.
fit
(
model
)
logger
.
info
(
"Training mtgjamendo completed"
)
def
run
(
hparams
):
init_experiment
(
comment
=
hparams
.
experiment_name
)
pretrain_midlevel
(
hparams
)
train_mtgjamendo
(
hparams
)
if
__name__
==
'__main__'
:
...
...
models/midlevel_vgg.py
View file @
176d4816
...
...
@@ -23,14 +23,25 @@ def initialize_weights(module):
module
.
bias
.
data
.
zero_
()
class
ModelMidlevel
(
pl
.
LightningModule
):
def
__init__
(
self
,
num_targets
,
initialize
=
True
,
load_from
=
None
,
on_gpu
=
None
,
map_location
=
None
):
def
__init__
(
self
,
num_targets
,
initialize
=
True
,
dataset
=
'midlevel'
,
load_from
=
None
,
on_gpu
=
None
,
map_location
=
None
):
super
(
ModelMidlevel
,
self
).
__init__
()
data_root
,
audio_path
,
csvs_path
=
get_paths
()
cache_x_name
=
'_ap_midlevel44k'
from
torch.utils.data
import
random_split
dataset
,
dataset_length
=
df_get_midlevel_set
(
'midlevel'
,
os
.
path
.
join
(
csvs_path
,
'annotations.csv'
),
audio_path
,
cache_x_name
)
self
.
trainset
,
self
.
validationset
,
self
.
testset
=
random_split
(
dataset
,
[
int
(
i
*
dataset_length
)
for
i
in
[
0.7
,
0.2
,
0.1
]])
self
.
dataset
=
dataset
if
dataset
==
'midlevel'
:
data_root
,
audio_path
,
csvs_path
=
get_paths
(
'midlevel'
)
cache_x_name
=
'_ap_midlevel44k'
from
torch.utils.data
import
random_split
dataset
,
dataset_length
=
df_get_midlevel_set
(
'midlevel'
,
os
.
path
.
join
(
csvs_path
,
'annotations.csv'
),
audio_path
,
cache_x_name
)
self
.
trainset
,
self
.
validationset
,
self
.
testset
=
random_split
(
dataset
,
[
int
(
i
*
dataset_length
)
for
i
in
[
0.7
,
0.2
,
0.1
]])
elif
dataset
==
'mtgjamendo'
:
data_root
,
audio_path
,
csvs_path
=
get_paths
(
'mtgjamendo'
)
cache_x_name
=
"_ap_mtgjamendo44k"
train_csv
=
os
.
path
.
join
(
csvs_path
,
'train_processed.tsv'
)
validation_csv
=
os
.
path
.
join
(
csvs_path
,
'validation_processed.tsv'
)
test_csv
=
os
.
path
.
join
(
csvs_path
,
'test_processed.tsv'
)
self
.
trainset
=
df_get_mtg_set
(
'mtgjamendo'
,
train_csv
,
audio_path
,
cache_x_name
)
self
.
validationset
=
df_get_mtg_set
(
'mtgjamendo_val'
,
validation_csv
,
audio_path
,
cache_x_name
)
self
.
testset
=
df_get_mtg_set
(
'mtgjamendo_test'
,
test_csv
,
audio_path
,
cache_x_name
)
self
.
num_targets
=
num_targets
...
...
@@ -100,12 +111,22 @@ class ModelMidlevel(pl.LightningModule):
)
self
.
fc_ml
=
nn
.
Linear
(
256
,
7
)
if
initialize
:
self
.
apply
(
initialize_weights
)
if
load_from
:
self
.
_load_model
(
load_from
,
map_location
,
on_gpu
)
if
dataset
==
'mtgjamendo'
:
self
.
fc_mtg1
=
nn
.
Linear
(
7
,
10
)
self
.
fc_mtg2
=
nn
.
Linear
(
10
,
56
)
for
name
,
param
in
self
.
named_parameters
():
if
'mtg'
in
name
:
param
.
requires_grad
=
True
else
:
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
# 313 * 149 * 1
...
...
@@ -122,6 +143,10 @@ class ModelMidlevel(pl.LightningModule):
x
=
self
.
conv11
(
x
)
# 2 * 2 * 256
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
ml
=
self
.
fc_ml
(
x
)
if
self
.
dataset
==
'mtgjamendo'
:
x
=
self
.
fc_mtg1
(
ml
)
logit
=
nn
.
Sigmoid
()(
self
.
fc_mtg2
(
x
))
return
logit
return
ml
def
_load_model
(
self
,
load_from
,
map_location
=
None
,
on_gpu
=
True
):
...
...
@@ -156,7 +181,10 @@ class ModelMidlevel(pl.LightningModule):
self
.
load_state_dict
(
checkpoint
[
'state_dict'
])
def
my_loss
(
self
,
y_hat
,
y
):
return
F
.
mse_loss
(
y_hat
,
y
)
if
self
.
dataset
==
'midlevel'
:
return
F
.
mse_loss
(
y_hat
,
y
)
else
:
return
my_loss
(
y_hat
,
y
)
def
forward_full_song
(
self
,
x
,
y
):
# print(x.shape)
...
...
@@ -173,51 +201,66 @@ class ModelMidlevel(pl.LightningModule):
# return y_hat/count
def
training_step
(
self
,
data_batch
,
batch_nb
):
x
,
_
,
y
=
data_batch
y_hat
=
self
.
forward_full_song
(
x
,
y
)
y
=
y
.
float
()
y_hat
=
y_hat
.
float
()
return
{
'loss'
:
self
.
my_loss
(
y_hat
,
y
)}
if
self
.
dataset
==
'midlevel'
:
x
,
_
,
y
=
data_batch
y_hat
=
self
.
forward_full_song
(
x
,
y
)
y
=
y
.
float
()
y_hat
=
y_hat
.
float
()
return
{
'loss'
:
self
.
my_loss
(
y_hat
,
y
)}
else
:
return
training_step
(
self
,
data_batch
,
batch_nb
)
def
validation_step
(
self
,
data_batch
,
batch_nb
):
x
,
_
,
y
=
data_batch
y_hat
=
self
.
forward_full_song
(
x
,
y
)
y
=
y
.
float
()
y_hat
=
y_hat
.
float
()
return
{
'val_loss'
:
self
.
my_loss
(
y_hat
,
y
),
'y'
:
y
.
cpu
().
numpy
(),
'y_hat'
:
y_hat
.
cpu
().
numpy
(),
}
if
self
.
dataset
==
'midlevel'
:
x
,
_
,
y
=
data_batch
y_hat
=
self
.
forward_full_song
(
x
,
y
)
y
=
y
.
float
()
y_hat
=
y_hat
.
float
()
return
{
'val_loss'
:
self
.
my_loss
(
y_hat
,
y
),
'y'
:
y
.
cpu
().
numpy
(),
'y_hat'
:
y_hat
.
cpu
().
numpy
(),
}
else
:
return
validation_step
(
self
,
data_batch
,
batch_nb
)
def
validation_end
(
self
,
outputs
):
avg_loss
=
torch
.
stack
([
x
[
'val_loss'
]
for
x
in
outputs
]).
mean
()
y
=
[]
y_hat
=
[]
for
output
in
outputs
:
y
.
append
(
output
[
'y'
])
y_hat
.
append
(
output
[
'y_hat'
])
if
self
.
dataset
==
'midlevel'
:
avg_loss
=
torch
.
stack
([
x
[
'val_loss'
]
for
x
in
outputs
]).
mean
()
y
=
[]
y_hat
=
[]
for
output
in
outputs
:
y
.
append
(
output
[
'y'
])
y_hat
.
append
(
output
[
'y_hat'
])
y
=
np
.
concatenate
(
y
)
y_hat
=
np
.
concatenate
(
y_hat
)
y
=
np
.
concatenate
(
y
)
y_hat
=
np
.
concatenate
(
y_hat
)
return
{
'val_loss'
:
avg_loss
}
return
{
'val_loss'
:
avg_loss
}
else
:
return
validation_end
(
outputs
)
def
test_step
(
self
,
data_batch
,
batch_nb
):
x
,
_
,
y
=
data_batch
y_hat
=
self
.
forward_full_song
(
x
,
y
)
y
=
y
.
float
()
y_hat
=
y_hat
.
float
()
return
{
'test_loss'
:
self
.
my_loss
(
y_hat
,
y
),
'y'
:
y
.
cpu
().
numpy
(),
'y_hat'
:
y_hat
.
cpu
().
numpy
(),
}
if
self
.
dataset
==
'midlevel'
:
x
,
_
,
y
=
data_batch
y_hat
=
self
.
forward_full_song
(
x
,
y
)
y
=
y
.
float
()
y_hat
=
y_hat
.
float
()
return
{
'test_loss'
:
self
.
my_loss
(
y_hat
,
y
),
'y'
:
y
.
cpu
().
numpy
(),
'y_hat'
:
y_hat
.
cpu
().
numpy
(),
}
else
:
return
test_step
(
self
,
data_batch
,
batch_nb
)
def
test_end
(
self
,
outputs
):
avg_test_loss
=
torch
.
stack
([
x
[
'test_loss'
]
for
x
in
outputs
]).
mean
()
test_metrics
=
{
"test_loss"
:
avg_test_loss
}
self
.
experiment
.
log
(
test_metrics
)
return
test_metrics
if
self
.
dataset
==
'midlevel'
:
avg_test_loss
=
torch
.
stack
([
x
[
'test_loss'
]
for
x
in
outputs
]).
mean
()
test_metrics
=
{
"test_loss"
:
avg_test_loss
}
self
.
experiment
.
log
(
test_metrics
)
return
test_metrics
else
:
return
test_end
(
outputs
)
def
configure_optimizers
(
self
):
...
...
utils.py
View file @
176d4816
...
...
@@ -30,7 +30,7 @@ data_roots = {
"shreyan-All-Series"
:
"/mnt/2tb/datasets/MTG-Jamendo"
},
"midlevel"
:{
"rechenknecht3.cp.jku.at"
:
"/media/rk3/shared/
datasets/
midlevel"
,
"rechenknecht3.cp.jku.at"
:
"/media/rk3/shared/midlevel"
,
"rechenknecht2.cp.jku.at"
:
"/media/rk2/shared/datasets/midlevel"
,
"rechenknecht1.cp.jku.at"
:
"/media/rk1/shared/datasets/midlevel"
,
"hermine"
:
""
,
...
...
@@ -98,7 +98,17 @@ def set_paths(dataset_name):
PATH_ANNOTATIONS
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'MTG-Jamendo_annotations'
)
def
get_paths
():
def
get_paths
(
dataset_name
):
PATH_DATA_ROOT
=
data_roots
[
dataset_name
][
hostname
]
if
dataset_name
==
'midlevel'
:
PATH_AUDIO
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'audio'
)
PATH_ANNOTATIONS
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'metadata_annotations'
)
elif
dataset_name
==
'mtgjamendo'
:
PATH_AUDIO
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'MTG-Jamendo_audio'
)
PATH_ANNOTATIONS
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'MTG-Jamendo_annotations'
)
else
:
PATH_AUDIO
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'audio'
)
PATH_ANNOTATIONS
=
os
.
path
.
join
(
PATH_DATA_ROOT
,
'annotations'
)
return
PATH_DATA_ROOT
,
PATH_AUDIO
,
PATH_ANNOTATIONS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment