diff --git a/main/experiments/few_shot_training_procedure.py b/main/experiments/few_shot_training_procedure.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fb01cd3a50506a8450290606221d5cb163ea84b
--- /dev/null
+++ b/main/experiments/few_shot_training_procedure.py
@@ -0,0 +1,592 @@
+"""
+Few shot learning procesure: Benchmarking deepstrom on the few shot classification task
+
+Usage:
+    benchmark_vgg dense -D reprdim -l  ame
+        [-f name] [-d val] [-B name] [-a value] [-v number] [-e numepoch] [-j numepisode]
+        [-s numepisode] [-b nbclass] [-c nbclass] [-k nbsupp] [-i nbquery] [-q] [-o name] [-p nb]
+    benchmark_vgg deepfriedconvnet -N nbstack -l name
+        [-f name] [-d val] [-B name] [-a value] [-v number] [-e numepoch] [-s batchsize] [-g gammavalue]
+        [-b nbclass] [-c nbclass] [-k nbsupp] [-i nbquery] [-j numepisode] [-q] [-o name] [-p nb]
+    benchmark_vgg deepstrom (-R|-L|-C|-E|-P|-A) -m size -l name
+        [-f name] [-d val] [-B name] [-r] [-a value] [-v number] [-e numepoch] [-s numepisode] [-n] [-D reprdim]
+        [-g gammavalue] [-b nbclass] [-c nbclass] [-k nbsupp] [-i nbquery] [-j numepisode] [-q] [-o name] [-p nb]
+    benchmark_vgg none -l name
+        [-d val] [-B name] [-a value] [-v number] [-e numepoch] [-s numepisode]
+        [-b nbclass] [-c nbclass] [-k nbsupp] [-i nbquery] [-j numepisode] [-q] [-o name] [-p nb]
+
+Options:
+    --help -h                                       Display help and exit.
+    -q --quiet                                      Set logging level to info.
+    -a value --seed value                           The seed value used for all randomization processed[default: 0]
+    -v number --nb-class-val number                 The number of classes used for the validation set
+    -s numepisode --num-episode-train               The number of episode during train[default: 10000]
+    -j numepisode --num-episode-test                The number of episode during evaluation[default: 1000]
+    -e numepoch --num-epoch=numepoch                The number of epoch for training on the support set[default: 300]
+    -b nbclass --num-class-ep-train                 The number of classes for each episode during train [default: 60]
+    -c nbclass --num-class-ep-test                  The number of classes for each episode during test[default: 20]
+    -k nbsupp --num-supp-ex                         The number of support examples for each class[default: 5]
+    -i nbquery --num-query-ex                       The number of query examples for each class[default: 5]
+    -d --dropout val                                Keep probability of neurons before classif[default: 1.0]
+    -D reprdim --out-dim=reprdim                    The dimension of the final representation
+    -f --non-linearity name                         Tell the model which non-linearity to use when necessary (possible values: "relu", "tanh") [default: relu]
+    -l --classification-method name                 Telle the model which classification method should be use (possible values: "lc", "knn", "proto") [default: lc]
+
+Deepfried convnet:
+    -N nbstack --nb-stack nbstack                   The number of fastfood stack for deepfriedconvnet
+
+Deepstrom:
+    -r --real-nystrom                               Says if the matrix for deepstrom should be K^(-1/2)
+    -m size --nys-size size                         The number of example in the nystrom subsample.
+    -n --non-linear                                 Tell Nystrom to use the non linear activation function on its output.
+
+KNN:
+    -o --distance name                              Tell the distance measure to use for KNN (possible values: "L1", "L2") [default: L1]
+    -p --nb-neighbour nb                            Tell the number of neighbours for the KNN [default: 1]
+
+Dataset related:
+    -B --cut-layer name                             The name of the last convolutional layer when loading VGG19Transformer [default: activation_4]
+
+Possible kernels:
+    -R --rbf-kernel                                 Says if the rbf kernel should be used for nystrom.
+    -L --linear-kernel                              Says if the linear kernel should be used for nystrom.
+    -C --chi-square-kernel                          Says if the basic additive chi square kernel should be used for nystrom.
+    -E --exp-chi-square-kernel                      Says if the exponential chi square kernel should be used for nystrom.
+    -P --chi-square-PD-kernel                       Says if the Positive definite version of the basic additive chi square kernel should be used for nystrom.
+    -A --laplacian-kernel                           Says if the laplacian kernel should be used for nystrom.
+
+Kernel related:
+    -g gammavalue --gamma gammavalue                The value of gamma for rbf, chi or hyperbolic tangent kernel (deepstrom and deepfriedconvnet)
+"""
+import time
+import sys
+import docopt
+import numpy as np
+import scipy.stats
+from keras.layers import Dense, BatchNormalization
+import tensorflow as tf
+
+from skluc.main.data.mldatasets import OmniglotDataset
+from skluc.main.data.transformation.ResizeTransformer import ResizeTransformer
+from skluc.main.data.transformation.VinyalsTransformer import VinyalsTransformer
+from skluc.main.tensorflow_.kernel_approximation.fastfood_layer import FastFoodLayer
+from skluc.main.tensorflow_.kernel_approximation.nystrom_layer import DeepstromLayer
+from skluc.main.utils import LabeledData, logger, compute_euristic_sigma_chi2, compute_euristic_sigma
+
+
+def get_episode_data(x_train, dict_class_index, unique_classes_total, nb_class_episode, nb_examples_support, nb_examples_query, shuffle=True):
+    # unique_classes_total = np.unique(y_train, axis=0)
+    classes_episode_indices = np.random.choice(len(unique_classes_total), nb_class_episode, replace=False)
+    classes_episodes = unique_classes_total[classes_episode_indices]
+
+    # data selection for episode
+    lst_indices_support = []
+    y_support_episode = np.zeros((nb_examples_support*nb_class_episode, nb_class_episode))
+    y_query_episode = np.zeros((nb_examples_query*nb_class_episode, nb_class_episode))
+    lst_indices_query = []
+    for i, _class in enumerate(classes_episodes):
+        indices_examples_of_class_bool = dict_class_index[_class.tostring()]
+        # indices_examples_of_class_bool = OmniglotDataset.get_bool_idx_label(_class, y_train)
+        indices_examples_of_class = np.where(indices_examples_of_class_bool)[0]
+        indices_examples_of_class_support = np.random.choice(indices_examples_of_class, nb_examples_support, replace=False)
+        indices_examples_of_class_not_in_support = np.setdiff1d(indices_examples_of_class, indices_examples_of_class_support)
+        indices_examples_of_class_query = np.random.choice(indices_examples_of_class_not_in_support, nb_examples_query)
+
+        lst_indices_support.extend(indices_examples_of_class_support)
+        y_support_episode[i*nb_examples_support:(i+1)*nb_examples_support, i] = np.ones((nb_examples_support))
+        lst_indices_query.extend(indices_examples_of_class_query)
+        y_query_episode[i*nb_examples_query:(i+1)*nb_examples_query, i] = np.ones((nb_examples_query))
+
+    x_support_episode = x_train[np.array(lst_indices_support)]
+    x_query_episode = x_train[np.array(lst_indices_query)]
+    # y_train[np.array(lst_indices_support)]
+    # y_query_episode = y_train[np.array(lst_indices_query)]
+    if shuffle:
+        shuffled_indices_support = np.random.permutation(len(x_support_episode))
+        shuffled_inidex_query = np.random.permutation(len(x_query_episode))
+    else:
+        shuffled_indices_support = np.arange(len(x_support_episode))
+        shuffled_inidex_query = np.arange(len(x_query_episode))
+
+    return (LabeledData(data=x_support_episode[shuffled_indices_support], labels=y_support_episode[shuffled_indices_support]),
+            LabeledData(data=x_query_episode[shuffled_inidex_query], labels=y_query_episode[shuffled_inidex_query]))
+
+
+def print_result():
+    to_print_headers = [
+        "accuracy_test",
+        "total_time"
+    ]
+    to_print_list = [resman[k] for k in to_print_headers]
+    to_print_list.extend(paraman.get_ordered_values())
+    print(",".join(to_print_headers + paraman.get_ordered_keys()))
+    print(",".join([str(v) for v in to_print_list]))
+
+
+def get_gamma_value(arguments, dat, chi2=False):  # todo add to skluc.main.utils?
+    if arguments["--gamma"] is None:
+        logger.debug("Gamma arguments is None. Need to compute it.")
+        if chi2:
+            gamma_value = 1./compute_euristic_sigma_chi2(dat)
+
+        else:
+            gamma_value = 1./compute_euristic_sigma(dat)
+    else:
+        gamma_value = eval(arguments["--gamma"])
+
+    logger.debug("Gamma value is {}".format(gamma_value))
+    return gamma_value
+
+
+class ParameterManager:
+    # todo parametermanager could be a class which as options as attributes
+
+    def __init__(self, docopt_dict):
+        def init_kernel():
+            if docopt_dict["--rbf-kernel"]:
+                return "rbf"
+            elif docopt_dict["--linear-kernel"]:
+                return "linear"
+            elif docopt_dict["--chi-square-kernel"]:
+                return "chi2_cpd"
+            elif docopt_dict["--exp-chi-square-kernel"]:
+                return "chi2_exp_cpd"
+            elif docopt_dict["--chi-square-PD-kernel"]:
+                return "chi2_pd"
+            elif docopt_dict["--laplacian-kernel"]:
+                return "laplacian"
+            else:
+                return None
+
+        def init_network():
+            if docopt_dict["dense"]:
+                return "dense"
+            elif docopt_dict["deepfriedconvnet"]:
+                return "deepfriedconvnet"
+            elif docopt_dict["deepstrom"]:
+                return "deepstrom"
+            elif docopt_dict["none"]:
+                return "none"
+
+        def init_non_linearity():
+            if docopt_dict["--non-linearity"] == "tanh":
+                return tf.nn.tanh
+            elif docopt_dict["--non-linearity"] == "relu":
+                return tf.nn.relu
+            elif docopt_dict["--non-linearity"] == "None":
+                return None
+
+        self.__dict = docopt_dict
+        self.__dict["--out-dim"] = int(self.__dict["--out-dim"]) if eval(str(self.__dict["--out-dim"])) is not None else None
+        self.__dict["kernel"] = init_kernel()
+        self.__dict["network"] = init_network()
+        self.__dict["--non-linearity"] = init_non_linearity()
+        self.__dict["--nb-stack"] = int(self.__dict["--nb-stack"]) if self.__dict["--nb-stack"] is not None else None
+        self.__dict["--dropout"] = float(self.__dict["--dropout"])
+        self.__dict["--num-class-ep-train"] = int(self.__dict["--num-class-ep-train"])
+        self.__dict["--num-class-ep-test"] = int(self.__dict["--num-class-ep-test"])
+        self.__dict["--num-supp-ex"] = int(self.__dict["--num-supp-ex"])
+        self.__dict["--num-query-ex"] = int(self.__dict["--num-query-ex"])
+        self.__dict["--num-episode-test"] = int(self.__dict["--num-episode-test"])
+        self.__dict["--num-episode-train"] = int(self.__dict["--num-episode-train"])
+        self.__dict["--nys-size"] = int(self.__dict["--nys-size"]) if self.__dict["--nys-size"] is not None else None
+        self.__dict["--num-epoch"] = int(self.__dict["--num-epoch"])
+        self.__dict["--nb-neighbour"] = int(self.__dict["--nb-neighbour"])
+
+    def __getitem__(self, item):
+        return self.__dict[item]
+
+    def get_ordered_values(self):
+        return [self.__dict[k] for k in sorted(self.__dict.keys())]
+
+    def get_ordered_keys(self):
+        return sorted([k for k in self.__dict.keys()])
+
+
+class ResultManager:
+    def __init__(self):
+        self.__dict = {}
+        self.__dict["accuracy_test"] = None
+        self.__dict["total_time"] = None
+
+    def __getitem__(self, item):
+        return self.__dict[item]
+
+    def __setitem__(self, item, value):
+        self.__dict[item] = value
+
+
+def get_mapping(total_labels, name):
+    logger.debug(f"Finding unique class in {name}")
+    unique_classes_total = np.unique(total_labels, axis=0)
+    logger.debug(f"Finding bool array for each class in {name}")
+    dict_class_index = {}
+    for _class in unique_classes_total:
+        dict_class_index[_class.tostring()] = OmniglotDataset.get_bool_idx_label(_class, total_labels)
+    logger.debug(f"Number of {name} classes {len(dict_class_index.keys())}")
+    return unique_classes_total, dict_class_index
+
+def tf_euclidean_distance(a, b):
+    # a.shape = N x D
+    # b.shape = M x D
+    N, D = tf.shape(a)[0], tf.shape(a)[1]
+    M = tf.shape(b)[0]
+    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
+    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
+    return tf.reduce_mean(tf.square(a - b), axis=2)
+
+
+if __name__ == "__main__":
+    logger.debug(f"Command line: {sys.argv}")
+    # initialize parameter manager with command line arguments
+    paraman = ParameterManager(docopt.docopt(__doc__))
+    resman = ResultManager()
+    # todo param seed non utilisé
+    # todo param val size non utilisé
+    # todo param quiet non utilisé
+
+    START = time.time()
+    # loading dataset and apply preprocessing
+    data = OmniglotDataset()
+    data.load()
+    data.data_astype(np.float32)
+    data.labels_astype(np.float32)
+    data.to_image()
+    resize_trans = ResizeTransformer(data.s_name, output_shape=(28, 28))
+    data.apply_transformer(resize_trans)
+    data.normalize()
+    transformer = VinyalsTransformer(data_name="omniglot_28x28", cut_layer_name=paraman["--cut-layer"])
+    data.apply_transformer(transformer)
+    data.normalize()
+    data.flatten()
+    data.to_one_hot()
+
+    X_train, Y_train = data.train.data, data.train.labels
+    # todo faire une fonction qui augmente le nombre de classes comme dans prototypical (rotations à 90)
+    X_test, Y_test = data.test.data, data.test.labels
+
+    # get info about data
+    nb_examples_N = len(X_train)
+    nb_classes_train_K = Y_train.shape[1]
+    nb_classes_test_K = Y_test.shape[1]
+    input_dim, output_dim_train, output_dim_test = X_train.shape[1], \
+                                                   paraman["--num-class-ep-train"], \
+                                                   paraman["--num-class-ep-test"]
+
+    # get unique classes and build mapping between classes and corresponding indexes in data train and data test
+    unique_classes_total_train, dict_class_index_train = get_mapping(Y_train, "background set")
+    unique_classes_total_test, dict_class_index_test = get_mapping(Y_test, "evaluation set")
+
+    logger.info("Start building network")
+
+    # Input layers of the network
+    #######################
+    x = tf.placeholder(tf.float32, shape=[None, input_dim], name="x")  # the input to classify
+    # label number during train and test are different parameters -> need a double graph
+    y_train = tf.placeholder(tf.float32, shape=[None,  paraman["--num-class-ep-train"]], name="label_train")  # labels of input during train
+    y_test = tf.placeholder(tf.float32, shape=[None, paraman["--num-class-ep-test"]], name="label_test")  # labels of input during test
+    x_support = None
+    x_support_train = None
+    x_support_test = None
+    y_support_train = None
+    y_support_test = None
+    if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+
+        y_support_train = tf.placeholder(tf.float32, shape=[paraman["--num-supp-ex"] * paraman["--num-class-ep-train"], paraman["--num-class-ep-train"]],
+                                         name="label_support_train")
+        y_support_test = tf.placeholder(tf.float32, shape=[paraman["--num-supp-ex"] * paraman["--num-class-ep-test"], paraman["--num-class-ep-test"]],
+                                        name="label_support_test")
+    if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+        # knn classification method need a support set as input + the corresponding labels for train and test procedure
+        x_support = tf.placeholder(tf.float32, shape=[None, input_dim],
+                                   name="x_support")
+    # elif paraman["--classification-method"] == "proto":
+    #     # prototy
+    #     x_support_train = tf.placeholder(tf.float32, shape=(paraman["--num-class-ep-train"], paraman["--num-supp-ex"], input_dim),
+    #                                      name="x_support_train")
+    #     x_support_test = tf.placeholder(tf.float32, shape=(paraman["--num-class-ep-test"], paraman["--num-supp-ex"], input_dim),
+    #                                     name="x_support_test")
+
+    # Representation layers of the network
+    #####################################
+    repr_support = None
+    # repr_support_train = None
+    # repr_support_test = None
+    if paraman["network"] == "dense":
+        with tf.variable_scope("dense_method"):
+            representation_layer = Dense(paraman["--out-dim"], activation=paraman["--non-linearity"])
+            input_classif = representation_layer(x)  # compute the representation of the input, according to the representation_layer definition
+            if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+                repr_support = representation_layer(x_support)
+            # elif paraman["--classification-method"] == "proto":
+            #     repr_support_train = representation_layer(x_support_train)
+            #     repr_support_test = representation_layer(x_support_test)
+    elif paraman["network"] == "deepstrom":
+        with tf.variable_scope("deepstrom_method"):
+            subsample_indexes = data.get_uniform_class_rand_indices_train(paraman["--nys-size"])
+            nys_subsample = data.train.data[subsample_indexes]
+            logger.debug("Chosen subsample: {}".format(nys_subsample))
+            representation_layer = DeepstromLayer(subsample=nys_subsample,
+                                                  out_dim=paraman["--out-dim"],
+                                                  activation=paraman["--non-linearity"] if paraman["--non-linear"] else None,
+                                                  kernel_name=paraman["kernel"])
+
+            input_classif = representation_layer(x)  # compute the representation of the input, according to the representation_layer definition
+
+            if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+                repr_support = representation_layer(x_support)
+
+
+            # elif paraman["--classification-method"] == "proto":
+            #     repr_support_train = representation_layer(x_support_train)
+            #     repr_support_test = representation_layer(x_support_test)
+    elif paraman["network"] == "deepfriedconvnet":
+        with tf.variable_scope("deepfried_method"):
+            representation_layer = FastFoodLayer(sigma=1./get_gamma_value(paraman, X_train), nbr_stack=paraman["--nb-stack"], trainable=True)
+            input_classif = representation_layer(x)  # compute the representation of the input, according to the representation_layer definition
+            if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+                repr_support = representation_layer(x_support)
+            # elif paraman["--classification-method"] == "proto":
+            #     repr_support_train = representation_layer(x_support_train)
+            #     repr_support_test = representation_layer(x_support_test)
+    elif paraman["network"] == "none":
+            input_classif = x
+            if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+                repr_support = x_support
+    else:
+        raise Exception("Not recognized network")
+
+    # bn = BatchNormalization()
+    # input_classif = bn(input_classif)
+    # if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+    #     repr_support = bn(repr_support)
+
+    # Classification layers of the network
+    ######################################
+    if paraman["--classification-method"] == "lc":
+        # build one linear classifier layer for each train and test setting with different numbers of output units
+        with tf.name_scope("classification_train"):
+            classif_train = Dense(paraman["--num-class-ep-train"])(input_classif)
+        with tf.name_scope("classification_test"):
+            classif_test = Dense(paraman["--num-class-ep-test"])(input_classif)
+
+    elif paraman["--classification-method"] == "knn":
+        if paraman["--distance"] == "L1":
+            distance = tf.reduce_sum(tf.abs(tf.subtract(repr_support, tf.expand_dims(input_classif, 1))), axis=2)
+        elif paraman["--distance"] == "L2":
+            distance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(repr_support, tf.expand_dims(input_classif, 1))), axis=2))
+            # distance = tf_euclidean_distance(input_classif, repr_support)
+        else:
+            raise ValueError(f"Distance parameter not understood: {paraman['--distance']}")
+        top_k_xvals, top_k_indices = tf.nn.top_k(tf.negative(distance), k=paraman["--nb-neighbour"], sorted=False)
+        # retourne les valeurs des distances et les indices des k plus proches
+
+        # contient les indices des classes qui ont les meilleurs score de distance
+        # with tf.name_scope("classification_train"):
+            # nearest_neighbours_train = []
+            # for i in range(paraman["--nb-neighbour"]):
+            #     nearest_neighbours_train.append(tf.argmax(tf.gather(y_support_train, top_k_indices), axis=1))
+            # neighbor_tensor_train = tf.stack(nearest_neighbours_train)
+            # y, _, count = tf.unique_with_counts(neighbor_tensor_train)
+            # pred_train = tf.slice(y, begin=[tf.argmax(count, 0)], size=tf.constant([1], dtype=tf.int64))[0]
+            # classif_train = tf.one_hot(pred_train,
+            #                            depth=paraman["--num-class-ep-train"])
+
+        x_sums = tf.expand_dims(tf.reduce_sum(top_k_xvals, 1), 1)
+        x_sums_repeated = tf.matmul(x_sums, tf.ones([1, paraman["--nb-neighbour"]], tf.float32))
+        x_val_weights = tf.expand_dims(tf.div(top_k_xvals, x_sums_repeated), 1)
+
+        # x_val_weights donne pour chaque exemple et chacun de ses plus proches voisins, un poids de "à quel point c'est proche"
+        with tf.name_scope("classification_train"):
+            top_k_yvals_train = tf.gather(y_support_train, top_k_indices)
+            classif_train = tf.squeeze(tf.matmul(x_val_weights, top_k_yvals_train))
+        with tf.name_scope("classification_test"):
+            top_k_yvals_test = tf.gather(y_support_test, top_k_indices)
+            classif_test = tf.squeeze(tf.matmul(x_val_weights, top_k_yvals_test))
+        # retourne les labels des top k plus proches
+
+        # with tf.name_scope("classification_test"):
+        #     nearest_neighbours_test = []
+        #     for i in range(paraman["--nb-neighbour"]):
+        #         nearest_neighbours_test.append(tf.argmax(tf.gather(y_support_train, top_k_indices), axis=1))
+        #     neighbor_tensor_test = tf.stack(nearest_neighbours_test)
+        #     y, _, count = tf.unique_with_counts(neighbor_tensor_test)
+        #     pred_test = tf.slice(y, begin=[tf.argmax(count, 0)], size=tf.constant([1], dtype=tf.int64))[0]
+        #     classif_test = tf.one_hot(pred_test,
+        #                               depth=paraman["--num-class-ep-test"])
+        # pred_* is the index of the element that appears the most often in neighbor_tensor
+        # classif_* is the one hot encoding of the counterpart pred_*
+
+    elif paraman["--classification-method"] == "proto":
+        emb_dim = repr_support.shape[-1].value
+        protos_train = tf.reduce_mean(tf.reshape(repr_support, [paraman["--num-class-ep-train"], paraman["--num-supp-ex"], emb_dim]), axis=1)
+        protos_test = tf.reduce_mean(tf.reshape(repr_support, [paraman["--num-class-ep-test"], paraman["--num-supp-ex"], emb_dim]), axis=1)
+        classif_train = tf.negative(tf_euclidean_distance(input_classif, protos_train))
+        classif_test = tf.negative(tf_euclidean_distance(input_classif, protos_test))
+    else:
+        raise ValueError(f"Classification method not recognized {paraman['--classification-method']}")
+
+
+    # calcul de la loss
+    with tf.name_scope("xent_train"):
+        cross_entropy_train = tf.reduce_mean(
+            tf.nn.softmax_cross_entropy_with_logits(labels=y_train, logits=classif_train, name="xentropy"),
+            name="xentropy_mean")
+        tf.summary.scalar('loss-xent-train', cross_entropy_train)
+
+    # calcul de la loss
+    with tf.name_scope("xent_test"):
+        cross_entropy_test = tf.reduce_mean(
+            tf.nn.softmax_cross_entropy_with_logits(labels=y_test, logits=classif_test, name="xentropy"),
+            name="xentropy_mean")
+        tf.summary.scalar('loss-xent-test', cross_entropy_test)
+
+    logger.debug(f"All variables {tf.trainable_variables()}")
+    # calcul du gradient
+    if paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+        logger.debug(f"As classification method is {paraman['--classification-method']}, no need to optimize a linear classifier")
+    else:
+        with tf.name_scope("train_support_train"):
+            variables_representation_support_train = [var for var in tf.trainable_variables() if "classification_train" in var.name]
+            logger.debug(f"Variables Support train {variables_representation_support_train}")
+            train_support_optimizer_train = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(cross_entropy_train, var_list=variables_representation_support_train)    # calcul du gradient
+        with tf.name_scope("train_support_test"):
+            variables_representation_support_test = [var for var in tf.trainable_variables() if "classification_test" in var.name]
+            logger.debug(f"Variables Support test {variables_representation_support_test}")
+            train_support_optimizer_test = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(cross_entropy_test, var_list=variables_representation_support_test)
+
+    if paraman["--out-dim"] == 0:
+        logger.debug("Representation dimension is 0, there is nothing to train on background, not building optimizer query train")
+    elif paraman["network"] == "none":
+        logger.debug("No representation layer, there is nothing to train on background, not building optimizer query train")
+    else:
+        with tf.name_scope("train_query_train"):
+            variables_representation_query_train = [var for var in tf.trainable_variables() if "_method" in var.name]
+            logger.debug(f"Variables Query train {variables_representation_query_train}")
+            train_query_optimizer_train = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy_train, var_list=variables_representation_query_train)
+
+    # calcul du gradient
+
+
+    # calcul de l'accuracy
+    with tf.name_scope("accuracy_train"):
+        predictions_train = tf.argmax(classif_train, 1)
+        correct_prediction_train = tf.equal(predictions_train, tf.argmax(y_train, 1))
+        accuracy_op_train = tf.reduce_mean(tf.cast(correct_prediction_train, tf.float32))
+        tf.summary.scalar("accuracy_train", accuracy_op_train)
+    with tf.name_scope("accuracy_test"):
+        predictions_test = tf.argmax(classif_test, 1)
+        correct_prediction_test = tf.equal(predictions_test, tf.argmax(y_test, 1))
+        accuracy_op_test = tf.reduce_mean(tf.cast(correct_prediction_test, tf.float32))
+        tf.summary.scalar("accuracy_test", accuracy_op_test)
+
+    merged_summary = tf.summary.merge_all()
+
+    init = tf.global_variables_initializer()
+    # Create a session for running Ops on the Graph.
+    # Instantiate a SummaryWriter to output summaries and the Graph.
+    # summary_writer = tf.summary.FileWriter("debug_few_shot")
+    # Initialize all Variable objects
+    # actual learning
+
+    logger.info("Start Tensorflow session")
+    with tf.Session() as sess:
+        # summary_writer.add_graph(sess.graph)
+        sess.run(init)
+        if paraman["--real-nystrom"] and paraman["network"] == "deepstrom":
+            logger.debug("Using real nystrom doesn't need training W on background, skip train onbackground")
+        elif paraman["--out-dim"] == 0:
+            logger.debug(
+                "Representation dimension is 0, there is nothing to train on background, skip train on background")
+        elif paraman["network"] == "none":
+            logger.debug("No representation function, there is nothing to train on background, skip train on background")
+        else:
+            logger.debug("Start training on background")
+            j = 0
+            while j < paraman["--num-episode-train"]:
+                # pas de reinitialisation du lc, jugée inutile
+                loss_supp, acc_supp = None, None
+                if paraman["--classification-method"] == "proto":
+                    shuffle_bool = False
+                else:
+                    shuffle_bool = True
+                support, query = get_episode_data(X_train,
+                                                  dict_class_index=dict_class_index_train,
+                                                  unique_classes_total=unique_classes_total_train,
+                                                  nb_class_episode=paraman["--num-class-ep-train"],
+                                                  nb_examples_support=paraman["--num-supp-ex"],
+                                                  nb_examples_query=paraman["--num-query-ex"],
+                                                  shuffle=shuffle_bool)
+                if paraman["--classification-method"] == "lc":
+                    feed_dict_support = {x: support.data, y_train: support.labels}
+                    k = 0
+                    while k < paraman["--num-epoch"]:
+                        _, loss_supp, acc_supp = sess.run([train_support_optimizer_train, cross_entropy_train, accuracy_op_train], feed_dict=feed_dict_support)
+                        k += 1
+                if paraman["--classification-method"] == "lc":
+                    feed_dict_query = {x: query.data, y_train: query.labels}
+                elif paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+                    feed_dict_query = {x: query.data,
+                                       y_train: query.labels,
+                                       x_support: support.data,
+                                       y_support_train:support.labels}
+                # elif paraman["--classification-method"] == "proto":
+                #     feed_dict_query = {
+                #         x: query.data,
+                #         y_train: query.labels,
+                #         x_support_train: support.data,
+                #         y_support_train: support.labels
+                #     }
+                else:
+                    raise ValueError
+                _, loss_query, acc_query = sess.run([train_query_optimizer_train, cross_entropy_train, accuracy_op_train], feed_dict=feed_dict_query)
+                if (j+1) % 10 == 0:
+                    logger.debug("Episode train {}; Loss support: {}; acc support: {}; Loss query: {}; acc query: {}".format(j, loss_supp, acc_supp, loss_query, acc_query))
+                    # summary_str = sess.run(merged_summary, feed_dict=feed_dict_query)
+                    # summary_writer.add_summary(summary_str, j)
+                j += 1
+
+        if paraman["--classification-method"] == "proto":
+            shuffle_bool = False
+        else:
+            shuffle_bool = True
+        logger.debug("Start evaluation")
+        accuracies = []
+        j = 0
+        while j < paraman["--num-episode-test"]:
+            # pas de reinitialisation du lc, jugée inutile
+            loss_supp, acc_supp = None, None
+            support, query = get_episode_data(X_test,
+                                              dict_class_index=dict_class_index_test,
+                                              unique_classes_total=unique_classes_total_test,
+                                              nb_class_episode=paraman["--num-class-ep-test"],
+                                              nb_examples_support=paraman["--num-supp-ex"],
+                                              nb_examples_query=paraman["--num-query-ex"],
+                                              shuffle=shuffle_bool)
+            if paraman["--classification-method"] == "lc":
+                feed_dict_support = {x: support.data, y_test: support.labels}
+                k = 0
+                while k < paraman["--num-epoch"]:
+                    _, loss_supp, acc_supp = sess.run([train_support_optimizer_test, cross_entropy_test, accuracy_op_test],
+                                                      feed_dict=feed_dict_support)
+                    k += 1
+            if paraman["--classification-method"] == "lc":
+                feed_dict_query = {x: query.data, y_test: query.labels}
+            elif paraman["--classification-method"] == "knn" or paraman["--classification-method"] == "proto":
+                feed_dict_query = {x: query.data, y_test: query.labels,
+                                   x_support: support.data,
+                                   y_support_test: support.labels}
+
+            _, acc_query = sess.run([cross_entropy_test, accuracy_op_test], feed_dict=feed_dict_query)
+            accuracies.append(acc_query)
+            if (j + 1) % 10 == 0:
+                logger.debug(
+                    "Episode test {}; Loss support: {}; acc support: {}; acc query: {}".format(j, loss_supp,
+                                                                                               acc_supp,
+                                                                                               acc_query))
+            j += 1
+
+        resman["accuracy_test"] = np.mean(accuracies)
+        STOP = time.time()
+        resman["total_time"] = STOP-START
+        print_result()
\ No newline at end of file