Skip to content

Commit

Permalink
Merge pull request #49 from hetailang/add_surgery
Browse files Browse the repository at this point in the history
add method 'surgery'
  • Loading branch information
tanganke authored Dec 20, 2024
2 parents 68cc9b9 + c4f0d97 commit e25b3d0
Show file tree
Hide file tree
Showing 15 changed files with 539 additions and 42 deletions.
27 changes: 27 additions & 0 deletions config/method/surgery/adamerging_surgery.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# this option can be "clip_task_wise_adamerging"
name: clip_layer_wise_adamerging_surgery
# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
# if weights is specified, skip the test-time adaptation training
weights: null
# learning rate
optimizer: adam
lr: 1e-3
init_values: 0.3
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
clamp_weights: false
# arguments of `functional_call`
tie_weights: true
strict: false
# this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
devices: 1
batch_size: 16
num_workers: 8
max_steps: 1000
fast_dev_run: ${fast_dev_run}
# the path for saving the merging weights
save_merging_weights: 'merging_weights.pt'
cache_dir: outputs

# parameters of Surgery
eval_iterations: 200
surgery_steps: 1000
1 change: 1 addition & 0 deletions fusion_bench/compat/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class AlgorithmFactory:
"clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
"clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
"singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
"clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
# plug-and-play model merging methods
"clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
"clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
Expand Down
8 changes: 7 additions & 1 deletion fusion_bench/compat/method/base_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, TYPE_CHECKING

from omegaconf import DictConfig

if TYPE_CHECKING:
from fusion_bench.programs.base_program import BaseHydraProgram

__all__ = ["ModelFusionAlgorithm"]


Expand All @@ -18,6 +21,9 @@ class ModelFusionAlgorithm(ABC):
config (DictConfig): Configuration for the algorithm.
"""

_program: "BaseHydraProgram" = None
"""A reference to the program that is running the algorithm."""

def __init__(self, algorithm_config: Optional[DictConfig] = None):
"""
Initialize the model fusion algorithm with the given configuration.
Expand Down
15 changes: 11 additions & 4 deletions fusion_bench/method/adamerging/layer_wise_adamerging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
from abc import abstractmethod
from typing import Any, List, Mapping, Union, cast # noqa: F401
from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, cast # noqa: F401

import torch
from lightning.fabric.utilities.rank_zero import rank_zero_only
from omegaconf import DictConfig
from torch import Tensor
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

Expand All @@ -19,10 +19,14 @@
get_layer_wise_weights,
)
from fusion_bench.utils.data import load_tensor_from_file
from fusion_bench.utils.type import TorchModelType

from .entropy_loss import entropy_loss
from .utils import get_memory_usage

if TYPE_CHECKING:
from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram

log = logging.getLogger(__name__)


Expand All @@ -31,6 +35,9 @@ class LayerWiseAdaMergingAlgorithm(
LightningFabricMixin,
SimpleProfilerMixin,
):
_program: "FabricModelFusionProgram"
"""The program that this algorithm is running on."""

"""
Implements the Layer-Wise AdaMerging Algorithm.
Expand All @@ -48,7 +55,7 @@ def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

@torch.no_grad()
def construct_layer_wise_merged_model(self, modelpool: ModelPool):
def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
"""
Constructs a wrapped layer-wise merged model from model pool.
Expand Down Expand Up @@ -183,7 +190,7 @@ def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
"""
pass

def test_time_adaptation(self, module: LayerWiseMergedModel):
def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
"""
Perform test-time adaptation on the merged model.
Expand Down
8 changes: 4 additions & 4 deletions fusion_bench/method/adamerging/min_norm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _min_norm_element_from2(v1v1, v1v2, v2v2):
return gamma, cost

def _min_norm_2d(vecs, dps):
"""
R"""
Find the minimum norm solution as combination of two points
This is correct only in 2D
ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
Expand Down Expand Up @@ -85,7 +85,7 @@ def _min_norm_2d(vecs, dps):
return sol, dps

def _projection2simplex(y):
"""
R"""
Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
"""
m = len(y)
Expand Down Expand Up @@ -117,7 +117,7 @@ def _next_point(cur_val, grad, n):
return next_point

def find_min_norm_element(vecs):
"""
R"""
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
Expand Down Expand Up @@ -163,7 +163,7 @@ def find_min_norm_element(vecs):
sol_vec = new_sol_vec

def find_min_norm_element_FW(vecs):
"""
R"""
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
Expand Down
3 changes: 3 additions & 0 deletions fusion_bench/method/surgery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .clip_layer_wise_adamerging_surgery import (
CLIPLayerWiseAdaMergingSurgeryAlgorithm,
)
157 changes: 157 additions & 0 deletions fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Implementation of the Layer-Wise AdaMerging+Surgery Algorithm.
For more details, please refer to:
- (ICLR 2024) Yang, et.al. AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
- (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging. https://arxiv.org/abs/2402.02705
Basic Example:
```shell
fusion_bench \
method=surgery/adamerging_surgery \
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
```
"""

import copy
import functools
import gc
import logging
from typing import TYPE_CHECKING, cast

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import CLIPVisionModel

from fusion_bench.dataset.clip_dataset import CLIPDataset
from fusion_bench.method.adamerging.layer_wise_adamerging import (
LayerWiseAdaMergingAlgorithm,
)
from fusion_bench.method.adamerging.utils import get_memory_usage
from fusion_bench.mixins import CLIPClassificationMixin
from fusion_bench.modelpool import CLIPVisionModelPool
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel

log = logging.getLogger(__name__)


class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
CLIPClassificationMixin,
LayerWiseAdaMergingAlgorithm,
):

def on_test_time_adaptation_start(self):
"""
Here we load the CLIP processor and construct the zero-shot classification head for each task.
"""
self.setup_zero_shot_classification_head()

@functools.cache
def get_shuffled_test_loader_iter(self, task: str):
return super().get_shuffled_test_loader_iter(
task,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers,
)

def run(self, modelpool: CLIPVisionModelPool, **kwargs):
"""
Run the Layer-Wise AdaMerging+Surgery Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary. Then, it will perform surgery.
Args:
modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
Returns:
LayerWiseMergedModel: The merged model after test-time adaptation.
"""
log.info("Fusing models using layer-wise adaptive merging.")
self.modelpool = modelpool
self.log_hyperparams(self.config)

# === Start of the AdaMerging Algorithm ===
with self.profile("construct the wrapped model"):
module = cast(
LayerWiseMergedModel[CLIPVisionModel],
self.construct_layer_wise_merged_model(modelpool),
)

if self.config.weights is not None:
# skip the test-time adaptation
merged_model = copy.deepcopy(module.merge_and_unload())
else:
with self.profile("test-time adaptation"):
module = self.test_time_adaptation(module)
if self.config.get("save_merging_weights", False):
self.save_merging_weights(
self.config.save_merging_weights, module.merge_weight
)
merged_model = copy.deepcopy(module.merge_and_unload())

# free memory
del module
gc.collect()
torch.cuda.empty_cache()

# === Start of the Surgery Algorithm ===
log.info("start performing Surgery")
alpha_model = SurgeryModelWrapper(
merged_model,
modelpool.model_names,
projection_dim=merged_model.config.projection_dim,
)
alpha_model = self.fabric.setup(alpha_model)
log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))

optimizer = torch.optim.Adam(
alpha_model.collect_trainable_params(),
lr=1e-3,
betas=(0.9, 0.999),
weight_decay=0.0,
)

finetuned_models = {
model_name: modelpool.load_model(model_name)
for model_name in modelpool.model_names
}
for name, model in finetuned_models.items():
model.requires_grad_(False)
model = self.fabric.to_device(model)
model.eval()

for iteration in tqdm(
range(self.config.surgery_steps),
"surgery",
dynamic_ncols=True,
):
for dataset_name in modelpool.model_names:
batch = next(self.get_shuffled_test_loader_iter(dataset_name))
finetuned_feature = self.compute_features(
finetuned_models[dataset_name], batch[0]
)
features, _, _ = alpha_model.compute_surgery_features(
lambda model: self.compute_features(model, batch[0]),
dataset_name,
)

loss = F.l1_loss(features, finetuned_feature)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if ((iteration + 1) % self.config.eval_iterations) == 0:
# print(list(alpha_model.collect_trainable_params()))
# Evaluate try to use the test module in fusion bench
log.info(f"iteration: {iteration+1}")
self._program.evaluate_merged_model(self._program.taskpool, alpha_model)

log.info("test the result of Adamerging")
return merged_model
Loading

0 comments on commit e25b3d0

Please sign in to comment.