Commit 4def4edd authored by Verena Praher's avatar Verena Praher
Browse files

replace last layer of pretrained midlevel model

parent e033f5df
...@@ -3,7 +3,6 @@ from datasets.midlevel import df_get_midlevel_set ...@@ -3,7 +3,6 @@ from datasets.midlevel import df_get_midlevel_set
from models.shared_stuff import * from models.shared_stuff import *
from utils import * from utils import *
from datasets.dataset import HDF5Dataset
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -119,7 +118,8 @@ class ModelMidlevel(pl.LightningModule): ...@@ -119,7 +118,8 @@ class ModelMidlevel(pl.LightningModule):
self._load_model(load_from, map_location, on_gpu) self._load_model(load_from, map_location, on_gpu)
if dataset == 'mtgjamendo': if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(7, 10) self.fc_mtg1 = nn.Linear(256
, 10)
self.fc_mtg2 = nn.Linear(10, 56) self.fc_mtg2 = nn.Linear(10, 56)
for name, param in self.named_parameters(): for name, param in self.named_parameters():
...@@ -142,12 +142,14 @@ class ModelMidlevel(pl.LightningModule): ...@@ -142,12 +142,14 @@ class ModelMidlevel(pl.LightningModule):
x = self.conv7b(x) # 39 * 18 * 384 x = self.conv7b(x) # 39 * 18 * 384
x = self.conv11(x) # 2 * 2 * 256 x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
ml = self.fc_ml(x) # ml = self.fc_ml(x)
if self.dataset == 'midlevel':
x = self.fc_ml(x)
if self.dataset=='mtgjamendo': if self.dataset=='mtgjamendo':
x = self.fc_mtg1(ml) x = self.fc_mtg1(x)
logit = nn.Sigmoid()(self.fc_mtg2(x)) logit = nn.Sigmoid()(self.fc_mtg2(x))
return logit return logit
return ml return x
def _load_model(self, load_from, map_location=None, on_gpu=True): def _load_model(self, load_from, map_location=None, on_gpu=True):
last_epoch = -1 last_epoch = -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment