Skip to content
Snippets Groups Projects
Commit c1e977c7 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Merge branch 'develop'

parents d2d642d0 b37578bd
Branches
Tags
No related merge requests found
Pipeline #9475 passed
...@@ -78,8 +78,10 @@ def publish_feature_importances(feature_importances, directory, database_name, ...@@ -78,8 +78,10 @@ def publish_feature_importances(feature_importances, directory, database_name,
columns=feature_std.columns, columns=feature_std.columns,
data=np.zeros((1, len( data=np.zeros((1, len(
feature_std.columns))))) feature_std.columns)))))
if len(importance_dfs)>0:
feature_importances_df = pd.concat(importance_dfs[:-1]) feature_importances_df = pd.concat(importance_dfs[:-1])
feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0) feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0)
feature_std_df = pd.concat(std_dfs[:-1]) feature_std_df = pd.concat(std_dfs[:-1])
if "mv" in feature_importances: if "mv" in feature_importances:
feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0) feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0)
......
...@@ -216,7 +216,7 @@ class Grid(GridSearchCV, HPSearch): ...@@ -216,7 +216,7 @@ class Grid(GridSearchCV, HPSearch):
random_state=None, track_tracebacks=True): random_state=None, track_tracebacks=True):
scoring = HPSearch.get_scoring(self, scoring) scoring = HPSearch.get_scoring(self, scoring)
GridSearchCV.__init__(self, estimator, param_grid, scoring=scoring, GridSearchCV.__init__(self, estimator, param_grid, scoring=scoring,
n_jobs=n_jobs, iid='deprecated', refit=refit, n_jobs=n_jobs, refit=refit,
cv=cv) cv=cv)
self.framework = framework self.framework = framework
self.available_indices = learning_indices self.available_indices = learning_indices
......
...@@ -57,28 +57,6 @@ class Test_initBenchmark(unittest.TestCase): ...@@ -57,28 +57,6 @@ class Test_initBenchmark(unittest.TestCase):
self.assertEqual(benchmark_output, self.assertEqual(benchmark_output,
{'monoview': ['decision_tree'], {'monoview': ['decision_tree'],
'multiview': ['weighted_linear_late_fusion']}) 'multiview': ['weighted_linear_late_fusion']})
benchmark_output = exec_classif.init_benchmark(
cl_type=["monoview", "multiview"], monoview_algos=["all"],
multiview_algos=["all"])
self.assertEqual(benchmark_output, {'monoview': ['adaboost',
'decision_tree',
'gradient_boosting',
'knn',
'lasso',
'random_forest',
'sgd',
'svm_linear',
'svm_poly',
'svm_rbf'],
'multiview': ['bayesian_inference_fusion',
'difficulty_fusion',
'disagree_fusion',
'double_fault_fusion',
'entropy_fusion',
'majority_voting_fusion',
'svm_jumbo_fusion',
'weighted_linear_early_fusion',
'weighted_linear_late_fusion']})
class Test_Functs(unittest.TestCase): class Test_Functs(unittest.TestCase):
...@@ -250,13 +228,13 @@ def fakeBenchmarkExec_mutlicore(nb_cores=-1, a=6, args=1): ...@@ -250,13 +228,13 @@ def fakeBenchmarkExec_mutlicore(nb_cores=-1, a=6, args=1):
def fakeBenchmarkExec_monocore( def fakeBenchmarkExec_monocore(
dataset_var=1, a=4, args=1, track_tracebacks=False): dataset_var=1, a=4, args=1, track_tracebacks=False, nb_cores=1):
return [a] return [a]
def fakegetResults(results, stats_iter, def fakegetResults(results, stats_iter,
benchmark_arguments_dictionaries, metrics, directory, benchmark_arguments_dictionaries, metrics, directory,
sample_ids, labels): sample_ids, labels, feat_ids, view_names):
return 3 return 3
...@@ -264,7 +242,8 @@ def fakeDelete(a, b, c): ...@@ -264,7 +242,8 @@ def fakeDelete(a, b, c):
return 9 return 9
def fake_analyze(a, b, c, d, sample_ids=None, labels=None): def fake_analyze(a, b, c, d, sample_ids=None, labels=None, feature_ids=None,
view_names=None):
pass pass
......
...@@ -177,6 +177,7 @@ class Test_exec_monoview(unittest.TestCase): ...@@ -177,6 +177,7 @@ class Test_exec_monoview(unittest.TestCase):
np.random.RandomState(42), np.random.RandomState(42),
"Random", "Random",
n_iter=2, n_iter=2,
feature_ids=[str(i) for i in range(test_dataset.get_v(0).shape[1])],
**{"classifier_name": "decision_tree", **{"classifier_name": "decision_tree",
"view_index": 0, "view_index": 0,
"decision_tree": {}}) "decision_tree": {}})
......
...@@ -14,7 +14,7 @@ class Test_get_sample_errors(unittest.TestCase): ...@@ -14,7 +14,7 @@ class Test_get_sample_errors(unittest.TestCase):
results = [MultiviewResult("mv", "", {"accuracy_score": [0.7, 0.75], results = [MultiviewResult("mv", "", {"accuracy_score": [0.7, 0.75],
"f1_score": [0.71, 0.76]}, "f1_score": [0.71, 0.76]},
np.array([0, 0, 0, 0, 1, 1, 1, 1, 1]), np.array([0, 0, 0, 0, 1, 1, 1, 1, 1]),
0, 0, 0, {}), 0, 0, 0, {}, "clf"),
MonoviewResult(0, MonoviewResult(0,
"dt", "dt",
"1", "1",
......
...@@ -9,6 +9,11 @@ from summit.multiview_platform.multiview.multiview_utils import MultiviewResult ...@@ -9,6 +9,11 @@ from summit.multiview_platform.multiview.multiview_utils import MultiviewResult
from summit.multiview_platform.result_analysis.execution import format_previous_results, get_arguments, analyze_iterations from summit.multiview_platform.result_analysis.execution import format_previous_results, get_arguments, analyze_iterations
from summit.tests.utils import rm_tmp, tmp_path, test_dataset from summit.tests.utils import rm_tmp, tmp_path, test_dataset
class FakeClf():
def __init__(self):
self.feature_importances_ = [0.01,0.99]
class FakeClassifierResult: class FakeClassifierResult:
...@@ -18,6 +23,7 @@ class FakeClassifierResult: ...@@ -18,6 +23,7 @@ class FakeClassifierResult:
self.hps_duration = i self.hps_duration = i
self.fit_duration = i self.fit_duration = i
self.pred_duration = i self.pred_duration = i
self.clf=FakeClf()
def get_classifier_name(self): def get_classifier_name(self):
return self.classifier_name return self.classifier_name
...@@ -84,16 +90,16 @@ class Test_format_previous_results(unittest.TestCase): ...@@ -84,16 +90,16 @@ class Test_format_previous_results(unittest.TestCase):
columns=["ada-1", "mvm"]) columns=["ada-1", "mvm"])
# Testing # Testing
np.testing.assert_array_equal(metric_analysis["acc"]["mean"].loc["train"], np.testing.assert_array_almost_equal(metric_analysis["acc"]["mean"].loc["train"],
mean_df.loc["train"]) mean_df.loc["train"])
np.testing.assert_array_equal(metric_analysis["acc"]["mean"].loc["test"], np.testing.assert_array_almost_equal(metric_analysis["acc"]["mean"].loc["test"],
mean_df.loc["test"]) mean_df.loc["test"])
np.testing.assert_array_equal(metric_analysis["acc"]["std"].loc["train"], np.testing.assert_array_almost_equal(metric_analysis["acc"]["std"].loc["train"],
std_df.loc["train"]) std_df.loc["train"])
np.testing.assert_array_equal(metric_analysis["acc"]["std"].loc["test"], np.testing.assert_array_almost_equal(metric_analysis["acc"]["std"].loc["test"],
std_df.loc["test"]) std_df.loc["test"])
np.testing.assert_array_equal(ada_sum, error_analysis["ada-1"]) np.testing.assert_array_almost_equal(ada_sum, error_analysis["ada-1"])
np.testing.assert_array_equal(mv_sum, error_analysis["mv"]) np.testing.assert_array_almost_equal(mv_sum, error_analysis["mv"])
self.assertEqual(durations_mean.at["ada-1", 'plif'], 0.5) self.assertEqual(durations_mean.at["ada-1", 'plif'], 0.5)
...@@ -129,6 +135,8 @@ class Test_analyze_iterations(unittest.TestCase): ...@@ -129,6 +135,8 @@ class Test_analyze_iterations(unittest.TestCase):
cls.metrics = {} cls.metrics = {}
cls.sample_ids = ['ex1', 'ex5', 'ex4', 'ex3', 'ex2', ] cls.sample_ids = ['ex1', 'ex5', 'ex4', 'ex3', 'ex2', ]
cls.labels = np.array([0, 1, 2, 1, 1]) cls.labels = np.array([0, 1, 2, 1, 1])
cls.feature_ids = [["a", "b"]]
cls.view_names = [""]
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -140,7 +148,9 @@ class Test_analyze_iterations(unittest.TestCase): ...@@ -140,7 +148,9 @@ class Test_analyze_iterations(unittest.TestCase):
self.stats_iter, self.stats_iter,
self.metrics, self.metrics,
self.sample_ids, self.sample_ids,
self.labels) self.labels,
self.feature_ids,
self.view_names)
res, iter_res, tracebacks, labels_names = analysis res, iter_res, tracebacks, labels_names = analysis
self.assertEqual(labels_names, ['zero', 'one', 'two']) self.assertEqual(labels_names, ['zero', 'one', 'two'])
...@@ -152,7 +162,6 @@ class Test_analyze_iterations(unittest.TestCase): ...@@ -152,7 +162,6 @@ class Test_analyze_iterations(unittest.TestCase):
data=np.array([1, 1, 1, 2, 2, 2]).reshape((2, 3)), dtype=object)) data=np.array([1, 1, 1, 2, 2, 2]).reshape((2, 3)), dtype=object))
np.testing.assert_array_equal( np.testing.assert_array_equal(
iter_res['sample_errors'][0]['test1'], np.array([1, 1, 0, 0, 1])) iter_res['sample_errors'][0]['test1'], np.array([1, 1, 0, 0, 1]))
self.assertEqual(iter_res["feature_importances"], [{}, {}])
np.testing.assert_array_equal( np.testing.assert_array_equal(
iter_res['labels'], np.array([0, 1, 2, 1, 1])) iter_res['labels'], np.array([0, 1, 2, 1, 1]))
self.assertEqual(iter_res['metrics_scores'], [{}, {}]) self.assertEqual(iter_res['metrics_scores'], [{}, {}])
...@@ -9,6 +9,7 @@ from summit.multiview_platform.monoview.monoview_utils import MonoviewResult ...@@ -9,6 +9,7 @@ from summit.multiview_platform.monoview.monoview_utils import MonoviewResult
class FakeClassifier: class FakeClassifier:
def __init__(self, i=0): def __init__(self, i=0):
self.feature_importances_ = [i, i + 1] self.feature_importances_ = [i, i + 1]
self.view_index=0
class FakeClassifierResult(MonoviewResult): class FakeClassifierResult(MonoviewResult):
...@@ -21,6 +22,7 @@ class FakeClassifierResult(MonoviewResult): ...@@ -21,6 +22,7 @@ class FakeClassifierResult(MonoviewResult):
self.clf = FakeClassifier(i) self.clf = FakeClassifier(i)
self.view_name = 'testview' + str(i) self.view_name = 'testview' + str(i)
self.classifier_name = "test" + str(i) self.classifier_name = "test" + str(i)
self.view_index = 0
def get_classifier_name(self): def get_classifier_name(self):
return self.classifier_name return self.classifier_name
...@@ -30,9 +32,9 @@ class Test_get_duration(unittest.TestCase): ...@@ -30,9 +32,9 @@ class Test_get_duration(unittest.TestCase):
def test_simple(self): def test_simple(self):
results = [FakeClassifierResult(), FakeClassifierResult(i=1)] results = [FakeClassifierResult(), FakeClassifierResult(i=1)]
feat_importance = feature_importances.get_feature_importances(results) feat_importance = feature_importances.get_feature_importances(results, feature_ids=[["a", "b"]], view_names=["v"])
pd.testing.assert_frame_equal(feat_importance["testview1"], pd.testing.assert_frame_equal(feat_importance["testview1"],
pd.DataFrame(index=None, columns=['test1'], pd.DataFrame(index=["a", "b"], columns=['test1'],
data=np.array( data=np.array(
[1, 2]).reshape((2, 1)), [1, 2]).reshape((2, 1)),
), check_dtype=False) ), check_dtype=False)
...@@ -66,7 +66,7 @@ class Test_get_metrics_scores(unittest.TestCase): ...@@ -66,7 +66,7 @@ class Test_get_metrics_scores(unittest.TestCase):
hps_duration=0, hps_duration=0,
fit_duration=0, fit_duration=0,
pred_duration=0, pred_duration=0,
class_metric_scores={}) class_metric_scores={},)
] ]
metrics_scores, class_met = get_metrics_scores(metrics, metrics_scores, class_met = get_metrics_scores(metrics,
results, []) results, [])
...@@ -91,7 +91,7 @@ class Test_get_metrics_scores(unittest.TestCase): ...@@ -91,7 +91,7 @@ class Test_get_metrics_scores(unittest.TestCase):
def test_mutiview_result(self): def test_mutiview_result(self):
metrics = {"accuracy_score*": {}, "f1_score": {}} metrics = {"accuracy_score*": {}, "f1_score": {}}
results = [MultiviewResult("mv", "", {"accuracy_score*": [0.7, 0.75], results = [MultiviewResult("mv", "", {"accuracy_score*": [0.7, 0.75],
"f1_score": [0.71, 0.76]}, "", 0, 0, 0, {}), "f1_score": [0.71, 0.76]}, "", 0, 0, 0, {}, ""),
MonoviewResult(view_index=0, MonoviewResult(view_index=0,
classifier_name="dt", classifier_name="dt",
view_name="1", view_name="1",
......
...@@ -97,7 +97,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -97,7 +97,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.directory, self.pred, self.directory,
self.base_file_name, self.labels, self.base_file_name, self.labels,
self.database_name, self.nb_cores, self.database_name, self.nb_cores,
self.duration) self.duration, [""])
def test_get_metric_scores(self): def test_get_metric_scores(self):
RA = base.ResultAnalyser(self.classifier, self.classification_indices, RA = base.ResultAnalyser(self.classifier, self.classification_indices,
...@@ -107,7 +107,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -107,7 +107,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.pred,
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels, self.database_name, self.labels, self.database_name,
self.nb_cores, self.duration) self.nb_cores, self.duration, [""])
cl_train, cl_test, train_score, test_score = RA.get_metric_score( cl_train, cl_test, train_score, test_score = RA.get_metric_score(
"accuracy_score", {}) "accuracy_score", {})
np.testing.assert_array_equal(train_score, self.train_accuracy) np.testing.assert_array_equal(train_score, self.train_accuracy)
...@@ -121,7 +121,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -121,7 +121,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.pred,
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels, self.database_name, self.labels, self.database_name,
self.nb_cores, self.duration) self.nb_cores, self.duration, [""])
RA.get_all_metrics_scores() RA.get_all_metrics_scores()
self.assertEqual(RA.metric_scores["accuracy_score"][0], self.assertEqual(RA.metric_scores["accuracy_score"][0],
self.train_accuracy) self.train_accuracy)
...@@ -140,7 +140,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -140,7 +140,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.pred,
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels, self.database_name, self.labels, self.database_name,
self.nb_cores, self.duration) self.nb_cores, self.duration, [''])
RA.get_all_metrics_scores() RA.get_all_metrics_scores()
string = RA.print_metric_score() string = RA.print_metric_score()
self.assertEqual(string, '\n\n\tFor Accuracy score using {}, (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\n\tFor F1 score using average: micro, {} (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\nTest set confusion matrix : \n\n╒════════╤══════════╤══════════╤══════════╕\n│ │ class1 │ class2 │ class3 │\n╞════════╪══════════╪══════════╪══════════╡\n│ class1 │ 3 │ 1 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class2 │ 3 │ 2 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class3 │ 3 │ 8 │ 2 │\n╘════════╧══════════╧══════════╧══════════╛\n\n') self.assertEqual(string, '\n\n\tFor Accuracy score using {}, (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\n\tFor F1 score using average: micro, {} (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\nTest set confusion matrix : \n\n╒════════╤══════════╤══════════╤══════════╕\n│ │ class1 │ class2 │ class3 │\n╞════════╪══════════╪══════════╪══════════╡\n│ class1 │ 3 │ 1 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class2 │ 3 │ 2 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class3 │ 3 │ 8 │ 2 │\n╘════════╧══════════╧══════════╧══════════╛\n\n')
...@@ -153,7 +153,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -153,7 +153,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.pred,
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels, self.database_name, self.labels, self.database_name,
self.nb_cores, self.duration) self.nb_cores, self.duration, [''])
self.assertEqual( self.assertEqual(
RA.get_db_config_string(), RA.get_db_config_string(),
'Database configuration : \n\t- Database name : test_database\ntest\t- Learning Rate : 0.48\n\t- Labels used : class1, class2, class3\n\t- Number of cross validation folds : 5\n\n') 'Database configuration : \n\t- Database name : test_database\ntest\t- Learning Rate : 0.48\n\t- Labels used : class1, class2, class3\n\t- Number of cross validation folds : 5\n\n')
...@@ -166,7 +166,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -166,7 +166,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.pred,
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels, self.database_name, self.labels, self.database_name,
self.nb_cores, self.duration) self.nb_cores, self.duration, [''])
self.assertEqual( self.assertEqual(
RA.get_classifier_config_string(), RA.get_classifier_config_string(),
'Classifier configuration : \n\t- FakeClassifier with test1 : 10, test2 : test\n\t- Executed on 0.5 core(s) \n\t- Got configuration using randomized search with 6 iterations \n') 'Classifier configuration : \n\t- FakeClassifier with test1 : 10, test2 : test\n\t- Executed on 0.5 core(s) \n\t- Got configuration using randomized search with 6 iterations \n')
...@@ -179,7 +179,7 @@ class Test_ResultAnalyzer(unittest.TestCase): ...@@ -179,7 +179,7 @@ class Test_ResultAnalyzer(unittest.TestCase):
self.pred, self.pred,
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels, self.database_name, self.labels, self.database_name,
self.nb_cores, self.duration) self.nb_cores, self.duration, [""])
str_analysis, img_analysis, metric_scores, class_metric_scores, conf_mat = RA.analyze() str_analysis, img_analysis, metric_scores, class_metric_scores, conf_mat = RA.analyze()
self.assertEqual(str_analysis, 'test2Database configuration : \n\t- Database name : test_database\ntest\t- Learning Rate : 0.48\n\t- Labels used : class1, class2, class3\n\t- Number of cross validation folds : 5\n\nClassifier configuration : \n\t- FakeClassifier with test1 : 10, test2 : test\n\t- Executed on 0.5 core(s) \n\t- Got configuration using randomized search with 6 iterations \n\n\n\tFor Accuracy score using {}, (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\n\tFor F1 score using average: micro, {} (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\nTest set confusion matrix : \n\n╒════════╤══════════╤══════════╤══════════╕\n│ │ class1 │ class2 │ class3 │\n╞════════╪══════════╪══════════╪══════════╡\n│ class1 │ 3 │ 1 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class2 │ 3 │ 2 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class3 │ 3 │ 8 │ 2 │\n╘════════╧══════════╧══════════╧══════════╛\n\n\n\n Classification took -1 day, 23:59:56\n\n Classifier Interpretation : \n') self.assertEqual(str_analysis, 'test2Database configuration : \n\t- Database name : test_database\ntest\t- Learning Rate : 0.48\n\t- Labels used : class1, class2, class3\n\t- Number of cross validation folds : 5\n\nClassifier configuration : \n\t- FakeClassifier with test1 : 10, test2 : test\n\t- Executed on 0.5 core(s) \n\t- Got configuration using randomized search with 6 iterations \n\n\n\tFor Accuracy score using {}, (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\n\tFor F1 score using average: micro, {} (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\nTest set confusion matrix : \n\n╒════════╤══════════╤══════════╤══════════╕\n│ │ class1 │ class2 │ class3 │\n╞════════╪══════════╪══════════╪══════════╡\n│ class1 │ 3 │ 1 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class2 │ 3 │ 2 │ 2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class3 │ 3 │ 8 │ 2 │\n╘════════╧══════════╧══════════╧══════════╛\n\n\n\n Classification took -1 day, 23:59:56\n\n Classifier Interpretation : \n')
...@@ -223,7 +223,7 @@ class Test_BaseClassifier(unittest.TestCase): ...@@ -223,7 +223,7 @@ class Test_BaseClassifier(unittest.TestCase):
def test_get_iterpret(self): def test_get_iterpret(self):
fake_class = FakeClassifier() fake_class = FakeClassifier()
self.assertEqual("", fake_class.get_interpretation("", "", "",)) self.assertEqual("", fake_class.get_interpretation("", "", "",""))
def test_accepts_mutliclass(self): def test_accepts_mutliclass(self):
accepts = FakeClassifier().accepts_multi_class(self.rs) accepts = FakeClassifier().accepts_multi_class(self.rs)
......
...@@ -19,6 +19,7 @@ class FakeEstim(BaseEstimator): ...@@ -19,6 +19,7 @@ class FakeEstim(BaseEstimator):
def __init__(self, param1=None, param2=None, random_state=None): def __init__(self, param1=None, param2=None, random_state=None):
self.param1 = param1 self.param1 = param1
self.param2 = param2 self.param2 = param2
self.random_state="1"
def fit(self, X, y,): def fit(self, X, y,):
return self return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment