diff --git a/labelCloud/control/controller.py b/labelCloud/control/controller.py index 8fa96f3..ef6784f 100644 --- a/labelCloud/control/controller.py +++ b/labelCloud/control/controller.py @@ -52,12 +52,16 @@ def startup(self, view: "GUI") -> None: # Read labels from folders self.pcd_manager.read_pointcloud_folder() self.next_pcd(save=False) + if LabelConfig().type == LabelingMode.SEMANTIC_SEGMENTATION: + self.pcd_manager.populate_segmentation_list() def loop_gui(self) -> None: """Function collection called during each event loop iteration.""" self.set_crosshair() self.set_selected_side() self.view.gl_widget.updateGL() + if LabelConfig().type == LabelingMode.SEMANTIC_SEGMENTATION: + self.pcd_manager.loop_seg_list_check_state() # POINT CLOUD METHODS def next_pcd(self, save: bool = True) -> None: diff --git a/labelCloud/control/pcd_manager.py b/labelCloud/control/pcd_manager.py index c0093ed..0e3ddea 100644 --- a/labelCloud/control/pcd_manager.py +++ b/labelCloud/control/pcd_manager.py @@ -10,12 +10,17 @@ import numpy as np import open3d as o3d import pkg_resources +from PyQt5 import QtCore +from PyQt5.QtWidgets import QCheckBox, QLabel -from ..definitions import LabelingMode, Point3D +from ..definitions import LabelingMode +from ..definitions.types import Point3D from ..io.labels.config import LabelConfig from ..io.pointclouds import BasePointCloudHandler, Open3DHandler from ..model import BBox, Perspective, PointCloud +from ..utils.color import rgb_to_hex from ..utils.logger import blue, green, print_column +from ..view.startup.color_button import ColorButton from .config_manager import config from .label_manager import LabelManager @@ -263,7 +268,8 @@ def rotate_pointcloud( def assign_point_label_in_box(self, box: BBox) -> None: assert self.pointcloud is not None - points = self.pointcloud.points + points = self.pointcloud.points.copy() + points[~self.pointcloud.visible] = np.finfo(np.float32).max points_inside = box.is_inside(points) # Relabel the points if its inside the box @@ -306,3 +312,53 @@ def update_pcd_infos(self, pointcloud_label: Optional[str] = None) -> None: else: self.view.button_next_pcd.setEnabled(True) self.view.button_prev_pcd.setEnabled(True) + + def populate_segmentation_list(self) -> None: + assert self.pointcloud is not None + assert self.pointcloud.labels is not None + self.seg_list_label: List[QLabel] = [] + self.seg_list_check_box: List[QCheckBox] = [] + self.seg_list_check_state: List[QtCore.Qt.CheckState] = [] + for idx, label_class in enumerate(LabelConfig().classes, start=1): + self.seg_list_label.append(QLabel(label_class.name)) + check_box = QCheckBox() + check_box.setCheckState(QtCore.Qt.Checked) + color_button = ColorButton( + color=rgb_to_hex(label_class.color), changeable=False + ) + self.seg_list_check_box.append(check_box) + self.seg_list_check_state.append(self.seg_list_check_box[-1].checkState()) + self.view.segmentation_list.addWidget(self.seg_list_label[-1], idx, 0) + self.view.segmentation_list.addWidget(color_button, idx, 1) + self.view.segmentation_list.addWidget(self.seg_list_check_box[-1], idx, 2) + + def loop_seg_list_check_state(self): + curr_checked_status = self.seg_list_check_state.copy() + any_changed = False + + move_back = [] + move_away = [] + for idx, (box, prev_status) in enumerate( + zip( + self.seg_list_check_box, + self.seg_list_check_state, + ) + ): + if box.checkState() != prev_status: + any_changed = True + curr_checked_status[idx] = box.checkState() + changed_id = LabelConfig().classes[idx].id + if box.checkState() == QtCore.Qt.Checked: + move_back.append(changed_id) + else: + move_away.append(changed_id) + + if any_changed: + points_move_away = ( + np.isin(self.pointcloud.labels, move_away) if move_away else None + ) + points_move_back = ( + np.isin(self.pointcloud.labels, move_back) if move_back else None + ) + self.pointcloud.update_position_vbo(points_move_away, points_move_back) + self.seg_list_check_state = curr_checked_status diff --git a/labelCloud/model/point_cloud.py b/labelCloud/model/point_cloud.py index 7f0db60..097d508 100644 --- a/labelCloud/model/point_cloud.py +++ b/labelCloud/model/point_cloud.py @@ -65,6 +65,8 @@ def __init__( self.validate_segmentation_label() self.mix_ratio = config.getfloat("POINTCLOUD", "label_color_mix_ratio") + self.visible = np.ones((self.points.shape[0],), dtype=np.bool_) + self.vbo = None self.center: Point3D = tuple(np.sum(points[:, i]) / len(points) for i in range(3)) # type: ignore self.pcd_mins: npt.NDArray[np.float32] = np.amin(points, axis=0) @@ -233,11 +235,51 @@ def color_with_label(self) -> bool: def has_label(self) -> bool: return self.labels is not None + def update_position_vbo( + self, + points_move_away: Optional[npt.NDArray[np.bool_]], + points_move_back: Optional[npt.NDArray[np.bool_]], + ): + GL.glBindBuffer(GL.GL_ARRAY_BUFFER, self.position_vbo) + # Move points to super far away + if points_move_away is not None and np.any(points_move_away): + move_away_idxs = np.where(points_move_away)[0] + self.visible[move_away_idxs] = False + arrays = consecutive(move_away_idxs) + stride = self.points.shape[1] * SIZE_OF_FLOAT + for arr in arrays: + super_far: npt.NDArray[np.float32] = ( + np.ones((arr.shape[0], 3), dtype=np.float32) + * np.finfo(np.float32).max + ) + # partially update label_vbo from positions arr[0] to arr[-1] + GL.glBufferSubData( + GL.GL_ARRAY_BUFFER, + offset=arr[0] * stride, + size=super_far.nbytes, + data=super_far, + ) + # Move points back + if points_move_back is not None and np.any(points_move_back): + move_back_idxs = np.where(points_move_back)[0] + self.visible[move_back_idxs] = True + arrays = consecutive(move_back_idxs) + stride = self.points.shape[1] * SIZE_OF_FLOAT + + for arr in arrays: + points: npt.NDArray[np.float32] = self.points[arr] + GL.glBufferSubData( + GL.GL_ARRAY_BUFFER, + offset=arr[0] * stride, + size=points.nbytes, + data=points, + ) + def update_selected_points_in_label_vbo( self, points_inside: npt.NDArray[np.bool_] ) -> None: - """Send the selected updated label colors to label vbo. This function - assumes the `self.label_colors[points_inside]` have been altered. + """Send the selected updated label colors to `self.label_vbo`. This + function assumes the `self.label_colors[points_inside]` have been altered. This function only partially updates the label vbo to minimise the data sent to gpu. It leverages `glBufferSubData` method to perform partial update and `consecutive` method to find consecutive indexes diff --git a/labelCloud/resources/interfaces/interface.ui b/labelCloud/resources/interfaces/interface.ui index 3f965c5..d8e181a 100644 --- a/labelCloud/resources/interfaces/interface.ui +++ b/labelCloud/resources/interfaces/interface.ui @@ -1460,6 +1460,97 @@ + + + + + DejaVu Sans,Arial + 75 + true + + + + Segmentation Controls + + + + QLayout::SetDefaultConstraint + + + 0 + + + 6 + + + + + + 0 + 0 + + + + + DejaVu Sans,Arial + 12 + 50 + true + + + + Class Name + + + + + + + + 0 + 0 + + + + + DejaVu Sans,Arial + 12 + 50 + true + + + + Color + + + + + + + + 0 + 0 + + + + + DejaVu Sans,Arial + 12 + 50 + true + + + + Visible + + + + + + + + + diff --git a/labelCloud/view/color_button.py b/labelCloud/view/color_button.py new file mode 100644 index 0000000..9170038 --- /dev/null +++ b/labelCloud/view/color_button.py @@ -0,0 +1,59 @@ +from PyQt5 import QtGui, QtWidgets +from PyQt5.QtCore import Qt, pyqtSignal + + +class ColorButton(QtWidgets.QPushButton): + """ + Custom Qt Widget to show a chosen color. + + Left-clicking the button shows the color-chooser, while + right-clicking resets the color to None (no-color). + + Source: https://www.pythonguis.com/widgets/qcolorbutton-a-color-selector-tool-for-pyqt/ + """ + + colorChanged = pyqtSignal(object) + + def __init__(self, *args, color="#FF0000", changeable: bool = True, **kwargs): + super(ColorButton, self).__init__(*args, **kwargs) + + self._color = None + self._default = color + if changeable: + self.pressed.connect(self.onColorPicker) + + # Set the initial/default state. + self.setColor(self._default) + + def setColor(self, color): + if color != self._color: + self._color = color + self.colorChanged.emit(color) + + if self._color: + self.setStyleSheet("background-color: %s;" % self._color) + else: + self.setStyleSheet("") + + def color(self): + return self._color + + def onColorPicker(self): + """ + Show color-picker dialog to select color. + + Qt will use the native dialog by default. + + """ + dlg = QtWidgets.QColorDialog(self) + if self._color: + dlg.setCurrentColor(QtGui.QColor(self._color)) + + if dlg.exec_(): + self.setColor(dlg.currentColor().name()) + + def mousePressEvent(self, e): + if e.button() == Qt.RightButton: + self.setColor(self._default) + + return super(ColorButton, self).mousePressEvent(e) diff --git a/labelCloud/view/gui.py b/labelCloud/view/gui.py index 631b70a..b5bbb4d 100644 --- a/labelCloud/view/gui.py +++ b/labelCloud/view/gui.py @@ -20,6 +20,8 @@ QMessageBox, ) +from labelCloud.view.startup.dialog import StartupDialog + from ..control.config_manager import config from ..definitions import Color3f, LabelingMode from ..io.labels.config import LabelConfig @@ -194,6 +196,8 @@ def __init__(self, control: "Controller") -> None: self.button_save_label: QtWidgets.QPushButton # RIGHT PANEL + self.segmentation_list_group: QtWidgets.QGroupBox + self.segmentation_list: QtWidgets.QGridLayout self.label_list: QtWidgets.QListWidget self.current_class_dropdown: QtWidgets.QComboBox self.button_deselect_label: QtWidgets.QPushButton @@ -257,6 +261,7 @@ def __init__(self, control: "Controller") -> None: if LabelConfig().type == LabelingMode.OBJECT_DETECTION: self.button_assign_label.setVisible(False) self.act_color_with_label.setVisible(False) + self.segmentation_list_group.setVisible(False) # Connect with controller self.controller.startup(self) @@ -521,9 +526,6 @@ def init_progress(self, min_value, max_value): def update_progress(self, value) -> None: self.progressbar_pcds.setValue(value) - def update_current_class_dropdown(self) -> None: - self.controller.pcd_manager.populate_class_dropdown() - def update_bbox_stats(self, bbox) -> None: viewing_precision = config.getint("USER_INTERFACE", "viewing_precision") if bbox and not self.line_edited_activated(): diff --git a/labelCloud/view/startup/color_button.py b/labelCloud/view/startup/color_button.py index d1d6bae..9170038 100644 --- a/labelCloud/view/startup/color_button.py +++ b/labelCloud/view/startup/color_button.py @@ -14,12 +14,13 @@ class ColorButton(QtWidgets.QPushButton): colorChanged = pyqtSignal(object) - def __init__(self, *args, color="#FF0000", **kwargs): + def __init__(self, *args, color="#FF0000", changeable: bool = True, **kwargs): super(ColorButton, self).__init__(*args, **kwargs) self._color = None self._default = color - self.pressed.connect(self.onColorPicker) + if changeable: + self.pressed.connect(self.onColorPicker) # Set the initial/default state. self.setColor(self._default) diff --git a/labelCloud/view/startup_dialog.py b/labelCloud/view/startup_dialog.py new file mode 100644 index 0000000..a75bdbc --- /dev/null +++ b/labelCloud/view/startup_dialog.py @@ -0,0 +1,246 @@ +import random +from typing import List, Optional, Tuple + +import pkg_resources +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QIcon, QPixmap, QValidator +from PyQt5.QtWidgets import ( + QButtonGroup, + QDesktopWidget, + QDialog, + QDialogButtonBox, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QScrollArea, + QSizePolicy, + QSpinBox, + QVBoxLayout, + QWidget, +) + +from ..definitions.labeling_mode import LabelingMode +from ..io.labels.config import ClassConfig, LabelConfig +from ..utils.color import get_distinct_colors, hex_to_rgb, rgb_to_hex +from ..view.color_button import ColorButton + + +class LabelNameValidator(QValidator): + def validate(self, a0: str, a1: int) -> Tuple["QValidator.State", str, int]: + if a0 != "": + return (QValidator.Acceptable, a0, a1) + return (QValidator.Invalid, a0, a1) + + +class StartupDialog(QDialog): + NAME_VALIDATOR = LabelNameValidator() + + def __init__(self, parent=None) -> None: + super().__init__(parent) + self.parent_gui = parent + + self.setWindowTitle("Welcome to labelCloud") + screen_size = QDesktopWidget().availableGeometry(self).size() + self.resize(screen_size * 0.5) + self.setWindowIcon( + QIcon( + pkg_resources.resource_filename( + "labelCloud.resources.icons", "labelCloud.ico" + ) + ) + ) + self.setContentsMargins(50, 10, 50, 10) + + self.colors: List[str] = [] + + main_layout = QVBoxLayout() + main_layout.setSpacing(15) + main_layout.setAlignment(Qt.AlignTop) + self.setLayout(main_layout) + + # 1. Row: Selection of labeling mode via checkable buttons + self.button_semantic_segmentation: QPushButton + self.add_labeling_mode_row(main_layout) + + # 2. Row: Definition of class labels + self.add_class_definition_rows(main_layout) + + # 3. Row: Addition of new class labels + self.button_add_label = QPushButton(text="Add new label") + self.button_add_label.clicked.connect( + lambda: self.add_label(id=self.next_label_id) + ) + self.delete_buttons.buttonClicked.connect(self.delete_label) + main_layout.addWidget(self.button_add_label) + + # 4. Row: Buttons to save or cancel + self.buttonBox = QDialogButtonBox(QDialogButtonBox.Save) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + + main_layout.addWidget(self.buttonBox) + + # ---------------------------------------------------------------------------- # + # SETUP # + # ---------------------------------------------------------------------------- # + + def add_labeling_mode_row(self, parent_layout: QVBoxLayout) -> None: + """ + Add a row to the dialog to select the labeling mode with two exclusive buttons. + """ + parent_layout.addWidget(QLabel("Select labeling mode:")) + + row_buttons = QHBoxLayout() + + self.button_object_detection = QPushButton( + text=LabelingMode.OBJECT_DETECTION.title().replace("_", " ") + ) + self.button_object_detection.setCheckable(True) + self.button_object_detection.setToolTip( + "This will result in a label file for each point cloud\n" + "with a bounding box for each annotated object." + ) + row_buttons.addWidget(self.button_object_detection) + + self.button_semantic_segmentation = QPushButton( + text=LabelingMode.SEMANTIC_SEGMENTATION.title().replace("_", " ") + ) + self.button_semantic_segmentation.setCheckable(True) + self.button_semantic_segmentation.setToolTip( + "This will result in a *.bin file for each point cloud\n" + "with a label for each annotated point of an object." + ) + row_buttons.addWidget(self.button_semantic_segmentation) + + parent_layout.addLayout(row_buttons) + + # Click callbacks to switch between the two modes + def select_object_detection(): + self.button_object_detection.setChecked(True) + self.button_semantic_segmentation.setChecked(False) + + self.button_object_detection.clicked.connect(select_object_detection) + + def select_semantic_segmentation(): + self.button_semantic_segmentation.setChecked(True) + self.button_object_detection.setChecked(False) + + self.button_semantic_segmentation.clicked.connect(select_semantic_segmentation) + + def add_class_definition_rows(self, parent_layout: QVBoxLayout) -> None: + scroll_area = QScrollArea() + widget = QWidget() + self.class_labels = QVBoxLayout() + self.class_labels.addStretch() + + widget.setLayout(self.class_labels) + self.delete_buttons = QButtonGroup() + + for class_label in LabelConfig().classes: + self.add_label( + class_label.id, class_label.name, rgb_to_hex(class_label.color) + ) + + parent_layout.addWidget(QLabel("Change class labels:")) + + scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(widget) + scroll_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + parent_layout.addWidget(scroll_area) + # Load annotation mode + if LabelConfig().type == LabelingMode.OBJECT_DETECTION: + self.button_object_detection.setChecked(True) + else: + self.button_semantic_segmentation.setChecked(True) + + # ---------------------------------------------------------------------------- # + # PROPERTIES # + # ---------------------------------------------------------------------------- # + + @property + def get_labeling_mode(self) -> LabelingMode: + if self.button_object_detection.isChecked(): + return LabelingMode.OBJECT_DETECTION + return LabelingMode.SEMANTIC_SEGMENTATION + + @property + def nb_of_labels(self) -> int: + return len(self.class_labels.children()) + + @property + def next_label_id(self) -> int: + max_class_id = 0 + for i in range(self.nb_of_labels): + label_id = int(self.class_labels.itemAt(i).itemAt(0).widget().text()) # type: ignore + max_class_id = max(max_class_id, label_id) + return max_class_id + 1 + + @property + def distinct_color(self) -> str: + if not self.colors: + self.colors = get_distinct_colors(25) + random.shuffle(self.colors) + return self.colors.pop() + + # ---------------------------------------------------------------------------- # + # LOGIC # + # ---------------------------------------------------------------------------- # + + def add_label( + self, id: int, name: Optional[str] = None, hex_color: Optional[str] = None + ) -> None: + row_label = QHBoxLayout() + row_label.setSpacing(15) + + label_id = QSpinBox() + label_id.setMinimum(0) + label_id.setMaximum(255) + label_id.setValue(id) + row_label.addWidget(label_id) + + label_name = QLineEdit(name or f"label_{id}") + label_name.setValidator(self.NAME_VALIDATOR) + row_label.addWidget(label_name, stretch=2) + + label_color = ColorButton(color=hex_color or self.distinct_color) + row_label.addWidget(label_color) + + label_delete = QPushButton( + icon=QIcon( + QPixmap( + pkg_resources.resource_filename( + "labelCloud.resources.icons", "delete-outline.svg" + ) + ) + ), + text="", + ) + self.delete_buttons.addButton(label_delete) + row_label.addWidget(label_delete) + + self.class_labels.insertLayout(self.nb_of_labels, row_label) + + def delete_label(self, delete_button: QPushButton) -> None: + row_label: QHBoxLayout + for row_index, row_label in enumerate(self.class_labels.children()): # type: ignore + if row_label.itemAt(3).widget() == delete_button: + for _ in range(row_label.count()): + row_label.removeWidget(row_label.itemAt(0).widget()) + break + + self.class_labels.removeItem(self.class_labels.itemAt(row_index)) # type: ignore + + def save_class_labels(self) -> None: + classes = [] + for i in range(self.nb_of_labels): + row: QHBoxLayout = self.class_labels.itemAt(i) # type: ignore + class_id = int(row.itemAt(0).widget().text()) # type: ignore + class_name = row.itemAt(1).widget().text() # type: ignore + class_color = hex_to_rgb(row.itemAt(2).widget().color()) # type: ignore + classes.append(ClassConfig(id=class_id, name=class_name, color=class_color)) + LabelConfig().classes = classes + LabelConfig().type = self.get_labeling_mode + LabelConfig().save_config()