Commit f02c5d32 authored by Verena Praher's avatar Verena Praher
Browse files

add optimizer to the new models

parent bd916e97
......@@ -78,4 +78,7 @@ class CNN(pl.LightningModule):
return logit
def my_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
\ No newline at end of file
return F.binary_cross_entropy(y_hat, y)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
\ No newline at end of file
......@@ -18,7 +18,7 @@ class Network(pl.LightningModule):
def __init__(self, num_tags):
super(Network, self).__init__()
self.num_tags = num_tags
self.model = resnet18(False)
self.model = resnet18(False) # TODO: need to check if optimizer recognizes these parameters
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, self.num_tags) # overwriting fc layer
self.sig = nn.Sigmoid(self.num_tags)
......@@ -29,4 +29,7 @@ class Network(pl.LightningModule):
return x
def my_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
\ No newline at end of file
return F.binary_cross_entropy(y_hat, y)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.001)]
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment