Skip to content
Snippets Groups Projects
Commit 3aced0d5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

forgotten fits

parent 981e0c78
No related branches found
No related tags found
No related merge requests found
...@@ -49,9 +49,17 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -49,9 +49,17 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
def predict(self, X): def predict(self, X):
pregen_X, _ = self.pregen_voters(X) pregen_X, _ = self.pregen_voters(X)
np.savetxt("pregen_x.csv", pregen_X, delimiter=',') list_files = os.listdir(".")
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') if "pregen_x.csv" in list_files:
os.remove("pregen_x.csv") i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
return self.classes_[self.model_.predict(place_holder)] return self.classes_[self.model_.predict(place_holder)]
def get_params(self, deep=True): def get_params(self, deep=True):
......
...@@ -35,9 +35,17 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -35,9 +35,17 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params): def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y, generator="Trees") pregen_X, _ = self.pregen_voters(X, y, generator="Trees")
np.savetxt("pregen_x.csv", pregen_X, delimiter=',') list_files = os.listdir(".")
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') if "pregen_x.csv" in list_files:
os.remove("pregen_x.csv") i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
super(SCMPregenTree, self).fit(place_holder, y, tiebreaker=tiebreaker, iteration_callback=iteration_callback, **fit_params) super(SCMPregenTree, self).fit(place_holder, y, tiebreaker=tiebreaker, iteration_callback=iteration_callback, **fit_params)
return self return self
......
...@@ -42,9 +42,17 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier): ...@@ -42,9 +42,17 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params): def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y) pregen_X, _ = self.pregen_voters(X, y)
np.savetxt("pregen_x.csv", pregen_X, delimiter=',') list_files = os.listdir(".")
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') if "pregen_x.csv" in list_files:
os.remove("pregen_x.csv") i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
for scm_estimator in self.scm_estimators: for scm_estimator in self.scm_estimators:
beg = time.time() beg = time.time()
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params) scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
......
...@@ -42,9 +42,17 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier): ...@@ -42,9 +42,17 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params): def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y, generator="Trees") pregen_X, _ = self.pregen_voters(X, y, generator="Trees")
np.savetxt("pregen_x.csv", pregen_X, delimiter=',') list_files = os.listdir(".")
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') if "pregen_x.csv" in list_files:
os.remove("pregen_x.csv") i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
for scm_estimator in self.scm_estimators: for scm_estimator in self.scm_estimators:
beg = time.time() beg = time.time()
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params) scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment