From 3aced0d558b642498412cb75ecdf4fac7b19cafd Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Mon, 1 Apr 2019 23:35:11 -0400 Subject: [PATCH] forgotten fits --- .../MonoviewClassifiers/SCMPregen.py | 14 +++++++++++--- .../MonoviewClassifiers/SCMPregenTree.py | 14 +++++++++++--- .../MonoviewClassifiers/SCMSparsity.py | 14 +++++++++++--- .../MonoviewClassifiers/SCMSparsityTree.py | 14 +++++++++++--- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregen.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregen.py index b79606df..6ee4f62c 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregen.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregen.py @@ -49,9 +49,17 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier): def predict(self, X): pregen_X, _ = self.pregen_voters(X) - np.savetxt("pregen_x.csv", pregen_X, delimiter=',') - place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') - os.remove("pregen_x.csv") + list_files = os.listdir(".") + if "pregen_x.csv" in list_files: + i = 0 + file_name = "pregen_x" + str(i) + ".csv" + while file_name in list_files: + i += 1 + else: + file_name = "pregen_x.csv" + np.savetxt(file_name, pregen_X, delimiter=',') + place_holder = np.genfromtxt(file_name, delimiter=',') + os.remove(file_name) return self.classes_[self.model_.predict(place_holder)] def get_params(self, deep=True): diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregenTree.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregenTree.py index 234ca0c3..b2f59c39 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregenTree.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMPregenTree.py @@ -35,9 +35,17 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier): def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params): pregen_X, _ = self.pregen_voters(X, y, generator="Trees") - np.savetxt("pregen_x.csv", pregen_X, delimiter=',') - place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') - os.remove("pregen_x.csv") + list_files = os.listdir(".") + if "pregen_x.csv" in list_files: + i = 0 + file_name = "pregen_x" + str(i) + ".csv" + while file_name in list_files: + i += 1 + else: + file_name = "pregen_x.csv" + np.savetxt(file_name, pregen_X, delimiter=',') + place_holder = np.genfromtxt(file_name, delimiter=',') + os.remove(file_name) super(SCMPregenTree, self).fit(place_holder, y, tiebreaker=tiebreaker, iteration_callback=iteration_callback, **fit_params) return self diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsity.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsity.py index bd354136..05f7e257 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsity.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsity.py @@ -42,9 +42,17 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier): def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params): pregen_X, _ = self.pregen_voters(X, y) - np.savetxt("pregen_x.csv", pregen_X, delimiter=',') - place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') - os.remove("pregen_x.csv") + list_files = os.listdir(".") + if "pregen_x.csv" in list_files: + i = 0 + file_name = "pregen_x" + str(i) + ".csv" + while file_name in list_files: + i += 1 + else: + file_name = "pregen_x.csv" + np.savetxt(file_name, pregen_X, delimiter=',') + place_holder = np.genfromtxt(file_name, delimiter=',') + os.remove(file_name) for scm_estimator in self.scm_estimators: beg = time.time() scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params) diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsityTree.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsityTree.py index 2fbd7c39..e7cbad62 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsityTree.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/SCMSparsityTree.py @@ -42,9 +42,17 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier): def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params): pregen_X, _ = self.pregen_voters(X, y, generator="Trees") - np.savetxt("pregen_x.csv", pregen_X, delimiter=',') - place_holder = np.genfromtxt("pregen_x.csv", delimiter=',') - os.remove("pregen_x.csv") + list_files = os.listdir(".") + if "pregen_x.csv" in list_files: + i = 0 + file_name = "pregen_x" + str(i) + ".csv" + while file_name in list_files: + i += 1 + else: + file_name = "pregen_x.csv" + np.savetxt(file_name, pregen_X, delimiter=',') + place_holder = np.genfromtxt(file_name, delimiter=',') + os.remove(file_name) for scm_estimator in self.scm_estimators: beg = time.time() scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params) -- GitLab