Skip to content
Snippets Groups Projects
Commit ea7393d9 authored by Julien Dejasmin's avatar Julien Dejasmin
Browse files

update omniglot classification notebook

parents c4da6840 5afd9c16
No related branches found
No related tags found
No related merge requests found
results/MNIST_results/heatmap_png/heatmapOmniglot_classif_NonBinaryNetfc.png

175 B

......@@ -211,7 +211,7 @@ class NoBinaryNetOmniglotClassification(Net):
def __init__(self):
super(NoBinaryNetOmniglotClassification, self).__init__()
self.layer1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, stride=2)
self.layer1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, stride=1)
self.batchNorm1 = nn.BatchNorm2d(64)
# self.dropout1 = nn.Dropout(0.5) #50 % probability
# self.maxPool1 = nn.MaxPool2d(kernel_size=2, stride=2)
......
......@@ -369,6 +369,7 @@ def viz_filters(model):
plt.show()
<<<<<<< HEAD
def get_activation(name, activation):
def hook(model, input, output):
activation[name] = output.detach()
......@@ -376,6 +377,8 @@ def get_activation(name, activation):
return hook
=======
>>>>>>> 5afd9c16b15a644c659fc0bb1142ff4983a49ae9
def viz_heatmap(model, name_model, loader, index_data=None, save=True):
activation = {}
for name, m in model.named_modules():
......@@ -434,6 +437,13 @@ def test_predict_few_examples(model, loader):
color=("green" if pred_arr[i] == labels_arr[i] else "red"))
def get_activation(name, activation):
def hook(model, input, output):
activation[name] = output.detach()
return hook
def get_train_data():
return datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
......@@ -654,6 +664,12 @@ def standardize_and_clip(tensor, MNIST, min_value=0.0, max_value=1.0,
def get_region_layer1(image, ind_x, ind_y, name, stride, padding, filter_size, len_img_h, len_img_w):
<<<<<<< HEAD
=======
"""
return region of interest from index (x,y) in image
"""
>>>>>>> 5afd9c16b15a644c659fc0bb1142ff4983a49ae9
# determine pixel high left of region of interest:
index_col_hl = (ind_x * stride) - padding
index_raw_hl = (ind_y * stride) - padding
......@@ -683,10 +699,20 @@ def get_region_layer1(image, ind_x, ind_y, name, stride, padding, filter_size, l
if region.shape != (filter_size, filter_size):
region = cv2.resize(region, (filter_size, filter_size), interpolation=cv2.INTER_AREA)
<<<<<<< HEAD
return region, begin_col, end_col, begin_raw, end_raw
def get_region_layer2(image, ind_x, ind_y, name, stride, padding, filter_size, len_img_h, len_img_w):
=======
return region
def get_region_layer2(image, ind_x, ind_y, name, stride, padding, filter_size, len_img_h, len_img_w):
"""
return region of interest from index (x,y)
"""
>>>>>>> 5afd9c16b15a644c659fc0bb1142ff4983a49ae9
region_shape = 7
# determine pixel high left of region of interest:
index_col_hl = (ind_x * stride) - padding
......@@ -739,6 +765,12 @@ def get_filter_layer2():
def get_region_layer3(image, ind_x, ind_y, name, stride, padding, filter_size, len_img_h, len_img_w):
<<<<<<< HEAD
=======
"""
return region of interest from index (x,y)
"""
>>>>>>> 5afd9c16b15a644c659fc0bb1142ff4983a49ae9
region_shape = 15
# determine pixel high left of region of interest:
index_col_hl = (ind_x * stride) - padding
......@@ -807,6 +839,12 @@ def get_filter_layer3():
def get_region_layer4(image, ind_x, ind_y, name, stride, padding, filter_size, len_img_h, len_img_w):
<<<<<<< HEAD
=======
"""
return region of interest from index (x,y)
"""
>>>>>>> 5afd9c16b15a644c659fc0bb1142ff4983a49ae9
region_shape = 31
# determine pixel high left of region of interest:
index_col_hl = (ind_x * stride) - padding
......@@ -862,6 +900,44 @@ def get_region_layer4(image, ind_x, ind_y, name, stride, padding, filter_size, l
region = cv2.resize(region, (region_shape, region_shape), interpolation=cv2.INTER_AREA)
return region
<<<<<<< HEAD
def get_filter_layer4():
return np.array(([[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[2, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1],
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1]]))
=======
def get_filter_layer4():
......@@ -898,6 +974,7 @@ def get_filter_layer4():
[1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1]]))
>>>>>>> 5afd9c16b15a644c659fc0bb1142ff4983a49ae9
def get_all_regions_max(loader, activations):
dataiter = iter(loader)
images, _ = dataiter.next()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment