Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can not fine-tuning after reload the model weights #168

Open
wubizhi opened this issue Sep 24, 2024 · 6 comments
Open

can not fine-tuning after reload the model weights #168

wubizhi opened this issue Sep 24, 2024 · 6 comments

Comments

@wubizhi
Copy link

wubizhi commented Sep 24, 2024

            n_bases=2
            softGBM = SoftGradientBoostingRegressor(
                estimator=MLP,
                n_estimators=n_bases,
                shrinkage_rate=1.00,
                cuda=True
            )

            io.load(softGBM, save_dir='./torch_ensemble_results/softGBM/')  # reload

            criterion = StepwiseMSELoss()
            softGBM.set_criterion(criterion)
            softGBM.set_optimizer('Adam', lr=0.001, weight_decay=5e-4)
            softGBM.set_scheduler("ReduceLROnPlateau")
            
            # Re-training
            softGBM.fit(train_loader=new_train_loader,
                             log_interval=128, 
                             epochs=20, 
                             test_loader=new_vali_loader,
                             save_model=True, 
                             save_dir='./torch_ensemble_results/softGBM/')

I want to know my code above can work or not? if i have just trained the model in 20 epoches, and reload the model weights for the longer epoches training? if it make sense, why it would report bug like below:

sKAN_softGBM.fit(train_loader=new_train_loader,

File "/home/WuBizhi/anaconda3/envs/torch-ensemble/lib/python3.9/site-packages/torchensemble/soft_gradient_boosting.py", line 514, in fit
super().fit(
File "/home/WuBizhi/anaconda3/envs/torch-ensemble/lib/python3.9/site-packages/torchensemble/soft_gradient_boosting.py", line 261, in fit
loss += criterion(output[idx], rets[idx])
IndexError: list index out of range

@xuyxu
Copy link
Member

xuyxu commented Sep 25, 2024

Hi @wubizhi, what is the value of n_bases in your code snippet and in the model located at ./torch_ensemble_results/softGBM/ ?

@wubizhi
Copy link
Author

wubizhi commented Sep 25, 2024

n_bases = 2

@xuyxu
Copy link
Member

xuyxu commented Sep 26, 2024

Is it the same as the model located at ./torch_ensemble_results/softGBM/ ?

@wubizhi
Copy link
Author

wubizhi commented Sep 26, 2024

Yes, the trained model path, the model name keep the same.

Firstly, i trained the model for 20 epoches, the trained weights were stored in the path: ./torch_ensemble_results/softGBM/

Then, i using the io.load to reload the model weights in the path ./torch_ensemble_results/softGBM/

Thirdly, i just want to runing the same model for another 50 epoches, but, it report the issue as i paste above.

Dose the reload and re-running of torch-ensemble work well for you? if yes, can you give me some demo that i can find out what heppend in my own code? Or have you had some tips or idea for the issue?

Best wishes, thanks very much!

@xuyxu
Copy link
Member

xuyxu commented Sep 26, 2024

Sure, I will try to reproduce your problem first, and then get back to you.

@GautamSharma11
Copy link

It doesn't seem like io.save() method actually saves anything like optimizer.state_dict() or scheduler.state_dict() which seem to mess up my trained ckpt.pth when I load it and call .fit() to continue training. Can you suggest any work around for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants