Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in teacher weights calculation #21

Open
ShpihanVlad opened this issue Sep 3, 2023 · 6 comments
Open

Bug in teacher weights calculation #21

ShpihanVlad opened this issue Sep 3, 2023 · 6 comments

Comments

@ShpihanVlad
Copy link

Hi, your code implementation contains error in updating teacher weights.
Basically, current code implementation for WeightEMA at utils/torch_utils.py breaks statistics, which are saved in model for batch norm, if I recall correctly. Because of this, teacher model after about 5th epoch begins making invalid predictions, which further hurts training and it becomes a little worse, than without using teacher model at all.

To fix this, you can refer to original yolov5 EMA ModelEMA in the same file just above. I was able to rewrite the code that way and then reproduce results, which are close to the ones in the paper even at image size equal to 640.

My old issue #18 was from this bug, and currently not closed #6 faced that issue too.

I can make a pull request later in the next week, if you wish.

@hnuzhy
Copy link
Owner

hnuzhy commented Sep 3, 2023

OK. It is very nice of your to fix this bug. And I suggest you add environment details of your machine. Because I did not have this bug in my server. This bug may be closely related to PyTorch version or other libs.

@ShpihanVlad
Copy link
Author

ShpihanVlad commented Sep 3, 2023

@hnuzhy OK, I will. I worked with this project in May, so some of my logs are lost, so it may take some time. I remember that you can verify, that some of the model parameters are integers, using something like

for k, v in model.state_dict().items():
                if not v.dtype.is_floating_point:
                   print(k, v.dtype)

This is the reason why this check is present in ModelEMA. I'm not sure whether these values become floats or zeros, but that was the root of the issue, I'll try to make report later.

Also as far as I remember, there were some other issues, I'll try to find them and hopefully fix again

@nice3310
Copy link

@ShpihanVlad Hello, I'm currently facing the same issue as you, but after attempting to modify the calculation of the teacher model weight, there hasn't been a significant improvement.

Could you please share the source code with me?

Additionally, I'd like to ask the final mAP50 training with an image size of 640 in your case.

I greatly appreciate any help you can provide.

@ShpihanVlad
Copy link
Author

@nice3310 Hi, about source code, I'll need to check about some credentials in my repo, as far as I remember there may be some info from gpu cloud service I used. Also final [email protected] is not saved, but that was realy close to results reported in paper, less than 1.0 difference, also as I remember a little smaller, but this may be just random initialization.

Here is some sample of code which I used:
`class WeightEMA (object):
"""
Exponential moving average weight optimizer for mean teacher model.
Based mainly on ModelEMA class by ultralitics team
"""

def __init__(self, teacher_model, alpha=0.99):
    self.model = teacher_model  # FP32 Teacher EMA
    self.alpha = alpha
    self.decay = lambda epoch: self.alpha
    for p in self.model.parameters():  # teacher
        p.requires_grad_(False)

def update(self, stud_model, epoch):
    # Update EMA parameters from student model
    with torch.no_grad():
        alpha = self.decay(epoch)
        msd = stud_model.module.state_dict() if is_parallel(stud_model) else stud_model.state_dict()
        for k, v in self.model.state_dict().items():
            if v.dtype.is_floating_point:  # weights, biases.
                v *= alpha
                v += (1. - alpha) * msd[k].detach()

def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
    # Update EMA attributes
    copy_attr(self.model, model, include, exclude)`

this was mainly copypasted from yolov5 original implementation, and as far as I remember I didn't even need update_attr. I needed to change alpha based on epoch, so formated my code like this, and could just change decay func self.decay inside child classes. As far as I remember I also changed some parts of training code, but can't provide on this for now. Just make sure teacher weight updates are called after each step of student optimization.

@daydreamertu
Copy link

@ShpihanVlad @hnuzhy Hello, I used my own dataset for the experiment but found that the validation effect using best_tacher.pt was poor, while the experimental results of best_student. pt met expectations. What is the reason for this? And have you provided the visualization code for pseudo labels?

@ShpihanVlad
Copy link
Author

@ShpihanVlad @hnuzhy Hello, I used my own dataset for the experiment but found that the validation effect using best_tacher.pt was poor, while the experimental results of best_student. pt met expectations. What is the reason for this? And have you provided the visualization code for pseudo labels?

Hi, first of all, nope, I haven't provided code with fixes, I don't realy remember the reason, I don't fully remember why, maybe I haven't found my repo because it was deleted or because there was some sensitive creds there and I couldn't bother with making another repository.

As for why this happens, if I remember correctly, EMA (exponential moving average) in this implementation is implemented basically as FLOAT CONSTANT multiplied by current teacher's weights. However, this is invalid for YOLO v5, as there are some integer params in a model, as far as I remember some sort of batch statistics or something like that. After multiplication in this variation these models were broken, and from that comes invalid performance for teacher model and worse performance for student too. In fact, after fixing this bug I was able to somewhat replicate reported in paper results.
In code I provided earlier this issue is addresed in if statement with check for data type, you should change teacher weights update somewhere with this custom class. Where it should be - I don't remember, just check where ema is used for teacher model or something like this, make a dive into training code.
In my use case I may overcomplicated decay function, but this was part of my university course work research.

As for visualization - I used original YOLO v5 by ultralytics detect.py code, althought I think you should take it from some older releases of yolov5, because here some older version of model or training code was used.

Also I'd recommend during training validating both student and teacher models (some additional changes to saving may be required, I don't remember exactly, but I remember that I've changed some code in that regard too)

If I wasn't somewhere clear feel free to ask, and remind me again if I don't ask. But don't expect help with code from me, as I worked with this project in mainly spring of 2023, issue was opened after some time.
Best regards,
Vladyslav

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants