diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/MinCQGraalpy.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/MinCQGraalpy.py index b638c6fbd47584dc1894c85fe84db83b77b7c9c8..b1bf1dfed06f3f55efd2bf043f397e6e55efbc6a 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/MinCQGraalpy.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/MinCQGraalpy.py @@ -331,8 +331,10 @@ class MinCQGraalpy(RegularizedBinaryMinCqClassifier, BaseMonoviewClassifier): self.param_names = ["mu"] self.distribs = [CustomUniform(loc=0.5, state=2.0, multiplier="e-"), ] + self.n_stumps_per_attribute = n_stumps_per_attribute self.classed_params = [] self.weird_strings = {} + self.random_state = random_state if "nbCores" not in kwargs: self.nbCores = 1 else: @@ -342,6 +344,12 @@ class MinCQGraalpy(RegularizedBinaryMinCqClassifier, BaseMonoviewClassifier): """Used to know if the classifier can return label probabilities""" return True + def set_params(self, **params): + self.mu = params["mu"] + + def get_params(self, deep=True): + return {"random_state":self.random_state, "mu":self.mu} + def getInterpret(self, directory, y_test): interpret_string = "" # interpret_string += "Train C_bound value : "+str(self.cbound_train) diff --git a/multiview_platform/MonoMultiViewClassifiers/utils/execution.py b/multiview_platform/MonoMultiViewClassifiers/utils/execution.py index 9a3a91d49af6345c2c0254d404cee54918a6f8b0..2fc233f2a1e8725c13d431b6d45ba851a48a2f58 100644 --- a/multiview_platform/MonoMultiViewClassifiers/utils/execution.py +++ b/multiview_platform/MonoMultiViewClassifiers/utils/execution.py @@ -125,7 +125,7 @@ def parseTheArgs(arguments): groupAdaboostPregen = parser.add_argument_group('AdaboostPregen arguments') groupAdaboostPregen.add_argument('--AdP_n_est', metavar='INT', type=int, action='store', help='Number of estimators', - default=2) + default=100) groupAdaboostPregen.add_argument('--AdP_b_est', metavar='STRING', action='store', help='Estimators', default='DecisionTreeClassifier')