diff --git a/labelme/utils.py b/labelme/utils.py index 7bf9e87b9..614185b2a 100644 --- a/labelme/utils.py +++ b/labelme/utils.py @@ -119,14 +119,31 @@ def draw_label(label, img, label_names, colormap=None): def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'): - lbl = np.zeros(img_shape[:2], dtype=np.int32) + assert type in ['class', 'instance'] + + cls = np.zeros(img_shape[:2], dtype=np.int32) + if type == 'instance': + ins = np.zeros(img_shape[:2], dtype=np.int32) + instance_names = ['__background__'] for shape in shapes: polygons = shape['points'] - label_name = shape['label'] - label_value = label_name_to_value[label_name] + label = shape['label'] + if type == 'class': + cls_name = label + elif type == 'instance': + cls_name = label.split('-')[0] + if label not in instance_names: + instance_names.append(label) + ins_id = len(instance_names) - 1 + cls_id = label_name_to_value[cls_name] mask = polygons_to_mask(img_shape[:2], polygons) - lbl[mask] = label_value - return lbl + cls[mask] = cls_id + if type == 'instance': + ins[mask] = ins_id + + if type == 'instance': + return cls, ins + return cls def labelme_shapes_to_label(img_shape, shapes):