Skip to content

SurgeryV2: Bridging the Gap Between Model Merging and Multi-Task Learning with Deep Representation Surgery. Arxiv, 2024.

Notifications You must be signed in to change notification settings

EnnengYang/SurgeryV2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Surgery V2

A repository of 'SurgeryV2: Bridging the Gap Between Model Merging and Multi-Task Learning with Deep Representation Surgery'.

Abstract

Surgery V2

Citation

If you find our paper or this resource helpful, please consider citing:

@article{SurgeryV2_Arxiv_2024,
  title={SurgeryV2: Bridging the Gap Between Model Merging and Multi-Task Learning with Deep Representation Surgery},
  author={Yang, Enneng and Shen, Li and Wang, Zhenyi and Guo, Guibing and Wang, Xingwei and Cao, Xiaochun and Zhang, Jie and Tao, Dacheng},
  journal={arXiv preprint arXiv:2410.14389},
  year={2024}
}

@article{RepresentationSurgery_ICML_2024,
  title={Representation Surgery for Multi-Task Model Merging},
  author={Yang, Enneng and Shen, Li and Wang, Zhenyi and Guo, Guibing and Chen, Xiaojun and Wang, Xingwei and Tao, Dacheng},
  journal={Forty-first International Conference on Machine Learning},
  year={2024}
}

Thanks!

Datasets

Refer to dataset processing in the task_vectors.

Or you can download the processed data from Baidu Cloud disk.

Task Vectors / Checkpoints

You can download the fine-tuned checkpoints from the task_vectors#checkpoints. The Google Drive folder is: task_vectors_checkpoints

Note: When using torch.load(xxx_checkpoint).state_dict() fails, you can try pickle.load(open(xxx_checkpoint, 'rb')).state_dict().

Train

  • Model Merging Methods (e.g., Weight Averaging, Task Arithmetic, Ties-Merging, AdaMerging)
python src/main_tv.py
  • Model Merging Methods with Surgery (e.g., Weight Averaging w/ Surgery, Task Arithmetic w/ Surgery, Ties-Merging w/ Surgery, AdaMerging w/ Surgery)
python src/main_tv_surgery_v1.py
  • Model Merging Methods with our SurgeryV2 (e.g., Weight Averaging w/ Surgery V2, Task Arithmetic w/ Surgery V2, Ties-Merging w/ Surgery V2, AdaMerging w/ Surgery V2)
python src/main_tv_surgery_v2.py

Note: Due to machine memory limitations, our implementation reloaded the dataset at each step, which resulted in a significant amount of additional time. If your machine has enough memory, you can load all the data before optimizing the surgery module, which will speed up the training significantly.

Acknowledgement

Our implementation references the code below, thanks to them.

About

SurgeryV2: Bridging the Gap Between Model Merging and Multi-Task Learning with Deep Representation Surgery. Arxiv, 2024.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages