Skip to content
Snippets Groups Projects
Commit 61d7e641 authored by Julien Dejasmin's avatar Julien Dejasmin
Browse files

add binary matching network

parent dcde6c2c
No related branches found
No related tags found
No related merge requests found
# Default ignored files # Default ignored files
/workspace.xml /workspace.xml
/shelf/
...@@ -3,8 +3,13 @@ import os ...@@ -3,8 +3,13 @@ import os
PATH = os.path.dirname(os.path.realpath(__file__)) PATH = os.path.dirname(os.path.realpath(__file__))
# local DATA_PATH
DATA_PATH = '/home/julien/PycharmProjects/thesis/work/Pytorch/MNIST_Binary_V2/data/' DATA_PATH = '/home/julien/PycharmProjects/thesis/work/Pytorch/MNIST_Binary_V2/data/'
# colab DATA_PATH
# DATA_PATH = 'data/'
EPSILON = 1e-8 EPSILON = 1e-8
if DATA_PATH is None: if DATA_PATH is None:
......
...@@ -6,7 +6,7 @@ from torch.optim import Adam ...@@ -6,7 +6,7 @@ from torch.optim import Adam
from DataLoader.dataset import OmniglotDataset from DataLoader.dataset import OmniglotDataset
from utils.core import prepare_nshot_task, EvaluateFewShot from utils.core import prepare_nshot_task, EvaluateFewShot
from utils.models import NoBinaryMatchingNetwork from utils.models import NoBinaryMatchingNetwork, BinaryMatchingNetwork
from utils.training import fit from utils.training import fit
from utils.callback import * from utils.callback import *
from config import PATH from config import PATH
...@@ -28,9 +28,9 @@ k_test = 5 ...@@ -28,9 +28,9 @@ k_test = 5
q_test = 1 q_test = 1
evaluation_episodes = 1000 evaluation_episodes = 1000
episodes_per_epoch = 100 episodes_per_epoch = 10
n_epochs = 100 n_epochs = 2
dataset_class = OmniglotDataset dataset_class = OmniglotDataset
num_input_channels = 1 num_input_channels = 1
...@@ -38,27 +38,35 @@ param_str = f'_n={n_train}_k={k_train}_q={q_train}_' \ ...@@ -38,27 +38,35 @@ param_str = f'_n={n_train}_k={k_train}_q={q_train}_' \
f'nv={n_test}_kv={k_test}_qv={q_test}_' \ f'nv={n_test}_kv={k_test}_qv={q_test}_' \
f'dist={distance}' f'dist={distance}'
######### first_conv_layer = True
# Model # second_conv_layer = True
######### third_conv_layer = True
model = NoBinaryMatchingNetwork(n_train, k_train, q_train, num_input_channels) fourth_conv_layer = True
slope_annealing = True
binary_model = True
###################
# No binary Model #
###################
# model = NoBinaryMatchingNetwork(n_train, k_train, q_train, num_input_channels)
# model, use_gpu = gpu_config(model)
# model.double()
################
# Binary Model #
################
model = BinaryMatchingNetwork(first_conv_layer, second_conv_layer, third_conv_layer, fourth_conv_layer,
n_train, k_train, q_train, num_input_channels)
model, use_gpu = gpu_config(model) model, use_gpu = gpu_config(model)
model.double() model.double()
########### ###########
# Dataset # # Dataset #
########### ###########
# background_taskloader, evaluation_taskloader = get_omniglot_dataloader_v2(episodes_per_epoch, n_train, k_train, background_taskloader, evaluation_taskloader = get_omniglot_dataloader_v2(episodes_per_epoch, n_train, k_train,
# q_train, n_test, k_test, q_test, q_train, n_test, k_test, q_test,
# dataset_class) dataset_class)
# save dataloader:
# torch.save(background_taskloader, 'background_taskloader.pth')
# torch.save(evaluation_taskloader, 'evaluation_taskloader.pth')
# load dataloader
background_taskloader = torch.load('background_taskloader.pth')
evaluation_taskloader = torch.load('evaluation_taskloader.pth')
"""
############ ############
# Training # # Training #
############ ############
...@@ -68,6 +76,9 @@ loss_fn = torch.nn.NLLLoss().cuda() ...@@ -68,6 +76,9 @@ loss_fn = torch.nn.NLLLoss().cuda()
callbacks = [ callbacks = [
EvaluateFewShot( EvaluateFewShot(
binary_model=binary_model,
slope=1.0,
use_gpu=use_gpu,
eval_fn=matching_net_episode, eval_fn=matching_net_episode,
num_tasks=evaluation_episodes, num_tasks=evaluation_episodes,
n_shot=n_test, n_shot=n_test,
...@@ -87,6 +98,8 @@ callbacks = [ ...@@ -87,6 +98,8 @@ callbacks = [
] ]
fit( fit(
binary_model,
slope_annealing,
use_gpu, use_gpu,
model, model,
optimiser, optimiser,
...@@ -99,4 +112,3 @@ fit( ...@@ -99,4 +112,3 @@ fit(
fit_function=matching_net_episode, fit_function=matching_net_episode,
fit_function_kwargs={'n_shot': n_train, 'k_way': k_train, 'q_queries': q_train, 'train': True, 'distance': distance} fit_function_kwargs={'n_shot': n_train, 'k_way': k_train, 'q_queries': q_train, 'train': True, 'distance': distance}
) )
"""
\ No newline at end of file
epoch,categorical_accuracy,loss,lr,val_1-shot_5-way_acc,val_loss
1,0.29733333333333334,2.2908880733978103,0.001,0.3,2.058387121346762
2,0.36133333333333334,1.7418731719399887,0.001,0.46,1.415031428749744
File added
...@@ -124,6 +124,9 @@ class EvaluateFewShot(Callback): ...@@ -124,6 +124,9 @@ class EvaluateFewShot(Callback):
""" """
def __init__(self, def __init__(self,
binary_model,
slope,
use_gpu,
eval_fn: Callable, eval_fn: Callable,
num_tasks: int, num_tasks: int,
n_shot: int, n_shot: int,
...@@ -134,6 +137,9 @@ class EvaluateFewShot(Callback): ...@@ -134,6 +137,9 @@ class EvaluateFewShot(Callback):
prefix: str = 'val_', prefix: str = 'val_',
**kwargs): **kwargs):
super(EvaluateFewShot, self).__init__() super(EvaluateFewShot, self).__init__()
self.binary_model = binary_model
self.slope = slope
self.use_gpu = use_gpu
self.eval_fn = eval_fn self.eval_fn = eval_fn
self.num_tasks = num_tasks self.num_tasks = num_tasks
self.n_shot = n_shot self.n_shot = n_shot
...@@ -157,6 +163,9 @@ class EvaluateFewShot(Callback): ...@@ -157,6 +163,9 @@ class EvaluateFewShot(Callback):
x, y = self.prepare_batch(batch) x, y = self.prepare_batch(batch)
loss, y_pred = self.eval_fn( loss, y_pred = self.eval_fn(
self.binary_model,
self.slope,
self.use_gpu,
self.model, self.model,
self.optimiser, self.optimiser,
self.loss_fn, self.loss_fn,
......
...@@ -9,7 +9,9 @@ from torch.nn.utils import clip_grad_norm_ ...@@ -9,7 +9,9 @@ from torch.nn.utils import clip_grad_norm_
EPSILON = 1e-8 EPSILON = 1e-8
def matching_net_episode(use_gpu, def matching_net_episode(binary_model,
slope,
use_gpu,
model: Module, model: Module,
optimiser: optimizer, optimiser: optimizer,
loss_fn: Loss, loss_fn: Loss,
...@@ -45,6 +47,9 @@ def matching_net_episode(use_gpu, ...@@ -45,6 +47,9 @@ def matching_net_episode(use_gpu,
model.eval() model.eval()
# Embed all samples # Embed all samples
if binary_model:
embeddings = model((x, slope))
else:
embeddings = model.encoder(x) embeddings = model.encoder(x)
# Samples are ordered by the NShotWrapper class as follows: # Samples are ordered by the NShotWrapper class as follows:
# k lots of n support samples from a particular class # k lots of n support samples from a particular class
......
...@@ -159,5 +159,93 @@ class Flatten(nn.Module): ...@@ -159,5 +159,93 @@ class Flatten(nn.Module):
# Arguments # Arguments
input: Input tensor input: Input tensor
""" """
def forward(self, input): def forward(self, input):
return input.view(input.size(0), -1) return input.view(input.size(0), -1)
class BinaryMatchingNetwork(nn.Module):
def __init__(self, first_conv_layer, second_conv_layer, third_conv_layer, fourth_conv_layer,
n: int, k: int, q: int, num_input_channels: int, mode='Deterministic', estimator='ST'):
super(BinaryMatchingNetwork, self).__init__()
assert mode in ['Deterministic', 'Stochastic']
assert estimator in ['ST', 'REINFORCE']
self.mode = mode
self.estimator = estimator
self.first_conv_layer = first_conv_layer
self.second_conv_layer = second_conv_layer
self.third_conv_layer = third_conv_layer
self.fourth_conv_layer = fourth_conv_layer
self.n = n
self.k = k
self.q = q
self.num_input_channels = num_input_channels
self.layer1 = nn.Sequential(
nn.Conv2d(num_input_channels, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride=2))
if first_conv_layer:
if self.mode == 'Deterministic':
self.act_layer1 = DeterministicBinaryActivation(estimator=estimator)
elif self.mode == 'Stochastic':
self.act_layer1 = StochasticBinaryActivation(estimator=estimator)
else:
self.act_layer1 = nn.ReLU()
self.layer2 = nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride=2))
if second_conv_layer:
if self.mode == 'Deterministic':
self.act_layer2 = DeterministicBinaryActivation(estimator=estimator)
elif self.mode == 'Stochastic':
self.act_layer2 = StochasticBinaryActivation(estimator=estimator)
else:
self.act_layer2 = nn.ReLU()
self.layer3 = nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride=2))
if third_conv_layer:
if self.mode == 'Deterministic':
self.act_layer3 = DeterministicBinaryActivation(estimator=estimator)
elif self.mode == 'Stochastic':
self.act_layer3 = StochasticBinaryActivation(estimator=estimator)
else:
self.act_layer3 = nn.ReLU()
self.layer4 = nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride=2))
if fourth_conv_layer:
if self.mode == 'Deterministic':
self.act_layer4 = DeterministicBinaryActivation(estimator=estimator)
elif self.mode == 'Stochastic':
self.act_layer4 = StochasticBinaryActivation(estimator=estimator)
else:
self.act_layer4 = nn.ReLU()
def forward(self, inputs):
x, slope = inputs
if self.first_conv_layer:
x_layer1 = self.act_layer1((self.layer1(x), slope))
else:
x_layer1 = self.act_layer1(self.layer1(x) * slope)
if self.second_conv_layer:
x_layer2 = self.act_layer2((self.layer2(x_layer1), slope))
else:
x_layer2 = self.act_layer2(self.layer2(x_layer1) * slope)
if self.third_conv_layer:
x_layer3 = self.act_layer3((self.layer3(x_layer2), slope))
else:
x_layer3 = self.act_layer3(self.layer3(x_layer2) * slope)
if self.fourth_conv_layer:
x_layer4 = self.act_layer4((self.layer4(x_layer3), slope))
else:
x_layer4 = self.act_layer4(self.layer4(x_layer3) * slope)
x_out = x_layer4.view(x_layer4.size(0), -1)
return x_out
...@@ -190,7 +190,7 @@ def gradient_step(model: Module, optimiser: optimizer, loss_fn: Callable, x: tor ...@@ -190,7 +190,7 @@ def gradient_step(model: Module, optimiser: optimizer, loss_fn: Callable, x: tor
""" """
model.train() model.train()
optimiser.zero_grad() optimiser.zero_grad()
y_pred = model(x) y_pred = model((x, slope))
loss = loss_fn(y_pred, y) loss = loss_fn(y_pred, y)
loss.backward() loss.backward()
optimiser.step() optimiser.step()
...@@ -218,7 +218,8 @@ def batch_metrics(model: Module, y_pred: torch.Tensor, y: torch.Tensor, metrics: ...@@ -218,7 +218,8 @@ def batch_metrics(model: Module, y_pred: torch.Tensor, y: torch.Tensor, metrics:
return batch_logs return batch_logs
def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, def fit(binary_model, slope_annealing, use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable,
epochs: int, dataloader: DataLoader,
prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None,
verbose: bool = True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}): verbose: bool = True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}):
"""Function to abstract away training loop. """Function to abstract away training loop.
...@@ -257,6 +258,14 @@ def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs: ...@@ -257,6 +258,14 @@ def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs:
'optimiser': optimiser 'optimiser': optimiser
}) })
# Slope annealing
if slope_annealing:
def get_slope(epochs):
return 1.0 * (1.005 ** (epochs - 1))
else:
def get_slope(epochs):
return 1.0
global slope
if verbose: if verbose:
print('Begin training...') print('Begin training...')
...@@ -264,6 +273,7 @@ def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs: ...@@ -264,6 +273,7 @@ def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs:
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
callbacks.on_epoch_begin(epoch) callbacks.on_epoch_begin(epoch)
slope = get_slope(epoch)
epoch_logs = {} epoch_logs = {}
for batch_index, batch in enumerate(dataloader): for batch_index, batch in enumerate(dataloader):
...@@ -273,7 +283,8 @@ def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs: ...@@ -273,7 +283,8 @@ def fit(use_gpu, model: Module, optimiser: optimizer, loss_fn: Callable, epochs:
x, y = prepare_batch(batch) x, y = prepare_batch(batch)
loss, y_pred = fit_function(use_gpu, model, optimiser, loss_fn, x, y, **fit_function_kwargs) loss, y_pred = fit_function(binary_model, slope, use_gpu, model, optimiser, loss_fn, x, y,
**fit_function_kwargs)
batch_logs['loss'] = loss.item() batch_logs['loss'] = loss.item()
# Loops through all metrics # Loops through all metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment