Commit 3982785c authored by Verena Praher's avatar Verena Praher
Browse files

setup effective receptive field experiments

parent b3ba542d
from utils import *
from pytorch_lightning import Trainer
from test_tube import Experiment
from models.cp_resnet import Network
import torch
from datasets import MelSpecDataset
from torch.utils.data import Dataset, DataLoader
from matplotlib.transforms import Affine2D
import mpl_toolkits.axisartist.floating_axes as floating_axes
import matplotlib.pyplot as plt
def setup_axes(fig, rect, rotation, axisScale, axisLimits, doShift):
tr_rot = Affine2D().scale(axisScale[0], axisScale[1]).rotate_deg(rotation)
# This seems to do nothing
if doShift:
tr_trn = Affine2D().translate(-90,-5)
else:
tr_trn = Affine2D().translate(0,0)
tr = tr_rot + tr_trn
grid_helper = floating_axes.GridHelperCurveLinear(tr, extremes=axisLimits)
ax = floating_axes.FloatingSubplot(fig, rect, grid_helper=grid_helper)
fig.add_subplot(ax)
aux_ax = ax.get_aux_axes(tr)
return ax, aux_ax
def ERF_plot(a, results={}, fileid="", savefile="erf.png"):
fig = plt.figure(1, figsize=(15, 9))
a0=a.sum(axis=0)
a0=a0/a0.max()
a1=a.sum(axis=1)
a1=a1/a1.max()
a0max=1
a1max=1
axes = []
axisOrientation = [0, 0, 270]
axisScale = [[1,1],[1,20],[-1,20]]
axisPosition = [223,221,224]
axisLimits = [(0, 431, 0, 256),
(0, 431, 0, 2),
(0,256, 0, 2),
]
doShift = [False, False, False]
label_axes = []
for i in range(0, len(axisOrientation)):
ax, aux_ax = setup_axes(fig, axisPosition[i], axisOrientation[i],
axisScale[i], axisLimits[i], doShift[i])
axes.append(aux_ax)
label_axes.append(ax)
axes[0].imshow(a/a.max(), cmap=plt.get_cmap('gray'),vmin=0,vmax=1)
label_axes[0].axis["bottom"].label.set_text(fileid + " loss:" + str(results.get("loss", "")) + " " +
" acc:" + str(results.get("acc", "")) + " ")
label_axes[0].axis["left"].label.set_text('Freq')
#b = np.linspace(-0.5,4.5,50)
axes[1].plot(a0)
axes[2].plot(a1)
#b = np.linspace(-3.5,3.5,50)
#axes[3].hist(np.array(x)-np.array(y), bins=b)
for i in range(1,len(label_axes)):
for axisLoc in ['top','left','right']:
label_axes[i].axis[axisLoc].set_visible(False)
label_axes[i].axis['bottom'].toggle(ticklabels=False)
fig.subplots_adjust(wspace=-0.35, hspace=-0.35, left=0.00, right=0.99, top=0.99, bottom=0.0)
plt.savefig(savefile)
def ERF_generate(model, loader):
#print('ERF_generate on ({}) :'.format(dataset_name + extra_name))
model.eval()
counter = 0
accum = None
for step, (data, _, targets) in enumerate(loader):
data = data.cuda()
data.requires_grad = True
outputs = model(data)
grads = torch.zeros_like(outputs)
grads[:, :, grads.size(2) // 2, grads.size(3) // 2] = 1
outputs.backward(grads)
me = np.abs(data.grad.cpu().numpy()).mean(axis=0).mean(axis=0)
if accum is None:
accum = me
else:
accum += me
counter += 1
# torch.save({"arr": accum, "counter": counter}, os.path.join(self.config.out_dir, 'ERF_dict.pth'))
ERF_plot(accum, savefile=os.path.join('/home/verena/experiments/moodwalk', 'erf.png'))
# self.experiment.add_artifact(os.path.join(self.config.out_dir, 'erf.png'), "erf.png", {"dataset": dataset_name})
return True
def run():
logger.info(CURR_RUN_PATH)
exp = Experiment(save_dir=CURR_RUN_PATH)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=10, train_percent_check=1.0,
fast_dev_run=False)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model_config = {
"input_shape": [1, 1, 256, 96], #[batch,channels,time,freq]
"n_classes": 56,
"depth": 26,
"base_channels": 128,
"n_blocks_per_stage": [3, 1, 1],
"stage1": {"maxpool": [1, 2], "k1s": [3, 3, 3], "k2s": [1, 3, 3]},
"stage2": {"maxpool": [1], "k1s": [3, ], "k2s": [1, ]},
"stage3": {"maxpool": [], "k1s": [1, ], "k2s": [1, ]},
"block_type": "basic"
}
model = Network(model_config)
trainer.fit(model)
dataset = MelSpecDataset(phase='test', ann_root=PATH_ANNOTATIONS, spec_root=PATH_MELSPEC_DOWNLOADED_FRAMED)
test_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True)
ERF_generate(model, test_loader)
if __name__=='__main__':
run()
\ No newline at end of file
......@@ -29,6 +29,10 @@ if hostname in ['rechenknecht0.cp.jku.at', 'rechenknecht1.cp.jku.at', 'rechenkne
plt.switch_backend('agg')
PATH_DATA_ROOT = '/home/shreyan/data/MTG-Jamendo'
USE_GPU = True
elif hostname == 'hermine': # PC verena
plt.switch_backend('agg')
PATH_DATA_ROOT = '/media/verena/SAMSUNG/Data/MTG-Jamendo'
USE_GPU = True
else:
PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo'
USE_GPU = False
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment