Commit ac30bc96 authored by Verena Praher's avatar Verena Praher
Browse files

move cutting of spec to baseline model since it's only necessary there

parent 8afd5def
......@@ -47,6 +47,7 @@ class CNN(pl.LightningModule):
def forward(self, x):
# x = x.unsqueeze(1)
x = x[:, :, :, :512]
# init bn
x = self.bn_init(x)
# print(x.shape)
......@@ -84,20 +85,6 @@ class CNN(pl.LightningModule):
def my_loss(self, y_hat, y):
return my_loss(y_hat, y)
def forward_full_song(self, x, y):
# print(x.shape)
#TODO full song???
return self.forward(x[:, :, :, :512])
# y_hat = torch.zeros((x.shape[0], 56), requires_grad=True).cuda()
# hop_size = 256
# i=0
# count = 0
# while i < x.shape[-1]:
# y_hat += self.forward(x[:,:,:,i:i+512])
# i += hop_size
# count += 1
# return y_hat/count
def training_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
......
......@@ -24,7 +24,7 @@ def validation_step(model, data_batch, batch_nb):
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = model.forward(x[:, :, :, :512]) # TODO: why is this necessary?
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
#print("y", y)
......@@ -77,7 +77,7 @@ def test_step(model, data_batch, batch_nb):
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = model.forward(x[:, :, :, :512]) # TODO why is this necessary?
y_hat = model.forward(x)
y = y.float()
y_hat = y_hat.float()
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu(), average='macro')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment