Commit 672f46f3 authored by Richard Vogl's avatar Richard Vogl
Browse files

model changes

parent 8ef23822
......@@ -94,8 +94,7 @@ def run(model, model_name, learn_rate, batch_size, split, k_samples):
train_batches = 0
start_time = time.time()
for batch_nr, f_ins in enumerate(train_batch_iterator):
cur_loss = train_fn(*f_ins)
print('cur loss shape = %d' % len(cur_loss))
cur_loss = train_fn(*f_ins)[0]
train_loss_sum += cur_loss[0]
train_batches += 1
......@@ -108,7 +107,7 @@ def run(model, model_name, learn_rate, batch_size, split, k_samples):
valid_batches = 0
valid_start_time = time.time()
for batch_nr, f_ins in enumerate(valid_batch_iterator):
cur_loss = eval_fn(*f_ins)
cur_loss = eval_fn(*f_ins)[0]
valid_loss_sum += cur_loss
valid_batches += 1
valid_loss[epoch] = valid_loss_sum / valid_batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment