Commit cc48279e authored by Richard Vogl's avatar Richard Vogl
Browse files

parameter for data

parent b8d9b6aa
...@@ -10,6 +10,11 @@ SETTINGS_FILE_NAME = 'settings.npy' ...@@ -10,6 +10,11 @@ SETTINGS_FILE_NAME = 'settings.npy'
LOSSES_FILE = 'losses.npy' LOSSES_FILE = 'losses.npy'
LIBROSA_PATH = os.path.join(DATASET_PATH, 'librosa_disklavier') LIBROSA_PATH = os.path.join(DATASET_PATH, 'librosa_disklavier')
MAPS_PATH = os.path.join(DATASET_PATH, 'maps_piano')
BATIK_PATH = os.path.join(DATASET_PATH, 'mozart_by_batik/')
MIDI_PATH = os.path.join(DATASET_PATH, 'midi_maestro_concert')
JAZZ_PATH = os.path.join(DATASET_PATH, 'jazz_piano')
CACHED_DS_NAME = 'cached.npz' CACHED_DS_NAME = 'cached.npz'
......
import numpy as np import numpy as np
import os import os
from piano_transcription import LIBROSA_PATH, CACHED_DS_NAME from piano_transcription import LIBROSA_PATH, MAPS_PATH, MIDI_PATH, BATIK_PATH, JAZZ_PATH, CACHED_DS_NAME
# TODO remove this and load and use real data! # TODO remove this and load and use real data!
FEAT_SIZE = 482 FEAT_SIZE = 482
...@@ -44,8 +44,22 @@ def load_ramdom_data(): ...@@ -44,8 +44,22 @@ def load_ramdom_data():
return feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, FULL_SEQ_LEN, FEAT_SIZE, OUT_SIZE return feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, FULL_SEQ_LEN, FEAT_SIZE, OUT_SIZE
def load_data(): def load_data(dataset, split):
# TODO how do we handle splits?
if dataset == 'librosa':
data_dict = np.load(os.path.join(LIBROSA_PATH, CACHED_DS_NAME)) data_dict = np.load(os.path.join(LIBROSA_PATH, CACHED_DS_NAME))
elif dataset == 'maps':
data_dict = np.load(os.path.join(MAPS_PATH, CACHED_DS_NAME))
elif dataset == 'batik':
data_dict = np.load(os.path.join(BATIK_PATH, CACHED_DS_NAME))
elif dataset == 'midi':
data_dict = np.load(os.path.join(MIDI_PATH, CACHED_DS_NAME))
elif dataset == 'jazz':
data_dict = np.load(os.path.join(JAZZ_PATH, CACHED_DS_NAME))
else:
print('unknown dataset %s, use one of: '%dataset)
exit(1)
train_names = data_dict['train_names'] train_names = data_dict['train_names']
train_feats = data_dict['train_feats'] train_feats = data_dict['train_feats']
......
...@@ -141,6 +141,8 @@ class UniversalRegressionDataPool(object): ...@@ -141,6 +141,8 @@ class UniversalRegressionDataPool(object):
for i_seq in xrange(self.n_sequences): for i_seq in xrange(self.n_sequences):
sequence = self.sequences[i_seq] sequence = self.sequences[i_seq]
target = self.target_sequences[i_seq] target = self.target_sequences[i_seq]
if isinstance(sequence, int) or isinstance(target, int):
print("war?")
assert len(sequence) == len(target) assert len(sequence) == len(target)
start_idx = self.half_context start_idx = self.half_context
......
...@@ -23,7 +23,7 @@ MAX_NUMEPOCHS = 10000 ...@@ -23,7 +23,7 @@ MAX_NUMEPOCHS = 10000
col = BColors() col = BColors()
def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_samples): def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_samples, dataset):
shuffle_data = True shuffle_data = True
# make output directory # make output directory
...@@ -32,7 +32,8 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam ...@@ -32,7 +32,8 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
os.makedirs(out_directory) os.makedirs(out_directory)
print("Loading data...") print("Loading data...")
feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, max_seq_len, feat_size, out_len = load_data() feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, max_seq_len, feat_size, out_len = \
load_data(dataset, split)
print("Building model...") print("Building model...")
network = model.build_model(batch_size=batch_size, seq_length=model.SEQ_LENGTH, feat_len=feat_size, out_len=out_len) network = model.build_model(batch_size=batch_size, seq_length=model.SEQ_LENGTH, feat_len=feat_size, out_len=out_len)
...@@ -102,7 +103,7 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam ...@@ -102,7 +103,7 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
train_batches += 1 train_batches += 1
tr_loss = train_loss_sum / train_batches tr_loss = train_loss_sum / train_batches
sys.stdout.write("\rEpoch: %3d of %d | Batch: %d | cur loss: %1.3f mean loss: %1.3f" % sys.stdout.write("\rEpoch: %3d of %d | Batch: %d | cur loss: %1.5f mean loss: %1.5f" %
(epoch + 1, MAX_NUMEPOCHS, batch_nr, cur_loss, tr_loss)) (epoch + 1, MAX_NUMEPOCHS, batch_nr, cur_loss, tr_loss))
sys.stdout.write(" validiating... ") sys.stdout.write(" validiating... ")
...@@ -119,15 +120,15 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam ...@@ -119,15 +120,15 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
print("\rEpoch %3d of %d took %1.3f s (valid: %1.3f s) -- patience: %d " % print("\rEpoch %3d of %d took %1.3f s (valid: %1.3f s) -- patience: %d " %
(epoch + 1, MAX_NUMEPOCHS, time.time() - start_time, time.time() - valid_start_time, cur_patience)) (epoch + 1, MAX_NUMEPOCHS, time.time() - start_time, time.time() - valid_start_time, cur_patience))
error = train_loss_sum / train_batches error = train_loss_sum / train_batches
print((" training loss: %1.3f "+col.print_colored("valid loss: %1.3f", BColors.WARNING)+" @ lr %1.6f") % print((" training loss: %1.5f "+col.print_colored("valid loss: %1.5f", BColors.WARNING)+" @ lr %1.6f") %
(error, valid_loss[epoch], new_lr)) (error, valid_loss[epoch], new_lr))
better = valid_loss[epoch] < valid_loss[best_valid_loss_epoch] better = valid_loss[epoch] < valid_loss[best_valid_loss_epoch]
if epoch == 0 or better: if epoch == 0 or better:
best_valid_loss_epoch = epoch best_valid_loss_epoch = epoch
np.savez(os.path.join(out_directory, BEST_MODEL_FILE_NAME), *lasagne.layers.get_all_param_values(network)) np.savez(os.path.join(out_directory, BEST_MODEL_FILE_NAME), *lasagne.layers.get_all_param_values(network))
print(' new best validation loss at epoch %3d: %1.3f' % (epoch, valid_loss[epoch])) print(' new best validation loss at epoch %3d: %1.5f' % (epoch, valid_loss[epoch]))
np.save(os.path.join(out_directory, LOSSES_FILE.npy), [train_loss[:epoch], valid_loss[:epoch]]) np.save(os.path.join(out_directory, LOSSES_FILE), [train_loss[:epoch], valid_loss[:epoch]])
if epoch > 0 and not better: if epoch > 0 and not better:
cur_patience -= 1 cur_patience -= 1
...@@ -154,6 +155,7 @@ def main(): ...@@ -154,6 +155,7 @@ def main():
# add argument parser # add argument parser
parser = argparse.ArgumentParser(description='Train piano transcription model.') parser = argparse.ArgumentParser(description='Train piano transcription model.')
parser.add_argument('--model', help='select model to train.', default='models/crnn_1.py') parser.add_argument('--model', help='select model to train.', default='models/crnn_1.py')
parser.add_argument('--dataset', help='select dataset for training.', default='librosa')
parser.add_argument('--learnrate', help='initial learning rate.', type=float, default=-1) parser.add_argument('--learnrate', help='initial learning rate.', type=float, default=-1)
parser.add_argument('--split', help='split for cross validation.', type=int, default=0) parser.add_argument('--split', help='split for cross validation.', type=int, default=0)
parser.add_argument('--batchsize', help='batchsize for minibatches.', type=int, default=-1) parser.add_argument('--batchsize', help='batchsize for minibatches.', type=int, default=-1)
...@@ -170,8 +172,9 @@ def main(): ...@@ -170,8 +172,9 @@ def main():
batchsize = model_arg.BATCH_SIZE batchsize = model_arg.BATCH_SIZE
split = args.split split = args.split
k_samples = args.ksamples k_samples = args.ksamples
dataset = args.dataset
run(model_arg, model_name_arg, model_name_full, lr, batchsize, split, k_samples) run(model_arg, model_name_arg, model_name_full, lr, batchsize, split, k_samples, dataset)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment