Skip to content

Commit

Permalink
add layer-wise model adamerging
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 16, 2024
1 parent 15f5469 commit 53f8df9
Show file tree
Hide file tree
Showing 6 changed files with 534 additions and 3 deletions.
File renamed without changes.
3 changes: 3 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .task_arithmetic import TaskArithmeticAlgorithm
from .ties_merging.ties_merging import TiesMergingAlgorithm
from .adamerging.clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm
from .adamerging.clip_layer_wise_adamerging import CLIPLayerWiseAdaMergingAlgorithm


def load_algorithm_from_config(method_config: DictConfig):
Expand All @@ -21,5 +22,7 @@ def load_algorithm_from_config(method_config: DictConfig):
return TiesMergingAlgorithm(method_config)
elif method_config.name == "clip_task_wise_adamerging":
return CLIPTaskWiseAdaMergingAlgorithm(method_config)
elif method_config.name == "clip_layer_wise_adamerging":
return CLIPLayerWiseAdaMergingAlgorithm(method_config)
else:
raise ValueError(f"Unknown algorithm: {method_config.name}")
140 changes: 140 additions & 0 deletions fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import functools
import itertools
import logging

import torch
from omegaconf import DictConfig, open_dict
from torch import Tensor
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPProcessor

from fusion_bench.dataset import CLIPDataset, load_dataset_from_config
from fusion_bench.modelpool.huggingface_clip_vision import HuggingFaceClipVisionPool
from fusion_bench.models.hf_clip import HFCLIPClassifier
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
from fusion_bench.utils import timeit_context

from .layer_wise_adamerging import LayerWiseAdaMergingAlgorithm
import os

log = logging.getLogger(__name__)


class InfiniteDataLoader:
def __init__(self, data_loader):
self.data_loader = data_loader
self.data_iter = iter(data_loader)

def __iter__(self):
return self

def __next__(self):
try:
data = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.data_loader) # Reset the data loader
data = next(self.data_iter)
return data


class CLIPLayerWiseAdaMergingAlgorithm(LayerWiseAdaMergingAlgorithm):
modelpool: HuggingFaceClipVisionPool = None
_clip_processor: CLIPProcessor = None
zeroshot_weights = {}

def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

def get_task_config(self, task):
for task_config in self.modelpool.config.tta_datasets:
if task_config.name == task:
return task_config
raise ValueError(f"Task {task} not found in config")

def prepare_dataset_config(self, dataset_config: DictConfig):
if not hasattr(dataset_config, "type"):
with open_dict(dataset_config):
dataset_config["type"] = self.modelpool.config.dataset_type
return dataset_config

@functools.cache
def get_test_dataset(self, task: str):
"""
Load the test dataset for the task.
This method is cached, so the dataset is loaded only once.
"""
dataset_config = self.get_task_config(task)["dataset"]
dataset_config = self.prepare_dataset_config(dataset_config)
log.info(f"Loading test dataset: {dataset_config.name}")
dataset = load_dataset_from_config(dataset_config)
dataset = CLIPDataset(dataset, self._clip_processor)
return dataset

@functools.cache
def get_shuffled_test_loader_iter(self, task: str):
loader = DataLoader(
self.get_test_dataset(task),
batch_size=self.config.batch_size,
shuffle=True,
num_workers=self.config.num_workers,
pin_memory=True,
)
if self._fabric is not None:
loader = self._fabric.setup_dataloaders(loader)
return iter(InfiniteDataLoader(loader))

def on_test_time_adaptation_start(self):
"""
Here we load the CLIP processor and construct the zero-shot classification head for each task.
"""
clip_model_config = self.modelpool.get_model_config("_pretrained_")

with timeit_context("Loading CLIP processor and pretrained CLIP model."):
self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path)
clip_model = CLIPModel.from_pretrained(clip_model_config.path)

clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
self.logit_scale = clip_model.logit_scale.exp()
if self._fabric is not None:
self.visual_projection = self._fabric.to_device(self.visual_projection)
self.logit_scale = self._fabric.to_device(self.logit_scale)

for task in self.modelpool.model_names:
cache_file = os.path.join(
self.config.cache_dir,
f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",
)
if os.path.exists(cache_file):
log.info(f"Loading cached zeroshot weights for task: {task}")
zeroshot_weights = torch.load(cache_file, map_location="cpu")
else:
log.info(f"Construct zero shot classification head for task: {task}")
classnames, templates = get_classnames_and_templates(
self.get_task_config(task)["dataset"].name
)
clip_classifier.set_classification_task(classnames, templates)
zeroshot_weights = clip_classifier.zeroshot_weights
log.info(f"save zeroshot weights to {cache_file}")
torch.save(zeroshot_weights, cache_file)
self.zeroshot_weights[task] = zeroshot_weights
if self._fabric is not None:
self.zeroshot_weights[task] = self._fabric.to_device(
self.zeroshot_weights[task]
)

def compute_logits(self, module, batch, task) -> Tensor:
images, _ = batch
text_embeds = self.zeroshot_weights[task]

image_embeds = module(images)[1]
image_embeds = self.visual_projection(image_embeds)

# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale
logits_per_image = logits_per_text.t()

return logits_per_image
171 changes: 171 additions & 0 deletions fusion_bench/method/adamerging/layer_wise_adamerging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
from abc import abstractmethod
from copy import deepcopy
from typing import List, Mapping, Union

import lightning as L
import numpy as np
import torch
from omegaconf import DictConfig
from torch import Tensor, nn

from fusion_bench.models.wrappers.layer_wise_fusion import (
LayerWiseMergedModel,
get_layer_wise_weights,
)
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub

from ...modelpool import ModelPool
from ...utils.type import _StateDict
from ..base_algorithm import ModelFusionAlgorithm
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader

log = logging.getLogger(__name__)


def entropy_loss(logits: Tensor) -> Tensor:
"""
Compute the entropy loss of a set of logits.
Args:
logits (Tensor): The logits to compute the entropy loss of.
Returns:
Tensor: The entropy loss of the logits.
"""
probs = torch.softmax(logits, dim=-1)
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()


class LayerWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
_fabric: L.Fabric = None

def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

if self._fabric is None and torch.cuda.is_available():
self._fabric = L.Fabric(devices=self.config.devices)
self._fabric.launch()

@torch.no_grad()
def construct_layer_wise_merged_model(self, modelpool: ModelPool):
pretrained_model = modelpool.load_model("_pretrained_")
finetuned_models = [
modelpool.load_model(name) for name in modelpool.model_names
]

if self.config.weights is None:
layer_wise_weight = get_layer_wise_weights(
num_models=len(modelpool.model_names),
num_layers=len(
list(
filter(lambda p: p.requires_grad, pretrained_model.parameters())
)
),
init_values=self.config.init_values,
)
else:
if isinstance(self.config.weights, str):
# self.config.weights is a path to a .np or .pt file
if self.config.weights.endswith(".pt"):
layer_wise_weight = torch.load(
self.config.weights, map_location="cpu"
).detach_()
elif self.config.weights.endswith(".np"):
layer_wise_weight = torch.from_numpy(
np.load(self.config.weights)
).detach_()
else:
raise ValueError(f"Unsupported file format: {self.config.weights}")
else:
try:
layer_wise_weight = torch.tensor(
list(self.config.weights), dtype=torch.float32
)
except ValueError:
raise ValueError(
f"Unsupported weights format: {self.config.weights}"
)

module = LayerWiseMergedModel(
layer_wise_weight=layer_wise_weight,
pretrained_model=pretrained_model,
finetuned_models=finetuned_models,
clamp_weights=self.config.clamp_weights,
tie_weights=self.config.tie_weights,
strict=self.config.strict,
)
return module

def fuse(self, modelpool: ModelPool):
log.info("Fusing models using layer-wise adaptive merging.")
self.modelpool = modelpool

module = self.construct_layer_wise_merged_model(modelpool)

if self.config.weights is not None:
# skip the test-time adaptation
return module.merge_and_unload(module)
else:
module = self.test_time_adaptation(module)
if self.config.get("save_merging_weights", False):
torch.save(module.merge_weight, self.config.save_merging_weights)
return module.merge_and_unload()

def on_test_time_adaptation_start(self):
pass

@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
pass

@abstractmethod
def compute_logits(self, module, batch, task) -> Tensor:
pass

def test_time_adaptation(self, module: LayerWiseMergedModel):
self.on_test_time_adaptation_start()

# configure optimizer
if self.config.optimizer == "adam":
optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
else:
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

if self._fabric is not None:
module, optimizer = self._fabric.setup(module, optimizer)

module.train()
module.merge_weights()

if self.config.get("fast_dev_run", False):
log.info("Running fast_dev_run, only one step")
pbar = tqdm(
range(1),
"AdaMerging Test-time adaptation",
dynamic_ncols=True,
)
else:
pbar = tqdm(
range(self.config.max_steps),
"AdaMerging Test-time adaptation",
dynamic_ncols=True,
)
for step_idx in pbar:
for task in self.modelpool.model_names:
batch = next(self.get_shuffled_test_loader_iter(task))
logits = self.compute_logits(module, batch, task)
assert (
logits.dim() == 2
), f"Expected logits to be 2D, got {logits.dim()}"
loss = entropy_loss(logits)
# .backward() accumulates when .zero_grad() wasn't called
# this can save memory
self._fabric.backward(loss, retain_graph=True)

optimizer.step()
optimizer.zero_grad()
module.merge_weights()

return module
Loading

0 comments on commit 53f8df9

Please sign in to comment.