diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py index beaf5829d86231854630d0f935968b7e450e343a..98dc7a39ea552e70b5f555fb0728b650aa26fa5b 100644 --- a/summit/multiview_platform/utils/dataset.py +++ b/summit/multiview_platform/utils/dataset.py @@ -106,9 +106,9 @@ class Dataset(): return concat_views, view_limits def select_labels(self, selected_label_names): - selected_labels = [self.get_label_names().index(label_name.decode()) + selected_labels = [self.get_label_names(decode=True).index(label_name.decode()) if isinstance(label_name, bytes) - else self.get_label_names().index(label_name) + else self.get_label_names(decode=True).index(label_name) for label_name in selected_label_names] selected_indices = np.array([index for index, label in