Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GroundingDino] Fix grounding dino loss 🚨 #31828

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

EduardoPach
Copy link
Contributor

@EduardoPach EduardoPach commented Jul 7, 2024

What does this PR do?

Fixes #31434

As the original repo doesn't provide the loss implementation I'm using the one implemented here as a baseline since it was mentioned by the original repo, on this issue IDEA-Research/GroundingDINO#241, as a reliable source if one wants to train a GroundingDino model

TODO:

  • Test GroundingDinoMatcher and GroundingDinoLoss are working properly

Explanation of the Issue and Solution

So the issue was that GroundingDinoLoss and GroundingDinoHungarianMatcher were just a copy from DeformableDetr which is used for closed-set object detection (i.e. a fixed set of categories). Whereas in GroundingDino there's no limited amount of categories and the output logits are d_model dimensional where the first seq_len elements have a specified value and the subsequent are nan. The main differences are:

  1. class_labels are associated with the text prompt used
  2. The logits are asscoaited with the tokens of the text so it's not necessarily 1-to-1

For instance if an image with bounding boxes with fishes and jellyfishes using a prompt "fish. jellyfish." fish should have class_label 0 assigned to it and jellyfish should have 1 assigned. If the position of jellyfish and fish in the prompt swapped then the class_labels would swap as well. Moreover, jellyfish is represented by two tokens ([20919, 7529]) and fish by one token ([3869]) therefore we need to select the appropriate logits for each class.

As the original implementation doesn't provide the training loop or the loss implementation, but does recommend other implementations for training GroundingDino on this issue IDEA-Research/GroundingDINO#241, I took as baseline the implementation from Open-GroundingDino as it supports both visual grounding and object detection and they've trained their own GroundingDino using their code base achieving good performance.

Things added in this PR are:

  • build_label_maps which generates a list of torch.Tensor with lenght batch_size mapping each category to its corresponding tokens based on the input_ids
  • build_text_mask just expand the attention_mask to select the appropriate tokens when computing GroundingDino.loss_labels
  • Added enc_topk_proposals, encoder_logits and encoder_pred_boxes to GroundingDinoModelOutput and GroundingDinoObjectDetectionOutput to compute first stage loss
  • Added class_loss_coefficient (with correct default value) and class_loss_reduction to GroundingDinoConfig. class_loss_reduction was added because in sigmoid_focal_loss from the baseline implementation they reduced loss_ce with a simple sum, but that makes the losses imbalanced most of the time and in the original implementation they do have a sigmoid_focal_loss implemented, but using mean reduction, therefore I made I decided to make it configurable and use the sum one for testing reasons
  • Modifications to GroundingDinoLoss and GroundingDinoHungarianMatcher

Also added a new integration test called test_grounding_dino_loss where I compare the loss obtained from 2 sample images with the baseline implementation from Open-GroundingDino.

c.c. @amyeroberts

@EduardoPach EduardoPach changed the title WIP - [GroundingDino] Fix grounding dino loss [GroundingDino] Fix grounding dino loss Jul 14, 2024
@EduardoPach
Copy link
Contributor Author

@amyeroberts FYI for some reason, when testing locally, test_cross_attention_mask is failing on this branch, but when I tested using the main branch it was also failing (locally)

@EduardoPach
Copy link
Contributor Author

c.c. @amyeroberts

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Overall looks good. A breaking change, so we should append the PR title with 🚨, but I think this is acceptable as it's aligning ourselves with the recommended loss calculation.

Comment on lines 155 to 157
input_ids = torch.tensor([101, 3869, 1012, 11420, 1012, 1012, 102])
input_ids = input_ids.unsqueeze(0).expand(self.batch_size, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch to hard coded input ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise tests that use labels and, therefore, compute a loss would complain as build_label_maps (here) would return None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can add a comment to explain this

@SangbumChoi
Copy link
Contributor

@EduardoPach Thanks for the working this loss. Just sharing more well developed code for finetuning GroundingDINO https://github.com/open-mmlab/mmdetection/blob/main/configs/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365.py

@EduardoPach EduardoPach requested a review from amyeroberts July 23, 2024 11:30
@EduardoPach EduardoPach changed the title [GroundingDino] Fix grounding dino loss [GroundingDino] Fix grounding dino loss 🚨 Jul 23, 2024
@EduardoPach
Copy link
Contributor Author

c.c. @amyeroberts

@EduardoPach
Copy link
Contributor Author

Maybe @NielsRogge could have a look?

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Comment on lines 155 to 157
input_ids = torch.tensor([101, 3869, 1012, 11420, 1012, 1012, 102])
input_ids = input_ids.unsqueeze(0).expand(self.batch_size, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can add a comment to explain this

@EduardoPach
Copy link
Contributor Author

Cough cough, c.c @amyeroberts

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the continued work on this!

Main comments are about the precision and clarity of language in the docstrings and comments. It's important to write messages such that someone new to the code can understand.

Copy link
Contributor

@SangbumChoi SangbumChoi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks good. Let me check also in this PR #32483 for stable training convergence.

def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
# Similar to the one used in `DeformableDetr` but we pass `num_queries`, as `logits` are flattened
# due to masked selection, and support different `reduction` modes.
def sigmoid_focal_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/longzw1997/Open-GroundingDino/blob/main/models/GroundingDINO/utils.py#L138

Even though there are customization in this code, I like the current version of sigmoid_focal_loss 👍🏼

@SangbumChoi
Copy link
Contributor

Screenshot 2024-08-23 at 9 54 40 PM
This is the result of current commit

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@EduardoPach
Copy link
Contributor Author

c.c. @amyeroberts

@@ -741,3 +780,53 @@ def test_cross_attention_mask(self):
self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3))
# For some reason 12 elements are > 1e-3, but the rest are fine
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))

def test_grounding_dino_loss(self):
ds = load_dataset("EduardoPacheco/aquarium-sample", split="train")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should move this to huggingface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, wdyt @amyeroberts

Copy link
Member

@qubvel qubvel Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, it is great to have it tested!
Do we actually need to load the whole dataset here? Can't we copy annotations and upload one image somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel by upload one image somewhere you mean to a hf dataset or to the tests fixtures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, no, maybe?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh can you please help, do we have any place for test assets?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EduardoPach @qubvel I think we can do like this

# We will verify our results on an image of cute cats

Copy link
Collaborator

@ydshieh ydshieh Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. If IIRC, that dataset contains only 2 examples, right?

The easiest way is to have it as

hf-internal-testing/aquarium-sample

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel I added you to hf-internal-testing (you have to accept the invitation) then you can create a copy of the above dataset.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and all the work iterating on this!

Just two things before we're ready to merge:

  • This convo
  • Slow model tests -- Could you push an empty commit with the message [run_slow] grounding_dino?

@SangbumChoi
Copy link
Contributor

SangbumChoi commented Oct 11, 2024

do you mean the loss test is random seed dependant? In that case, we can either slightly increase the tolerance or mask it as @is_flaky()

@qubvel Still not sure about it. I will debug deeper more tomorrow but the fact is that sometime the CI pass and sometime fails 😭 (Or there might be a computing issue with torch.backends.cudnn.deterministic = True)

@qubvel
Copy link
Member

qubvel commented Oct 11, 2024

Ok, sure!

@stevenwudi
Copy link

stevenwudi commented Oct 12, 2024

HI @SangbumChoi , from the test_grounding_dino_loss, you have the "loss_ce_enc": torch.tensor(16226.3145),:
the scale of the loss is way to big from other loss? Isn't the trained network should have smaller scale loss or is there something wrong with the implementation?

Or does the loss actually need to include loss_ce_enc since the first stage is just to used for the regional proposal?

@SangbumChoi
Copy link
Contributor

SangbumChoi commented Oct 12, 2024

@stevenwudi

loss actually need to include loss_ce_enc since the first stage is just to used for the regional proposal?

Yeah, otherwise can you explain more detail about the reason why "loss_ce_enc" should be low? (I am also open to discuss this circumstances!)

@stevenwudi
Copy link

stevenwudi commented Oct 12, 2024

Yeah, otherwise can you explain more detail about the reason why "loss_ce_enc" should be low? (I am also open to discuss this circumstances!)

with the current sum as reduction, the loss_ce_enc will be several scales larger than all the rest of the loss and this will literally makes other loss neglible which will hinders training.

Maybe a better term is why we are not using mean as the class_loss_reduction default since the Open-GroundingDino seem to have a mean?
https://github.com/longzw1997/Open-GroundingDino/blob/main/models/GroundingDINO/utils.py#L168

@SangbumChoi
Copy link
Contributor

SangbumChoi commented Oct 12, 2024

@stevenwudi Yeah good point. Even though the loss scale is large, eventually it trains well at my experiment, but as you suggested it could be better if we set mean. Do you have some time for to experiment between the effectiveness of mean and sum? I will change that default as mean. cc. @EduardoPach

@stevenwudi
Copy link

stevenwudi commented Oct 12, 2024

@ydshieh Yeah good point. Even though the loss scale is large, eventually it trains well at my experiment, but as you suggested it could be better if we set mean. Do you have some time for to experiment between the effectiveness of mean and sum? I will change that default as mean. cc. @EduardoPach

(Actually, I don't have any experiment stats to support, and I am happy to know that you did train well at your experiment)
But what I observer is that with the sum reduction method, the first stage loss_ce is really large where as the second stage loss_ce lie in a more sensible scale. Any ideas? @SangbumChoi

@EduardoPach
Copy link
Contributor Author

I will change that default as mean

@SangbumChoi I did this before but switch back to default sum. Why? Because using mean yielded a terrible result in my experiments with aquarium dataset and because Open-GroundingDino uses sum as default

with the current sum as reduction, the loss_ce_enc will be several scales larger than all the rest of the loss and this will literally makes other loss neglible which will hinders training.

@stevenwudi first of all thanks for spending your time and looking into the code! I had the same thought, but then after some experiments (as mentioned above) I set sum as default. Why? Because sum is the default in Open-GroundingDino .

https://github.com/longzw1997/Open-GroundingDino/blob/main/models/GroundingDINO/utils.py#L168

This is not use in their code. See here. Thus the expected value, since we're basing ourselves in Open-GroundingDino is pretty high, but converges reasonably fast (see attached img).

Btw if you actually look the encoder_logits they're pretty similar to logits because the encoder_logits the first stage output that then gets refined on the second stage and generates the final logits

image

@stevenwudi
Copy link

stevenwudi commented Oct 14, 2024

@EduardoPach thanks for the detailed explaination, code link and the training loss, this is super helpful.

nit: curious for the extremely large value of loss_ce_enc for a pretrained model, why do you think the scale is so much larger (hundreds of times the scale of other loss).
My understanding is the first stage and second stage loss_ce are doing very similar things: (quote from deformable DETR):

Two-Stage Deformable DETR. In the original DETR, object queries in the decoder are irrelevant
to the current image. Inspired by two-stage object detectors, we explore a variant of Deformable
DETR for generating region proposals as the first stage. The generated region proposals will be fed
into the decoder as object queries for further refinement, forming a two-stage Deformable DETR.
In the first stage, to achieve high-recall proposals, each pixel in the multi-scale feature maps would
serve as an object query. However, directly setting object queries as pixels will bring unacceptable
computational and memory cost for the self-attention modules in the decoder, whose complexity
grows quadratically with the number of queries. To avoid this problem, we remove the decoder and
form an encoder-only Deformable DETR for region proposal generation. In it, each pixel is assigned
as an object query, which directly predicts a bounding box. Top scoring bounding boxes are picked
as region proposals. No NMS is applied before feeding the region proposals to the second stage.

Hence, I can not help but wonder: does the high loss_ce_enc means that there is almost little generalization of the groundingdino wrt. the cross attention logit input. Could there be something fundamentally flaw with the grounding dino formulation?

This is not use in their code. See here.

Hmm, good to know, it seems that the Open-Groundingdino code is kinda really messy, glad we will have this HF implemetation soon 👍

@SangbumChoi
Copy link
Contributor

SangbumChoi commented Oct 15, 2024

@ydshieh @qubvel @EduardoPach The reason of failing the CI test was the difference between cpu and gpu. There is a slight difference from the beginning part (e.g. load_backbone part) of the architecture. (There is also slight difference in text part also when we change atol into 1e-7) This difference keeps get larger and larger to the head part and makes the non-negligible difference. This difference also makes to pick the different topk proposal also.

for i in outputs.keys():
    try:
        difference = cpu_outputs[i] - outputs[i].cpu()
        print(f"Difference: {(difference>1e-6).sum():10} Name: {i:40} Size: {cpu_outputs[i].size()}")
    except:
        continue
Difference:          1 Name: loss                                     Size: torch.Size([])
Difference:      15992 Name: logits                                   Size: torch.Size([2, 900, 256])
Difference:       1839 Name: pred_boxes                               Size: torch.Size([2, 900, 4])
Difference:     201506 Name: last_hidden_state                        Size: torch.Size([2, 900, 256])
Difference:         57 Name: init_reference_points                    Size: torch.Size([2, 900, 4])
Difference:    1056660 Name: intermediate_hidden_states               Size: torch.Size([2, 6, 900, 256])
Difference:      10381 Name: intermediate_reference_points            Size: torch.Size([2, 6, 900, 4])
Difference:      84492 Name: encoder_last_hidden_state_vision         Size: torch.Size([2, 23890, 256])
Difference:          0 Name: encoder_last_hidden_state_text           Size: torch.Size([2, 21, 256])
Difference:          2 Name: encoder_topk_proposals                   Size: torch.Size([2, 900])
Difference:     223762 Name: enc_outputs_class                        Size: torch.Size([2, 23890, 256])
Difference:      40176 Name: enc_outputs_coord_logits                 Size: torch.Size([2, 23890, 4])
Difference:      14181 Name: encoder_logits                           Size: torch.Size([2, 900, 256])
Difference:         57 Name: encoder_pred_boxes                       Size: torch.Size([2, 900, 4])

for i in outputs.keys():
    try:
        difference = cpu_outputs[i] - outputs[i].cpu()
        print(f"Difference: {(difference>1e-3).sum():10} Name: {i:40} Size: {cpu_outputs[i].size()}")
    except:
        continue
Difference:          1 Name: loss                                     Size: torch.Size([])
Difference:       4110 Name: logits                                   Size: torch.Size([2, 900, 256])
Difference:         97 Name: pred_boxes                               Size: torch.Size([2, 900, 4])
Difference:      38021 Name: last_hidden_state                        Size: torch.Size([2, 900, 256])
Difference:          8 Name: init_reference_points                    Size: torch.Size([2, 900, 4])
Difference:     120993 Name: intermediate_hidden_states               Size: torch.Size([2, 6, 900, 256])
Difference:        329 Name: intermediate_reference_points            Size: torch.Size([2, 6, 900, 4])
Difference:          0 Name: encoder_last_hidden_state_vision         Size: torch.Size([2, 23890, 256])
Difference:          0 Name: encoder_last_hidden_state_text           Size: torch.Size([2, 21, 256])
Difference:          2 Name: encoder_topk_proposals                   Size: torch.Size([2, 900])
Difference:          0 Name: enc_outputs_class                        Size: torch.Size([2, 23890, 256])
Difference:          0 Name: enc_outputs_coord_logits                 Size: torch.Size([2, 23890, 4])
Difference:          0 Name: encoder_logits                           Size: torch.Size([2, 900, 256])
Difference:          8 Name: encoder_pred_boxes                       Size: torch.Size([2, 900, 4])

@EduardoPach I think the test_grounding_dino_loss is the same issue as I stated above. So for the solution we can recalculate based on the gpu, enlarge the tolerance, or remove this testing function

@SangbumChoi
Copy link
Contributor

Additional analysis for cpu/gpu difference

for i in range(len(backbone[0])):
    i, j = backbone[0][i][0], cpu_backbone[0][i][0]
    try:
        difference = j - i.cpu()
        print(f"Difference: {(difference>4e-5).sum():10}")
    except:
        continue

Difference: 0
Difference: 118
Difference: 11

backbone also has some large difference at the second and last layer.

Also SwinEmbedding has 1e-6 difference even if it is the very beginning of the architecture

class SwinEmbeddings(nn.Module):

@SangbumChoi
Copy link
Contributor

@qubvel requesting for the review

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 16, 2024

So for the solution we can recalculate based on the gpu, enlarge the tolerance, or remove this testing function

We are running CI on T4 GPU: we can update the expected values accordingly on our side.

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 17, 2024

updated the values for test_grounding_dino_loss and it pass now.
I will leave @qubvel to take a final look and merge when they are back.

@qubvel
Copy link
Member

qubvel commented Oct 30, 2024

Hi! I suppose we can merge it as tests fails are unrelated!

The only thing it would be nice to do before merging: I see @ArthurZucker started the initiative of moving losses to a separate module, let's update this branch and move this loss there too. Thanks!

@EduardoPach
Copy link
Contributor Author

Hi! I suppose we can merge it as tests fails are unrelated!

The only thing it would be nice to do before merging: I see @ArthurZucker started the initiative of moving losses to a separate module, let's update this branch and move this loss there too. Thanks!

Basically, just moving the loss implementation to a new file in the transformers/loss and properly set grounding dino in the LOSS_MAPING, right? @SangbumChoi do you have time to do this? Otherwise I can find some time during the weekend

@SangbumChoi
Copy link
Contributor

It would be great if you can handle this weekend. @EduardoPach

@NielsRogge
Copy link
Contributor

Hi @SangbumChoi @EduardoPach let us know if you need any help to finish this one :)

@SangbumChoi
Copy link
Contributor

@NielsRogge Nothing special just have no time to look on... I will try to finish this in this weekdays.

@EduardoPach
Copy link
Contributor Author

Had to travel on the weekend I was supposed to finish this 😅 , should probably be able to give the final push this weekend. We can coordinate on discord if you want @SangbumChoi

@qubvel
Copy link
Member

qubvel commented Nov 25, 2024

Hi @EduardoPach, thanks for the update! Is it ready or do you have something to finish?

@EduardoPach
Copy link
Contributor Author

EduardoPach commented Nov 25, 2024

Hi @EduardoPach, thanks for the update! Is it ready or do you have something to finish?

Hey @qubvel, I'll take a final look into the test issues, and it will be ready. I should be able to do that this weekend.

@SangbumChoi
Copy link
Contributor

@EduardoPach Any future plan for working this? Otherwise I can take a look

@EduardoPach
Copy link
Contributor Author

@EduardoPach Any future plan for working this? Otherwise I can take a look

@SangbumChoi probably not in the near future. If you have the bandwidth I'd appreciate that 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GroundingDino - Loss calculation exceptions
9 participants