from sklearn.externals.six import iteritems
from pyscm.scm import SetCoveringMachineClassifier as scm
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
import os

from ..Monoview.MonoviewUtils import CustomRandint, CustomUniform, BaseMonoviewClassifier, change_label_to_minus, change_label_to_zero
from ..Monoview.Additions.BoostUtils import StumpsClassifiersGenerator, BaseBoost
from ..Monoview.Additions.PregenUtils import PregenClassifier
# Author-Info
__author__ = "Baptiste Bauvin"
__status__ = "Prototype"  # Production, Development, Prototype

class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):

    def __init__(self, random_state=None, model_type="conjunction",
                 max_rules=10, p=0.1, n_stumps=10,self_complemented=True, **kwargs):
        super(SCMPregen, self).__init__(
            random_state=random_state,
            model_type=model_type,
            max_rules=max_rules,
            p=p
            )
        self.param_names = ["model_type", "max_rules", "p", "n_stumps", "random_state"]
        self.distribs = [["conjunction", "disjunction"],
                         CustomRandint(low=1, high=15),
                         CustomUniform(loc=0, state=1), [n_stumps], [random_state]]
        self.classed_params = []
        self.weird_strings = {}
        self.self_complemented = self_complemented
        self.n_stumps = n_stumps
        self.estimators_generator = "Stumps"

    def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
        pregen_X, _ = self.pregen_voters(X, y)
        list_files = os.listdir(".")
        if "pregen_x.csv" in list_files:
            i = 0
            file_name = "pregen_x" + str(i) + ".csv"
            while file_name in list_files:
                i += 1
        else:
            file_name="pregen_x.csv"
        np.savetxt(file_name, pregen_X, delimiter=',')
        place_holder = np.genfromtxt(file_name, delimiter=',')
        os.remove(file_name)
        super(SCMPregen, self).fit(place_holder, y, tiebreaker=tiebreaker, iteration_callback=iteration_callback, **fit_params)
        return self

    def predict(self, X):
        pregen_X, _ = self.pregen_voters(X)
        np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
        place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
        os.remove("pregen_x.csv")
        return self.classes_[self.model_.predict(place_holder)]

    def get_params(self, deep=True):
        return {"p": self.p, "model_type": self.model_type,
         "max_rules": self.max_rules,
         "random_state": self.random_state, "n_stumps":self.n_stumps}

    def canProbas(self):
        """Used to know if the classifier can return label probabilities"""
        return False

    def getInterpret(self, directory, y_test):
        interpretString = "Model used : " + str(self.model_)
        return interpretString


def formatCmdArgs(args):
    """Used to format kwargs for the parsed args"""
    kwargsDict = {"model_type": args.SCP_model_type,
                  "p": args.SCP_p,
                  "max_rules": args.SCP_max_rules,
                  "n_stumps": args.SCP_stumps}
    return kwargsDict


def paramsToSet(nIter, randomState):
    paramsSet = []
    for _ in range(nIter):
        paramsSet.append({"model_type": randomState.choice(["conjunction", "disjunction"]),
                          "max_rules": randomState.randint(1, 15),
                          "p": randomState.random_sample()})
    return paramsSet