From fbe5c82f539adfe4da1c0afc40a51e4e841172d7 Mon Sep 17 00:00:00 2001 From: jiangtann <654523039@qq.com> Date: Fri, 26 Apr 2024 17:40:11 +0800 Subject: [PATCH] Fix: RandomSamplingNegPos forget to remove gt_ignore_flags --- mmdet/datasets/transforms/text_transformers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mmdet/datasets/transforms/text_transformers.py b/mmdet/datasets/transforms/text_transformers.py index 12a0e57db3d..46f537d0872 100644 --- a/mmdet/datasets/transforms/text_transformers.py +++ b/mmdet/datasets/transforms/text_transformers.py @@ -60,7 +60,7 @@ def check_for_positive_overflow(gt_bboxes, gt_labels, text, tokenizer, keep_gt_labels.append(gt_labels[i]) return gt_bboxes[keep_box_index], np.array( - keep_gt_labels, dtype=np.long), length + keep_gt_labels, dtype=np.long), length, keep_box_index def generate_senetence_given_labels(positive_label_list, negative_label_list, @@ -164,7 +164,7 @@ def od_aug(self, results): if '/' in value: text[key] = random.choice(value.split('/')).strip() - gt_bboxes, gt_labels, positive_caption_length = \ + gt_bboxes, gt_labels, positive_caption_length, keep_box_index = \ check_for_positive_overflow(gt_bboxes, gt_labels, text, self.tokenizer, self.max_tokens) @@ -232,6 +232,8 @@ def od_aug(self, results): results['gt_bboxes'] = gt_bboxes results['gt_bboxes_labels'] = gt_labels + if results.get('gt_ignore_flags', None) is not None: + results['gt_ignore_flags'] = results['gt_ignore_flags'][keep_box_index] results['text'] = pheso_caption results['tokens_positive'] = label_to_positions @@ -252,4 +254,4 @@ def transform(self, results: dict) -> dict: else: text = results['text'] results['text'] = list(text.values()) - return results + return results \ No newline at end of file