Commit de220f20 authored by Richard Vogl's avatar Richard Vogl
Browse files

model changes

parent 791356a3
......@@ -95,7 +95,8 @@ def run(model, model_name, learn_rate, batch_size, split, k_samples):
start_time = time.time()
for batch_nr, f_ins in enumerate(train_batch_iterator):
cur_loss = train_fn(*f_ins)
train_loss_sum += cur_loss
print('cur loss shape = '+str(cur_loss.shape))
train_loss_sum += cur_loss[0]
train_batches += 1
tr_loss = train_loss_sum / train_batches
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment