Commit 4def4edd authored by Verena Praher's avatar Verena Praher
Browse files

replace last layer of pretrained midlevel model

parent e033f5df
......@@ -3,7 +3,6 @@ from datasets.midlevel import df_get_midlevel_set
from models.shared_stuff import *
from utils import *
from datasets.dataset import HDF5Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -119,7 +118,8 @@ class ModelMidlevel(pl.LightningModule):
self._load_model(load_from, map_location, on_gpu)
if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(7, 10)
self.fc_mtg1 = nn.Linear(256
, 10)
self.fc_mtg2 = nn.Linear(10, 56)
for name, param in self.named_parameters():
......@@ -142,12 +142,14 @@ class ModelMidlevel(pl.LightningModule):
x = self.conv7b(x) # 39 * 18 * 384
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
ml = self.fc_ml(x)
# ml = self.fc_ml(x)
if self.dataset == 'midlevel':
x = self.fc_ml(x)
if self.dataset=='mtgjamendo':
x = self.fc_mtg1(ml)
x = self.fc_mtg1(x)
logit = nn.Sigmoid()(self.fc_mtg2(x))
return logit
return ml
return x
def _load_model(self, load_from, map_location=None, on_gpu=True):
last_epoch = -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment