diff --git a/experiments/omniglot_binary_classif.py b/experiments/omniglot_binary_classif.py index 3a913f61d01fbef10228dd77494bdc632a79353e..c642278f17534285c35743ddc216b8633363fcf8 100644 --- a/experiments/omniglot_binary_classif.py +++ b/experiments/omniglot_binary_classif.py @@ -85,4 +85,8 @@ print(name_model) path_model_checkpoint = 'trained_models/Omniglot_classif/Mixt_models/maxpooling/' path_save_plot = 'results/Omniglot_results/plot_acc_loss/Omniglot_classif/' +<<<<<<< HEAD run(model, path_model_checkpoint, path_save_plot, name_model, train_loader, valid_loader, epochs, lr, momentum, criterion, log_interval) +======= +run(model, path_model_checkpoint, path_save_plot, name_model, train_loader, valid_loader, epochs, lr, momentum, criterion, log_interval) +>>>>>>> d798839d86ffc471c6d4f24a2bf4fad0e0b52036 diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_acc.png deleted file mode 100644 index 5999710e4fbeae76ddc7d66e67a62b11cedc61f8..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_acc.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_loss.png deleted file mode 100644 index 57fe9cef857281fa1dcfca9cd41e33f62c78957a..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_loss.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_acc.png deleted file mode 100644 index aafadbdb2496d7292b478ac88978c5c44c638e92..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_acc.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_loss.png deleted file mode 100644 index c4d410a0bd04a392feb2033a34a61e4c28ef6683..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_loss.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_acc.png deleted file mode 100644 index 0ba28e939e83ef55a406a3546b0cd02b73c23bad..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_acc.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_loss.png deleted file mode 100644 index 504cccc37f89ace7333d77a2b034915cdf16bbe9..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_loss.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_acc.png deleted file mode 100644 index 1b98880d838bdc9202cfd6c2550f9160a65078b1..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_acc.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_loss.png deleted file mode 100644 index 2810cba558581313743417bd94933943dfea5a54..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_loss.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_acc.png deleted file mode 100644 index b23428ed1d1e21dc3b9a7b4e8e353bbb5f83b610..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_acc.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_loss.png deleted file mode 100644 index 7c1bf1bfd40be15043e02d423d6ded6c3b022f9e..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_loss.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_acc.png deleted file mode 100644 index db7c814ae7a38e171d50e91261b3a5105feced8c..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_acc.png and /dev/null differ diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_loss.png deleted file mode 100644 index 6d61e0332840d1cd301ddd43c47f145db4af2faf..0000000000000000000000000000000000000000 Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_loss.png and /dev/null differ diff --git a/utils/models.py b/utils/models.py index 12fc579dde615d32782e50f01363383c60ffbcd8..c69a76d8fad8ab091b0e68f96146ad95c7b7d6b8 100644 --- a/utils/models.py +++ b/utils/models.py @@ -492,7 +492,6 @@ class BinaryNetOmniglotClassification(Net): self.fc1 = nn.Linear(3 * 3 * 512, 4096) else: self.fc1 = nn.Linear(4 * 4 * 512, 4096) - self.act_fc1 = nn.ReLU() self.dropout1 = nn.Dropout(0.5) self.fc2 = nn.Linear(4096, 1623) @@ -520,7 +519,7 @@ class BinaryNetOmniglotClassification(Net): else: x_layer4 = self.act_layer4(self.maxPool4(self.batchNorm4(self.layer4(x_layer3) * slope))) x_layer5 = self.act_layer5(self.maxPool5(self.batchNorm5(self.layer5(x_layer4) * slope))) - + else: if self.first_conv_layer: x_layer1 = self.act_layer1(((self.batchNorm1(self.layer1(x))), slope)) @@ -668,6 +667,122 @@ class MixtNetOmniglotClassification(Net): return x_out + +class MixtNetOmniglotClassification(Net): + + def __init__(self, maxpooling, mode='Deterministic', estimator='ST'): + super(MixtNetOmniglotClassification, self).__init__() + + assert mode in ['Deterministic', 'Stochastic'] + assert estimator in ['ST', 'REINFORCE'] + + self.maxpooling = maxpooling + self.mode = mode + self.estimator = estimator + + if self.maxpooling: + self.stride = 1 + self.maxPool1_no_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool2_no_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool3_no_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool4_no_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool5_no_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool1_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool2_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool3_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool4_binary = nn.MaxPool2d(kernel_size=2, stride=2) + self.maxPool5_binary = nn.MaxPool2d(kernel_size=2, stride=2) + else: + self.stride = 2 + + self.layer1_no_binary = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=self.stride) + self.layer2_no_binary = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=self.stride) + self.layer3_no_binary = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=self.stride) + self.layer4_no_binary = nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=self.stride) + self.layer5_no_binary = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=self.stride) + self.layer1_binary = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=self.stride) + self.layer2_binary = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=self.stride) + self.layer3_binary = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=self.stride) + self.layer4_binary = nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=self.stride) + self.layer5_binary = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=self.stride) + + self.batchNorm1_no_binary = nn.BatchNorm2d(32) + self.batchNorm2_no_binary = nn.BatchNorm2d(64) + self.batchNorm3_no_binary = nn.BatchNorm2d(128) + self.batchNorm4_no_binary = nn.BatchNorm2d(256) + self.batchNorm5_no_binary = nn.BatchNorm2d(256) + self.batchNorm1_binary = nn.BatchNorm2d(32) + self.batchNorm2_binary = nn.BatchNorm2d(64) + self.batchNorm3_binary = nn.BatchNorm2d(128) + self.batchNorm4_binary = nn.BatchNorm2d(256) + self.batchNorm5_binary = nn.BatchNorm2d(256) + + self.act_layer1_no_binary = nn.ReLU() + self.act_layer2_no_binary = nn.ReLU() + self.act_layer3_no_binary = nn.ReLU() + self.act_layer4_no_binary = nn.ReLU() + self.act_layer5_no_binary = nn.ReLU() + + if self.mode == 'Deterministic': + self.act_layer1 = DeterministicBinaryActivation(estimator=estimator) + self.act_layer2 = DeterministicBinaryActivation(estimator=estimator) + self.act_layer3 = DeterministicBinaryActivation(estimator=estimator) + self.act_layer4 = DeterministicBinaryActivation(estimator=estimator) + self.act_layer5 = DeterministicBinaryActivation(estimator=estimator) + elif self.mode == 'Stochastic': + self.act_layer1_binary = StochasticBinaryActivation(estimator=estimator) + self.act_layer2_binary = StochasticBinaryActivation(estimator=estimator) + self.act_layer3_binary = StochasticBinaryActivation(estimator=estimator) + self.act_layer4_binary = StochasticBinaryActivation(estimator=estimator) + self.act_layer5_binary = StochasticBinaryActivation(estimator=estimator) + + self.fc1 = nn.Linear(4*4*256*2, 4096) + self.act_fc1 = nn.ReLU() + self.dropout1 = nn.Dropout(0.5) + self.fc2 = nn.Linear(4096, 1623) + + + def forward(self, input): + x = input + slope = 1.0 + + if self.maxpooling: + # For binary: + x_layer1_binary = self.act_layer1_binary(((self.maxpool1_binary(self.batchnorm1_binary(self.layer1_binary(x)))), slope)) + x_layer2_binary = self.act_layer2_binary(((self.maxpool2_binary(self.batchnorm2_binary(self.layer2_binary(x_layer1_binary)))), slope)) + x_layer3_binary = self.act_layer3_binary(((self.maxpool3_binary(self.batchnorm3_binary(self.layer3_binary(x_layer2_binary)))), slope)) + x_layer4_binary = self.act_layer4_binary(((self.maxpool4_binary(self.batchnorm4_binary(self.layer4_binary(x_layer3_binary)))), slope)) + x_layer5_binary = self.act_layer5_binary(((self.maxpool5_binary(self.batchnorm5_binary(self.layer5_binary(x_layer4_binary)))), slope)) + # For No binary: + x_layer1_no_binary = self.act_layer1_no_binary(self.maxpool1_no_binary(self.batchnorm1_no_binary(self.layer1_no_binary(x) * slope))) + x_layer2_no_binary = self.act_layer2_no_binary(self.maxpool2_no_binary(self.batchnorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope))) + x_layer3_no_binary = self.act_layer3_no_binary(self.maxpool3_no_binary(self.batchnorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope))) + x_layer4_no_binary = self.act_layer4_no_binary(self.maxpool4_no_binary(self.batchnorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope))) + x_layer5_no_binary = self.act_layer5_no_binary(self.maxpool5_no_binary(self.batchnorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope))) + + else: + # For binary: + x_layer1_binary = self.act_layer1_binary(((self.batchnorm1_binary(self.layer1_binary(x))), slope)) + x_layer2_binary = self.act_layer2_binary(((self.batchnorm2_binary(self.layer2_binary(x_layer1_binary))), slope)) + x_layer3_binary = self.act_layer3_binary(((self.batchnorm3_binary(self.layer3_binary(x_layer2_binary))), slope)) + x_layer4_binary = self.act_layer4_binary(((self.batchnorm4_binary(self.layer4_binary(x_layer3_binary))), slope)) + x_layer5_binary = self.act_layer5_binary(((self.batchnorm5_binary(self.layer5_binary(x_layer4_binary))), slope)) + # For no binary: + x_layer1_no_binary = self.act_layer1_no_binary(self.batchnorm1_no_binary(self.layer1_no_binary(x) * slope)) + x_layer2_no_binary = self.act_layer2_no_binary(self.batchnorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope)) + x_layer3_no_binary = self.act_layer3_no_binary(self.batchnorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope)) + x_layer4_no_binary = self.act_layer4_no_binary(self.batchnorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope)) + x_layer5_no_binary = self.act_layer5_no_binary(self.batchnorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope)) + + x_layer5_binary = x_layer2_binary.view(x_layer5_binary.size(0), -1) + x_layer5_no_binary = x_layer2_no_binary.view(x_layer5_no_binary.size(0), -1) + x_concatenate = torch.cat((x_layer5_binary, x_layer5_no_binary), 1) + x_fc1 = self.dropout1(self.act_fc1(self.fc1(x_concatenate))) + x_fc2 = self.fc2(x_fc1) + x_out = F.log_softmax(x_fc2, dim=1) + return x_out + + class NoBinaryMatchingNetwork(nn.Module): def __init__(self, n: int, k: int, q: int, num_input_channels: int): """Creates a Matching Network as described in Vinyals et al.