Commit 997bce85 authored by Richard Vogl's avatar Richard Vogl
Browse files

fix transcribe

parent cc48279e
......@@ -47,7 +47,7 @@ def get_train_batch_iterator():
return batch_iterator
def predict(net, X, max_seq_len, out_len):
def predict(pfun, X, max_seq_len, out_len):
seq_len, feat_len = X.shape
pad_width = (SPEC_CONTEXT - 1) / 2
......@@ -68,9 +68,9 @@ def predict(net, X, max_seq_len, out_len):
i0 = batch * MAX_PRED_SIZE
i1 = i0 + MAX_PRED_SIZE
i1o = min(i1, seq_len)
p_b[i0:i1o] = net.predict_proba(X_pred[i0:i1])[0:(i1o-i0)]
p_b[i0:i1o] = pfun(X_pred[i0:i1])[0:(i1o-i0)]
else:
p_b = net.predict_proba(X_pred)[:seq_len]
p_b = pfun(X_pred)[:seq_len]
return p_b
......
......@@ -52,7 +52,7 @@ def get_train_batch_iterator():
return batch_iterator
def predict(net, X, max_seq_len, out_len):
def predict(pfun, X, max_seq_len, out_len):
seq_len, feat_len = X.shape
n_batches = int(np.ceil(seq_len / float(MAX_PRED_SIZE)))
......@@ -82,9 +82,9 @@ def predict(net, X, max_seq_len, out_len):
i0 = batch * MAX_PRED_SIZE
i1 = i0 + MAX_PRED_SIZE
i1o = min(i1, seq_len)
p_b[i0:i1o] = net.predict_proba([X_pred[:, i0:i1], m_b[:, i0:i1]])[0, 0:(i1o-i0)]
p_b[i0:i1o] = pfun([X_pred[:, i0:i1], m_b[:, i0:i1]])[0, 0:(i1o-i0)]
else:
p_b = net.predict_proba([X_pred, m_b])[0, :seq_len]
p_b = pfun([X_pred, m_b])[0, :seq_len]
return p_b
......
......@@ -26,6 +26,16 @@ RNN_GRAD_CLIP = 50
FF_GRAD_CLIP = 50
def predict(pfun, X, max_seq_len, out_len):
seq_len, feat_len = X.shape
# pad and create mask
m_b = np.ones((1, seq_len), dtype=theano.config.floatX)
X_pred = np.pad(X, ((0, max_seq_len - seq_len), (0, 0)), 'constant').astype(theano.config.floatX)
m_b = np.pad(m_b, ((0, 0), (0, max_seq_len - seq_len)), 'constant').astype(theano.config.floatX)
return pfun([X_pred, m_b])[0, :seq_len]
def get_batch_iterator_train():
def batch_iterator(batch_size, k_samples, shuffle):
return BatchIterator(batch_size=batch_size, prepare=prepare_train, k_samples=k_samples, shuffle=shuffle)
......
......@@ -6,29 +6,45 @@ import numpy as np
from madmom.features.notes import NotePeakPickingProcessor
from piano_transcription import BEST_MODEL_FILE_NAME, SETTINGS_FILE_NAME, LOSSES_FILE
from piano_transcription.utils import select_model
from piano_transcription.utils import select_model, collect_inputs
from piano_transcription.data import FEAT_SIZE, OUT_SIZE
from piano_transcription.data.annotations import write_txt_annotation
from piano_transcription.data.features import extract_features
import theano
import lasagne
def compile_prediction_function(net):
"""
Compile theano prediction function
"""
input_vars = collect_inputs(net)
net_output = lasagne.layers.get_output(net, deterministic=True)
return theano.function(inputs=input_vars, outputs=net_output, on_unused_input='warn')
def run(model_path, input_file):
settings = np.load(os.path.join(model_path, SETTINGS_FILE_NAME))
settings = np.load(os.path.join(model_path, SETTINGS_FILE_NAME)).item(0)
model = select_model(settings['model'])
model_seq_len = model.MAX_PRED_SIZE
model, model_name = select_model(settings['model'])
features = extract_features(input_file)
network = model.build_eval_model(model_seq_len, FEAT_SIZE, OUT_SIZE)
seq_len, feat_len = features.shape
network = model.build_eval_model(seq_len, feat_len, OUT_SIZE)
with np.load(os.path.join(model_path, BEST_MODEL_FILE_NAME)) as f:
param_values = [f['arr_%d' % i] for i in range(len(f.files))]
lasagne.layers.set_all_param_values(network, param_values)
pfun = compile_prediction_function(network)
# transcribe using model
pred = model.predict(network, features, model_seq_len, OUT_SIZE)
pred = model.predict(pfun, features, seq_len, OUT_SIZE)
peak_picker = NotePeakPickingProcessor(pitch_offset=0)
notes = peak_picker.process(pred)
......@@ -51,4 +67,5 @@ def main():
if __name__ == '__main__':
#run('/Users/rich/python-workspace/piano_trans/piano_transcription/output/2018-06-17-16-57-22', '/Users/rich/Desktop/piano_test/CH/MAPS_ISOL_CH0.3_F_AkPnBsdf.flac')
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment