Commit cc48279e authored by Richard Vogl's avatar Richard Vogl
Browse files

parameter for data

parent b8d9b6aa
......@@ -10,6 +10,11 @@ SETTINGS_FILE_NAME = 'settings.npy'
LOSSES_FILE = 'losses.npy'
LIBROSA_PATH = os.path.join(DATASET_PATH, 'librosa_disklavier')
MAPS_PATH = os.path.join(DATASET_PATH, 'maps_piano')
BATIK_PATH = os.path.join(DATASET_PATH, 'mozart_by_batik/')
MIDI_PATH = os.path.join(DATASET_PATH, 'midi_maestro_concert')
JAZZ_PATH = os.path.join(DATASET_PATH, 'jazz_piano')
CACHED_DS_NAME = 'cached.npz'
......
import numpy as np
import os
from piano_transcription import LIBROSA_PATH, CACHED_DS_NAME
from piano_transcription import LIBROSA_PATH, MAPS_PATH, MIDI_PATH, BATIK_PATH, JAZZ_PATH, CACHED_DS_NAME
# TODO remove this and load and use real data!
FEAT_SIZE = 482
......@@ -44,8 +44,22 @@ def load_ramdom_data():
return feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, FULL_SEQ_LEN, FEAT_SIZE, OUT_SIZE
def load_data():
data_dict = np.load(os.path.join(LIBROSA_PATH, CACHED_DS_NAME))
def load_data(dataset, split):
# TODO how do we handle splits?
if dataset == 'librosa':
data_dict = np.load(os.path.join(LIBROSA_PATH, CACHED_DS_NAME))
elif dataset == 'maps':
data_dict = np.load(os.path.join(MAPS_PATH, CACHED_DS_NAME))
elif dataset == 'batik':
data_dict = np.load(os.path.join(BATIK_PATH, CACHED_DS_NAME))
elif dataset == 'midi':
data_dict = np.load(os.path.join(MIDI_PATH, CACHED_DS_NAME))
elif dataset == 'jazz':
data_dict = np.load(os.path.join(JAZZ_PATH, CACHED_DS_NAME))
else:
print('unknown dataset %s, use one of: '%dataset)
exit(1)
train_names = data_dict['train_names']
train_feats = data_dict['train_feats']
......
......@@ -141,6 +141,8 @@ class UniversalRegressionDataPool(object):
for i_seq in xrange(self.n_sequences):
sequence = self.sequences[i_seq]
target = self.target_sequences[i_seq]
if isinstance(sequence, int) or isinstance(target, int):
print("war?")
assert len(sequence) == len(target)
start_idx = self.half_context
......
......@@ -23,7 +23,7 @@ MAX_NUMEPOCHS = 10000
col = BColors()
def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_samples):
def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_samples, dataset):
shuffle_data = True
# make output directory
......@@ -32,7 +32,8 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
os.makedirs(out_directory)
print("Loading data...")
feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, max_seq_len, feat_size, out_len = load_data()
feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, max_seq_len, feat_size, out_len = \
load_data(dataset, split)
print("Building model...")
network = model.build_model(batch_size=batch_size, seq_length=model.SEQ_LENGTH, feat_len=feat_size, out_len=out_len)
......@@ -102,7 +103,7 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
train_batches += 1
tr_loss = train_loss_sum / train_batches
sys.stdout.write("\rEpoch: %3d of %d | Batch: %d | cur loss: %1.3f mean loss: %1.3f" %
sys.stdout.write("\rEpoch: %3d of %d | Batch: %d | cur loss: %1.5f mean loss: %1.5f" %
(epoch + 1, MAX_NUMEPOCHS, batch_nr, cur_loss, tr_loss))
sys.stdout.write(" validiating... ")
......@@ -119,15 +120,15 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
print("\rEpoch %3d of %d took %1.3f s (valid: %1.3f s) -- patience: %d " %
(epoch + 1, MAX_NUMEPOCHS, time.time() - start_time, time.time() - valid_start_time, cur_patience))
error = train_loss_sum / train_batches
print((" training loss: %1.3f "+col.print_colored("valid loss: %1.3f", BColors.WARNING)+" @ lr %1.6f") %
print((" training loss: %1.5f "+col.print_colored("valid loss: %1.5f", BColors.WARNING)+" @ lr %1.6f") %
(error, valid_loss[epoch], new_lr))
better = valid_loss[epoch] < valid_loss[best_valid_loss_epoch]
if epoch == 0 or better:
best_valid_loss_epoch = epoch
np.savez(os.path.join(out_directory, BEST_MODEL_FILE_NAME), *lasagne.layers.get_all_param_values(network))
print(' new best validation loss at epoch %3d: %1.3f' % (epoch, valid_loss[epoch]))
print(' new best validation loss at epoch %3d: %1.5f' % (epoch, valid_loss[epoch]))
np.save(os.path.join(out_directory, LOSSES_FILE.npy), [train_loss[:epoch], valid_loss[:epoch]])
np.save(os.path.join(out_directory, LOSSES_FILE), [train_loss[:epoch], valid_loss[:epoch]])
if epoch > 0 and not better:
cur_patience -= 1
......@@ -154,6 +155,7 @@ def main():
# add argument parser
parser = argparse.ArgumentParser(description='Train piano transcription model.')
parser.add_argument('--model', help='select model to train.', default='models/crnn_1.py')
parser.add_argument('--dataset', help='select dataset for training.', default='librosa')
parser.add_argument('--learnrate', help='initial learning rate.', type=float, default=-1)
parser.add_argument('--split', help='split for cross validation.', type=int, default=0)
parser.add_argument('--batchsize', help='batchsize for minibatches.', type=int, default=-1)
......@@ -170,8 +172,9 @@ def main():
batchsize = model_arg.BATCH_SIZE
split = args.split
k_samples = args.ksamples
dataset = args.dataset
run(model_arg, model_name_arg, model_name_full, lr, batchsize, split, k_samples)
run(model_arg, model_name_arg, model_name_full, lr, batchsize, split, k_samples, dataset)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment