diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py index d6ff8b4c8a7f585707959ad055accbe074afc8d8..f30aa6a43fbde1f73a8be8191ed0a0f4ad0d3d9f 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py @@ -137,11 +137,15 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier): nb_clfs = nb_monoview_per_view else: nb_clfs = nb_view + if isinstance(self.classifiers_names, ClassifierDistribution): self.classifiers_names = self.classifiers_names.draw(nb_clfs, self.rs) elif self.classifiers_names is None: self.classifiers_names = ["decision_tree" for _ in range(nb_clfs)] + elif isinstance(self.classifiers_names, str): + self.classifiers_names = [self.classifiers_names + for _ in range(nb_clfs)] if isinstance(self.classifier_configs, ConfigDistribution): self.classifier_configs = self.classifier_configs.draw(nb_clfs,