Skip to content
Snippets Groups Projects
Commit 708e8ed9 authored by Julien Dejasmin's avatar Julien Dejasmin
Browse files

trained models on omniglot

parent e363cc4c
No related branches found
No related tags found
No related merge requests found
File deleted
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -137,10 +137,10 @@ def get_my_model_Omniglot(binary, maxpooling=True, mixt=False, stochastic=True, ...@@ -137,10 +137,10 @@ def get_my_model_Omniglot(binary, maxpooling=True, mixt=False, stochastic=True,
elif mixt: elif mixt:
if stochastic: if stochastic:
mode = 'Stochastic' mode = 'Stochastic'
names_model = 'MNIST_Stochastic' names_model = 'Omniglot_Stochastic'
else: else:
mode = 'Deterministic' mode = 'Deterministic'
names_model = 'MNIST_Deterministic' names_model = 'Omniglot_Deterministic'
if reinforce: if reinforce:
estimator = 'REINFORCE' estimator = 'REINFORCE'
names_model += '_REINFORCE' names_model += '_REINFORCE'
...@@ -153,10 +153,10 @@ def get_my_model_Omniglot(binary, maxpooling=True, mixt=False, stochastic=True, ...@@ -153,10 +153,10 @@ def get_my_model_Omniglot(binary, maxpooling=True, mixt=False, stochastic=True,
model = MixtNetOmniglotClassification(maxpooling, mode=mode, estimator=estimator) model = MixtNetOmniglotClassification(maxpooling, mode=mode, estimator=estimator)
else: else:
model = NoBinaryNetOmniglotClassification(maxpooling)
names_model = 'Omniglot_classif_NonBinaryNet' names_model = 'Omniglot_classif_NonBinaryNet'
mode = None if maxpooling:
estimator = None names_model += '_maxpooling'
model = NoBinaryNetOmniglotClassification(maxpooling)
return model, names_model return model, names_model
...@@ -383,6 +383,9 @@ class NoBinaryNetOmniglotClassification(Net): ...@@ -383,6 +383,9 @@ class NoBinaryNetOmniglotClassification(Net):
self.batchNorm5 = nn.BatchNorm2d(512) self.batchNorm5 = nn.BatchNorm2d(512)
self.act_layer5 = nn.ReLU() self.act_layer5 = nn.ReLU()
if self.maxpooling:
self.fc1 = nn.Linear(3 * 3 * 512, 4096)
else:
self.fc1 = nn.Linear(4 * 4 * 512, 4096) self.fc1 = nn.Linear(4 * 4 * 512, 4096)
self.act_fc1 = nn.ReLU() self.act_fc1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.5) self.dropout1 = nn.Dropout(0.5)
...@@ -485,7 +488,11 @@ class BinaryNetOmniglotClassification(Net): ...@@ -485,7 +488,11 @@ class BinaryNetOmniglotClassification(Net):
self.act_layer5 = nn.ReLU() self.act_layer5 = nn.ReLU()
if self.maxpooling:
self.fc1 = nn.Linear(3 * 3 * 512, 4096)
else:
self.fc1 = nn.Linear(4 * 4 * 512, 4096) self.fc1 = nn.Linear(4 * 4 * 512, 4096)
self.act_fc1 = nn.ReLU() self.act_fc1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.5) self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(4096, 1623) self.fc2 = nn.Linear(4096, 1623)
...@@ -609,7 +616,11 @@ class MixtNetOmniglotClassification(Net): ...@@ -609,7 +616,11 @@ class MixtNetOmniglotClassification(Net):
self.act_layer4_binary = StochasticBinaryActivation(estimator=estimator) self.act_layer4_binary = StochasticBinaryActivation(estimator=estimator)
self.act_layer5_binary = StochasticBinaryActivation(estimator=estimator) self.act_layer5_binary = StochasticBinaryActivation(estimator=estimator)
if self.maxpooling:
self.fc1 = nn.Linear(3*3*256*2, 4096)
else:
self.fc1 = nn.Linear(4*4*256*2, 4096) self.fc1 = nn.Linear(4*4*256*2, 4096)
self.act_fc1 = nn.ReLU() self.act_fc1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.5) self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(4096, 1623) self.fc2 = nn.Linear(4096, 1623)
...@@ -621,34 +632,35 @@ class MixtNetOmniglotClassification(Net): ...@@ -621,34 +632,35 @@ class MixtNetOmniglotClassification(Net):
if self.maxpooling: if self.maxpooling:
# For binary: # For binary:
x_layer1_binary = self.act_layer1_binary(((self.maxpool1_binary(self.batchnorm1_binary(self.layer1_binary(x)))), slope)) x_layer1_binary = self.act_layer1_binary(((self.maxPool1_binary(self.batchNorm1_binary(self.layer1_binary(x)))), slope))
x_layer2_binary = self.act_layer2_binary(((self.maxpool2_binary(self.batchnorm2_binary(self.layer2_binary(x_layer1_binary)))), slope)) x_layer2_binary = self.act_layer2_binary(((self.maxPool2_binary(self.batchNorm2_binary(self.layer2_binary(x_layer1_binary)))), slope))
x_layer3_binary = self.act_layer3_binary(((self.maxpool3_binary(self.batchnorm3_binary(self.layer3_binary(x_layer2_binary)))), slope)) x_layer3_binary = self.act_layer3_binary(((self.maxPool3_binary(self.batchNorm3_binary(self.layer3_binary(x_layer2_binary)))), slope))
x_layer4_binary = self.act_layer4_binary(((self.maxpool4_binary(self.batchnorm4_binary(self.layer4_binary(x_layer3_binary)))), slope)) x_layer4_binary = self.act_layer4_binary(((self.maxPool4_binary(self.batchNorm4_binary(self.layer4_binary(x_layer3_binary)))), slope))
x_layer5_binary = self.act_layer5_binary(((self.maxpool5_binary(self.batchnorm5_binary(self.layer5_binary(x_layer4_binary)))), slope)) x_layer5_binary = self.act_layer5_binary(((self.maxPool5_binary(self.batchNorm5_binary(self.layer5_binary(x_layer4_binary)))), slope))
# For No binary: # For No binary:
x_layer1_no_binary = self.act_layer1_no_binary(self.maxpool1_no_binary(self.batchnorm1_no_binary(self.layer1_no_binary(x) * slope))) x_layer1_no_binary = self.act_layer1_no_binary(self.maxPool1_no_binary(self.batchNorm1_no_binary(self.layer1_no_binary(x) * slope)))
x_layer2_no_binary = self.act_layer2_no_binary(self.maxpool2_no_binary(self.batchnorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope))) x_layer2_no_binary = self.act_layer2_no_binary(self.maxPool2_no_binary(self.batchNorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope)))
x_layer3_no_binary = self.act_layer3_no_binary(self.maxpool3_no_binary(self.batchnorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope))) x_layer3_no_binary = self.act_layer3_no_binary(self.maxPool3_no_binary(self.batchNorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope)))
x_layer4_no_binary = self.act_layer4_no_binary(self.maxpool4_no_binary(self.batchnorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope))) x_layer4_no_binary = self.act_layer4_no_binary(self.maxPool4_no_binary(self.batchNorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope)))
x_layer5_no_binary = self.act_layer5_no_binary(self.maxpool5_no_binary(self.batchnorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope))) x_layer5_no_binary = self.act_layer5_no_binary(self.maxPool5_no_binary(self.batchNorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope)))
else: else:
# For binary: # For binary:
x_layer1_binary = self.act_layer1_binary(((self.batchnorm1_binary(self.layer1_binary(x))), slope)) x_layer1_binary = self.act_layer1_binary(((self.batchNorm1_binary(self.layer1_binary(x))), slope))
x_layer2_binary = self.act_layer2_binary(((self.batchnorm2_binary(self.layer2_binary(x_layer1_binary))), slope))
x_layer3_binary = self.act_layer3_binary(((self.batchnorm3_binary(self.layer3_binary(x_layer2_binary))), slope))
x_layer4_binary = self.act_layer4_binary(((self.batchnorm4_binary(self.layer4_binary(x_layer3_binary))), slope))
x_layer5_binary = self.act_layer5_binary(((self.batchnorm5_binary(self.layer5_binary(x_layer4_binary))), slope))
# For no binary:
x_layer1_no_binary = self.act_layer1_no_binary(self.batchnorm1_no_binary(self.layer1_no_binary(x) * slope))
x_layer2_no_binary = self.act_layer2_no_binary(self.batchnorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope))
x_layer3_no_binary = self.act_layer3_no_binary(self.batchnorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope))
x_layer4_no_binary = self.act_layer4_no_binary(self.batchnorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope))
x_layer5_no_binary = self.act_layer5_no_binary(self.batchnorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope))
x_layer5_binary = x_layer2_binary.view(x_layer5_binary.size(0), -1) x_layer2_binary = self.act_layer2_binary(((self.batchNorm2_binary(self.layer2_binary(x_layer1_binary))), slope))
x_layer5_no_binary = x_layer2_no_binary.view(x_layer5_no_binary.size(0), -1) x_layer3_binary = self.act_layer3_binary(((self.batchNorm3_binary(self.layer3_binary(x_layer2_binary))), slope))
x_layer4_binary = self.act_layer4_binary(((self.batchNorm4_binary(self.layer4_binary(x_layer3_binary))), slope))
x_layer5_binary = self.act_layer5_binary(((self.batchNorm5_binary(self.layer5_binary(x_layer4_binary))), slope))
# For no binary:
x_layer1_no_binary = self.act_layer1_no_binary(self.batchNorm1_no_binary(self.layer1_no_binary(x) * slope))
x_layer2_no_binary = self.act_layer2_no_binary(self.batchNorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope))
x_layer3_no_binary = self.act_layer3_no_binary(self.batchNorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope))
x_layer4_no_binary = self.act_layer4_no_binary(self.batchNorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope))
x_layer5_no_binary = self.act_layer5_no_binary(self.batchNorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope))
x_layer5_binary = x_layer5_binary.view(x_layer5_binary.size(0), -1)
x_layer5_no_binary = x_layer5_no_binary.view(x_layer5_no_binary.size(0), -1)
x_concatenate = torch.cat((x_layer5_binary, x_layer5_no_binary), 1) x_concatenate = torch.cat((x_layer5_binary, x_layer5_no_binary), 1)
x_fc1 = self.dropout1(self.act_fc1(self.fc1(x_concatenate))) x_fc1 = self.dropout1(self.act_fc1(self.fc1(x_concatenate)))
x_fc2 = self.fc2(x_fc1) x_fc2 = self.fc2(x_fc1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment