Skip to content

Commit

Permalink
add task wise adamerging
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 16, 2024
1 parent fba1f89 commit 34de1b0
Show file tree
Hide file tree
Showing 8 changed files with 570 additions and 1 deletion.
22 changes: 22 additions & 0 deletions config/method/task_wise_adamerging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: ??? # one of "clip_task_wise_adamerging"

# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
# if weights is specified, skip the test-time adaptation training
weights: null

# learning rate
lr: 0.0001
optimizer: adam

init_values: 0.3
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
clamp_weights: false

# arguments of `functional_call`
tie_weights: true
strict: false

devices: 1
batch_size: 16
num_workers: 4
max_steps: 1000
43 changes: 43 additions & 0 deletions config/modelpool/clip-vit-base-patch32_TA8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,46 @@ models:
path: tanganke/clip-vit-base-patch32_mnist
- name: dtd
path: tanganke/clip-vit-base-patch32_dtd


# The following datasets are used for test-time adaptation
dataset_type: huggingface_image_classification
tta_datasets:
- name: svhn
dataset:
type: instantiate
name: svhn
object:
_target_: datasets.load_dataset
_args_:
- svhn
- cropped_digits
split: test
- name: stanford_cars
dataset:
name: tanganke/stanford_cars
split: test
- name: resisc45
dataset:
name: tanganke/resisc45
split: test
- name: eurosat
dataset:
name: tanganke/eurosat
split: test
- name: gtsrb
dataset:
name: tanganke/gtsrb
split: test
- name: mnist
dataset:
name: mnist
split: test
- name: dtd
dataset:
name: tanganke/dtd
split: test
- name: sun397
dataset:
name: tanganke/sun397
split: test
3 changes: 3 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .weighted_average import WeightedAverageAlgorithm
from .task_arithmetic import TaskArithmeticAlgorithm
from .ties_merging.ties_merging import TiesMergingAlgorithm
from .adamerging.clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm


def load_algorithm_from_config(method_config: DictConfig):
Expand All @@ -18,5 +19,7 @@ def load_algorithm_from_config(method_config: DictConfig):
return TaskArithmeticAlgorithm(method_config)
elif method_config.name == "ties_merging":
return TiesMergingAlgorithm(method_config)
elif method_config.name == "clip_task_wise_adamerging":
return CLIPTaskWiseAdaMergingAlgorithm(method_config)
else:
raise ValueError(f"Unknown algorithm: {method_config.name}")
Empty file.
107 changes: 107 additions & 0 deletions fusion_bench/method/adamerging/clip_task_wise_adamerging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import functools
import itertools
import logging

import torch
from omegaconf import DictConfig, open_dict
from torch import Tensor
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPProcessor

from fusion_bench.dataset import CLIPDataset, load_dataset_from_config
from fusion_bench.modelpool.huggingface_clip_vision import HuggingFaceClipVisionPool
from fusion_bench.models.hf_clip import HFCLIPClassifier
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
from fusion_bench.utils import timeit_context

from .task_wise_adamerging import TaskWiseAdaMergingAlgorithm

log = logging.getLogger(__name__)


class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
modelpool: HuggingFaceClipVisionPool = None
_clip_processor: CLIPProcessor = None
zeroshot_weights = {}

def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

def get_task_config(self, task):
for task_config in self.config.tta_datasets:
if task_config.name == task:
return task_config
raise ValueError(f"Task {task} not found in config")

def prepare_dataset_config(self, dataset_config: DictConfig):
if not hasattr(dataset_config, "type"):
with open_dict(dataset_config):
dataset_config["type"] = self.config.dataset_type
return dataset_config

@functools.cache()
def get_test_dataset(self, task: str):
"""
Load the test dataset for the task.
This method is cached, so the dataset is loaded only once.
"""
dataset_config = self.get_task_config(task)["dataset"]
dataset_config = self.prepare_dataset_config(dataset_config)
log.info(f"Loading test dataset: {dataset_config.name}")
dataset = load_dataset_from_config(dataset_config)
dataset = CLIPDataset(dataset, self._clip_processor)
return dataset

@functools.cache()
def get_shuffled_test_loader_iter(self, task: str):
loader = DataLoader(
self.get_test_dataset(task),
batch_size=self.config.batch_size,
shuffle=True,
num_workers=self.config.num_workers,
pin_memory=True,
)
if self._fabric is not None:
loader = self._fabric.setup_dataloaders(loader)
return iter(itertools.cycle(loader))

def on_test_time_adaptation_start(self):
clip_model_config = self.modelpool.get_model_config("_pretrained_")

with timeit_context("Loading CLIP processor and pretrained CLIP model."):
self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path)
clip_model = CLIPModel.from_pretrained(clip_model_config.path)

clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
if self._fabric is not None:
self.visual_projection = self._fabric.to_device(self.visual_projection)
self.logit_scale = self._fabric.to_device(clip_model.logit_scale.exp())

for task in self.modelpool.model_names():
log.info(f"Construct zero shot classification head for task {task}")
classnames, templates = get_classnames_and_templates(
self.get_task_config(task)["dataset"].name
)
clip_classifier.set_classification_task(classnames, templates)
self.zeroshot_weights[task] = clip_classifier.zeroshot_weights
if self._fabric is not None:
self.zeroshot_weights[task] = self._fabric.to_device(
self.zeroshot_weights[task]
)

def compute_logits(self, module, batch, task) -> Tensor:
images, _ = batch
text_embeds = self.zeroshot_weights[task]

image_embeds = module(images)[1]
image_embeds = self.visual_projection(image_embeds)

# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale
logits_per_image = logits_per_text.t()

return logits_per_image
152 changes: 152 additions & 0 deletions fusion_bench/method/adamerging/task_wise_adamerging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import logging
from abc import abstractmethod
from copy import deepcopy
from typing import List, Mapping, Union

import lightning as L
import numpy as np
import torch
from omegaconf import DictConfig
from torch import Tensor, nn

from fusion_bench.models.wrappers.task_wise_fusion import (
TaskWiseMergedModel,
get_task_wise_weights,
)
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub

from ...modelpool import ModelPool
from ...utils.type import _StateDict
from ..base_algorithm import ModelFusionAlgorithm
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader

log = logging.getLogger(__name__)


def entropy_loss(logits: Tensor) -> Tensor:
"""
Compute the entropy loss of a set of logits.
Args:
logits (Tensor): The logits to compute the entropy loss of.
Returns:
Tensor: The entropy loss of the logits.
"""
probs = torch.softmax(logits, dim=-1)
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()


class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
_fabric: L.Fabric = None

def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

if self._fabric is not None and torch.cuda.is_available():
self._fabric = L.Fabric(devices=self.config.devices)
self._fabric.launch()

@torch.no_grad()
def construct_task_wise_merged_model(self, modelpool: ModelPool):
if self.config.weights is None:
task_wise_weight = get_task_wise_weights(
num_models=len(modelpool.model_names),
init_values=self.config.init_values,
)
else:
if isinstance(self.config.weights, str):
# self.config.weights is a path to a .np or .pt file
if self.config.weights.endswith(".pt"):
task_wise_weight = torch.load(
self.config.weights, map_location="cpu"
).detach_()
elif self.config.weights.endswith(".np"):
task_wise_weight = torch.from_numpy(
np.load(self.config.weights)
).detach_()
else:
raise ValueError(f"Unsupported file format: {self.config.weights}")
else:
try:
task_wise_weight = torch.tensor(
list(self.config.weights), dtype=torch.float32
)
except ValueError:
raise ValueError(
f"Unsupported weights format: {self.config.weights}"
)

pretrained_model = modelpool.load_model("_pretrained_")
finetuned_models = [
modelpool.load_model(name) for name in modelpool.model_names
]

module = TaskWiseMergedModel(
task_wise_weight=task_wise_weight,
pretrained_model=pretrained_model,
finetuned_models=finetuned_models,
clamp_weights=self.config.clamp_weights,
tie_weights=self.config.tie_weights,
strict=self.config.strict,
)
return module

def fuse(self, modelpool: ModelPool):
log.info("Fusing models using task-wise adaptive merging.")
self.modelpool = modelpool

module = self.construct_task_wise_merged_model(modelpool)

if self.config.weights is not None:
# skip the test-time adaptation
return module.merge_and_unload(module)
else:
module = self.test_time_adaptation()
return module.merge_and_unload()

def on_test_time_adaptation_start(self):
pass

@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
pass

@abstractmethod
def compute_logits(self, module, batch, task) -> Tensor:
pass

def test_time_adaptation(self, module: TaskWiseMergedModel):
self.on_test_time_adaptation_start()

# configure optimizer
if self.config.optimizer == "adam":
optimizer = torch.optim.Adam(
[self.module.task_wise_weight], lr=self.config.lr
)
else:
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

if self._fabric is not None:
module, optimizer = self._fabric.setup(module, optimizer)

module.train()
module.merge_weights()
for step_idx in tqdm(
self.config.max_steps, "AdaMerging Test-time adaptation", dynamic_ncols=True
):
loss = 0
for task in self.modelpool.model_names:
batch = next(self.get_shuffled_test_loader_iter(task))
logits = self.compute_logits(module, batch, task)
assert (
logits.dim() == 2
), f"Expected logits to be 2D, got {logits.dim()}"
loss = loss + entropy_loss(logits)
optimizer.zero_grad()
self._fabric.backward(loss)
optimizer.step()
module.merge_weights()

return module
Loading

0 comments on commit 34de1b0

Please sign in to comment.