diff --git a/experiments/omniglot_binary_classif.py b/experiments/omniglot_binary_classif.py
index 3a913f61d01fbef10228dd77494bdc632a79353e..c642278f17534285c35743ddc216b8633363fcf8 100644
--- a/experiments/omniglot_binary_classif.py
+++ b/experiments/omniglot_binary_classif.py
@@ -85,4 +85,8 @@ print(name_model)
 path_model_checkpoint = 'trained_models/Omniglot_classif/Mixt_models/maxpooling/'
 path_save_plot = 'results/Omniglot_results/plot_acc_loss/Omniglot_classif/'
 
+<<<<<<< HEAD
 run(model, path_model_checkpoint, path_save_plot, name_model, train_loader, valid_loader, epochs, lr, momentum, criterion, log_interval)
+=======
+run(model, path_model_checkpoint, path_save_plot, name_model, train_loader, valid_loader, epochs, lr, momentum, criterion, log_interval)
+>>>>>>> d798839d86ffc471c6d4f24a2bf4fad0e0b52036
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_acc.png
deleted file mode 100644
index 5999710e4fbeae76ddc7d66e67a62b11cedc61f8..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_acc.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_loss.png
deleted file mode 100644
index 57fe9cef857281fa1dcfca9cd41e33f62c78957a..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_maxpooling_mixt_loss.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_acc.png
deleted file mode 100644
index aafadbdb2496d7292b478ac88978c5c44c638e92..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_acc.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_loss.png
deleted file mode 100644
index c4d410a0bd04a392feb2033a34a61e4c28ef6683..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/MNIST_Stochastic_ST_mixt_loss.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_acc.png
deleted file mode 100644
index 0ba28e939e83ef55a406a3546b0cd02b73c23bad..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_acc.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_loss.png
deleted file mode 100644
index 504cccc37f89ace7333d77a2b034915cdf16bbe9..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_loss.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_acc.png
deleted file mode 100644
index 1b98880d838bdc9202cfd6c2550f9160a65078b1..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_acc.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_loss.png
deleted file mode 100644
index 2810cba558581313743417bd94933943dfea5a54..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_NonBinaryNet_maxpooling_loss.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_acc.png
deleted file mode 100644
index b23428ed1d1e21dc3b9a7b4e8e353bbb5f83b610..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_acc.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_loss.png
deleted file mode 100644
index 7c1bf1bfd40be15043e02d423d6ded6c3b022f9e..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_loss.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_acc.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_acc.png
deleted file mode 100644
index db7c814ae7a38e171d50e91261b3a5105feced8c..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_acc.png and /dev/null differ
diff --git a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_loss.png b/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_loss.png
deleted file mode 100644
index 6d61e0332840d1cd301ddd43c47f145db4af2faf..0000000000000000000000000000000000000000
Binary files a/results/Omniglot_results/plot_acc_loss/Omniglot_classif/Omniglot_classif_Stochastic_ST_first_conv_binary_maxpooling_loss.png and /dev/null differ
diff --git a/utils/models.py b/utils/models.py
index 12fc579dde615d32782e50f01363383c60ffbcd8..c69a76d8fad8ab091b0e68f96146ad95c7b7d6b8 100644
--- a/utils/models.py
+++ b/utils/models.py
@@ -492,7 +492,6 @@ class BinaryNetOmniglotClassification(Net):
             self.fc1 = nn.Linear(3 * 3 * 512, 4096)
         else:
             self.fc1 = nn.Linear(4 * 4 * 512, 4096)
-
         self.act_fc1 = nn.ReLU()
         self.dropout1 = nn.Dropout(0.5)
         self.fc2 = nn.Linear(4096, 1623)
@@ -520,7 +519,7 @@ class BinaryNetOmniglotClassification(Net):
             else:
                 x_layer4 = self.act_layer4(self.maxPool4(self.batchNorm4(self.layer4(x_layer3) * slope)))
             x_layer5 = self.act_layer5(self.maxPool5(self.batchNorm5(self.layer5(x_layer4) * slope)))
-            
+        
         else:
             if self.first_conv_layer:
                 x_layer1 = self.act_layer1(((self.batchNorm1(self.layer1(x))), slope))
@@ -668,6 +667,122 @@ class MixtNetOmniglotClassification(Net):
         return x_out
         
 
+
+class MixtNetOmniglotClassification(Net):
+
+    def __init__(self, maxpooling, mode='Deterministic', estimator='ST'):
+        super(MixtNetOmniglotClassification, self).__init__()
+
+        assert mode in ['Deterministic', 'Stochastic']
+        assert estimator in ['ST', 'REINFORCE']
+
+        self.maxpooling = maxpooling
+        self.mode = mode
+        self.estimator = estimator
+
+        if self.maxpooling:
+            self.stride = 1
+            self.maxPool1_no_binary = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.maxPool2_no_binary = nn.MaxPool2d(kernel_size=2, stride=2)    
+            self.maxPool3_no_binary = nn.MaxPool2d(kernel_size=2, stride=2)  
+            self.maxPool4_no_binary = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.maxPool5_no_binary = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.maxPool1_binary = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.maxPool2_binary = nn.MaxPool2d(kernel_size=2, stride=2)    
+            self.maxPool3_binary = nn.MaxPool2d(kernel_size=2, stride=2)  
+            self.maxPool4_binary = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.maxPool5_binary = nn.MaxPool2d(kernel_size=2, stride=2)
+        else:
+            self.stride = 2
+            
+        self.layer1_no_binary = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=self.stride)
+        self.layer2_no_binary = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=self.stride)
+        self.layer3_no_binary = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=self.stride)
+        self.layer4_no_binary = nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=self.stride)
+        self.layer5_no_binary = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=self.stride)
+        self.layer1_binary = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=self.stride)
+        self.layer2_binary = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=self.stride)
+        self.layer3_binary = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=self.stride)
+        self.layer4_binary = nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=self.stride)
+        self.layer5_binary = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=self.stride)
+            
+        self.batchNorm1_no_binary = nn.BatchNorm2d(32)
+        self.batchNorm2_no_binary = nn.BatchNorm2d(64)
+        self.batchNorm3_no_binary = nn.BatchNorm2d(128)
+        self.batchNorm4_no_binary = nn.BatchNorm2d(256) 
+        self.batchNorm5_no_binary = nn.BatchNorm2d(256) 
+        self.batchNorm1_binary = nn.BatchNorm2d(32)
+        self.batchNorm2_binary = nn.BatchNorm2d(64)
+        self.batchNorm3_binary = nn.BatchNorm2d(128)
+        self.batchNorm4_binary = nn.BatchNorm2d(256)  
+        self.batchNorm5_binary = nn.BatchNorm2d(256)  
+        
+        self.act_layer1_no_binary = nn.ReLU()
+        self.act_layer2_no_binary = nn.ReLU()
+        self.act_layer3_no_binary = nn.ReLU()
+        self.act_layer4_no_binary = nn.ReLU()
+        self.act_layer5_no_binary = nn.ReLU()
+        
+        if self.mode == 'Deterministic':
+            self.act_layer1 = DeterministicBinaryActivation(estimator=estimator)
+            self.act_layer2 = DeterministicBinaryActivation(estimator=estimator)
+            self.act_layer3 = DeterministicBinaryActivation(estimator=estimator)
+            self.act_layer4 = DeterministicBinaryActivation(estimator=estimator)
+            self.act_layer5 = DeterministicBinaryActivation(estimator=estimator)
+        elif self.mode == 'Stochastic':
+            self.act_layer1_binary = StochasticBinaryActivation(estimator=estimator)
+            self.act_layer2_binary = StochasticBinaryActivation(estimator=estimator)
+            self.act_layer3_binary = StochasticBinaryActivation(estimator=estimator)
+            self.act_layer4_binary = StochasticBinaryActivation(estimator=estimator)
+            self.act_layer5_binary = StochasticBinaryActivation(estimator=estimator)
+
+        self.fc1 = nn.Linear(4*4*256*2, 4096)
+        self.act_fc1 = nn.ReLU()
+        self.dropout1 = nn.Dropout(0.5)
+        self.fc2 = nn.Linear(4096, 1623)
+
+
+    def forward(self, input):
+        x = input
+        slope = 1.0
+        
+        if self.maxpooling:
+            # For binary: 
+            x_layer1_binary = self.act_layer1_binary(((self.maxpool1_binary(self.batchnorm1_binary(self.layer1_binary(x)))), slope))
+            x_layer2_binary = self.act_layer2_binary(((self.maxpool2_binary(self.batchnorm2_binary(self.layer2_binary(x_layer1_binary)))), slope))
+            x_layer3_binary = self.act_layer3_binary(((self.maxpool3_binary(self.batchnorm3_binary(self.layer3_binary(x_layer2_binary)))), slope))
+            x_layer4_binary = self.act_layer4_binary(((self.maxpool4_binary(self.batchnorm4_binary(self.layer4_binary(x_layer3_binary)))), slope))
+            x_layer5_binary = self.act_layer5_binary(((self.maxpool5_binary(self.batchnorm5_binary(self.layer5_binary(x_layer4_binary)))), slope))
+            # For No binary: 
+            x_layer1_no_binary = self.act_layer1_no_binary(self.maxpool1_no_binary(self.batchnorm1_no_binary(self.layer1_no_binary(x) * slope)))
+            x_layer2_no_binary = self.act_layer2_no_binary(self.maxpool2_no_binary(self.batchnorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope)))
+            x_layer3_no_binary = self.act_layer3_no_binary(self.maxpool3_no_binary(self.batchnorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope)))
+            x_layer4_no_binary = self.act_layer4_no_binary(self.maxpool4_no_binary(self.batchnorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope)))
+            x_layer5_no_binary = self.act_layer5_no_binary(self.maxpool5_no_binary(self.batchnorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope)))
+                        
+        else:
+            # For binary: 
+            x_layer1_binary = self.act_layer1_binary(((self.batchnorm1_binary(self.layer1_binary(x))), slope))
+            x_layer2_binary = self.act_layer2_binary(((self.batchnorm2_binary(self.layer2_binary(x_layer1_binary))), slope))
+            x_layer3_binary = self.act_layer3_binary(((self.batchnorm3_binary(self.layer3_binary(x_layer2_binary))), slope))
+            x_layer4_binary = self.act_layer4_binary(((self.batchnorm4_binary(self.layer4_binary(x_layer3_binary))), slope))
+            x_layer5_binary = self.act_layer5_binary(((self.batchnorm5_binary(self.layer5_binary(x_layer4_binary))), slope))
+            # For no binary:
+            x_layer1_no_binary = self.act_layer1_no_binary(self.batchnorm1_no_binary(self.layer1_no_binary(x) * slope))
+            x_layer2_no_binary = self.act_layer2_no_binary(self.batchnorm2_no_binary(self.layer2_no_binary(x_layer1_no_binary) * slope))
+            x_layer3_no_binary = self.act_layer3_no_binary(self.batchnorm3_no_binary(self.layer3_no_binary(x_layer2_no_binary) * slope))
+            x_layer4_no_binary = self.act_layer4_no_binary(self.batchnorm4_no_binary(self.layer4_no_binary(x_layer3_no_binary) * slope))
+            x_layer5_no_binary = self.act_layer5_no_binary(self.batchnorm5_no_binary(self.layer5_no_binary(x_layer4_no_binary) * slope))
+        
+        x_layer5_binary = x_layer2_binary.view(x_layer5_binary.size(0), -1)
+        x_layer5_no_binary = x_layer2_no_binary.view(x_layer5_no_binary.size(0), -1)
+        x_concatenate = torch.cat((x_layer5_binary, x_layer5_no_binary), 1)
+        x_fc1 = self.dropout1(self.act_fc1(self.fc1(x_concatenate)))
+        x_fc2 = self.fc2(x_fc1)
+        x_out = F.log_softmax(x_fc2, dim=1)
+        return x_out
+        
+
 class NoBinaryMatchingNetwork(nn.Module):
     def __init__(self, n: int, k: int, q: int, num_input_channels: int):
         """Creates a Matching Network as described in Vinyals et al.