-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
570 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
name: ??? # one of "clip_task_wise_adamerging" | ||
|
||
# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights | ||
# if weights is specified, skip the test-time adaptation training | ||
weights: null | ||
|
||
# learning rate | ||
lr: 0.0001 | ||
optimizer: adam | ||
|
||
init_values: 0.3 | ||
# if `clamp_weights` is true, the weights will be clamped to [0, 1] | ||
clamp_weights: false | ||
|
||
# arguments of `functional_call` | ||
tie_weights: true | ||
strict: false | ||
|
||
devices: 1 | ||
batch_size: 16 | ||
num_workers: 4 | ||
max_steps: 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
107 changes: 107 additions & 0 deletions
107
fusion_bench/method/adamerging/clip_task_wise_adamerging.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import functools | ||
import itertools | ||
import logging | ||
|
||
import torch | ||
from omegaconf import DictConfig, open_dict | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
from transformers import CLIPModel, CLIPProcessor | ||
|
||
from fusion_bench.dataset import CLIPDataset, load_dataset_from_config | ||
from fusion_bench.modelpool.huggingface_clip_vision import HuggingFaceClipVisionPool | ||
from fusion_bench.models.hf_clip import HFCLIPClassifier | ||
from fusion_bench.tasks.clip_classification import get_classnames_and_templates | ||
from fusion_bench.utils import timeit_context | ||
|
||
from .task_wise_adamerging import TaskWiseAdaMergingAlgorithm | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm): | ||
modelpool: HuggingFaceClipVisionPool = None | ||
_clip_processor: CLIPProcessor = None | ||
zeroshot_weights = {} | ||
|
||
def __init__(self, algorithm_config: DictConfig): | ||
super().__init__(algorithm_config) | ||
|
||
def get_task_config(self, task): | ||
for task_config in self.config.tta_datasets: | ||
if task_config.name == task: | ||
return task_config | ||
raise ValueError(f"Task {task} not found in config") | ||
|
||
def prepare_dataset_config(self, dataset_config: DictConfig): | ||
if not hasattr(dataset_config, "type"): | ||
with open_dict(dataset_config): | ||
dataset_config["type"] = self.config.dataset_type | ||
return dataset_config | ||
|
||
@functools.cache() | ||
def get_test_dataset(self, task: str): | ||
""" | ||
Load the test dataset for the task. | ||
This method is cached, so the dataset is loaded only once. | ||
""" | ||
dataset_config = self.get_task_config(task)["dataset"] | ||
dataset_config = self.prepare_dataset_config(dataset_config) | ||
log.info(f"Loading test dataset: {dataset_config.name}") | ||
dataset = load_dataset_from_config(dataset_config) | ||
dataset = CLIPDataset(dataset, self._clip_processor) | ||
return dataset | ||
|
||
@functools.cache() | ||
def get_shuffled_test_loader_iter(self, task: str): | ||
loader = DataLoader( | ||
self.get_test_dataset(task), | ||
batch_size=self.config.batch_size, | ||
shuffle=True, | ||
num_workers=self.config.num_workers, | ||
pin_memory=True, | ||
) | ||
if self._fabric is not None: | ||
loader = self._fabric.setup_dataloaders(loader) | ||
return iter(itertools.cycle(loader)) | ||
|
||
def on_test_time_adaptation_start(self): | ||
clip_model_config = self.modelpool.get_model_config("_pretrained_") | ||
|
||
with timeit_context("Loading CLIP processor and pretrained CLIP model."): | ||
self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path) | ||
clip_model = CLIPModel.from_pretrained(clip_model_config.path) | ||
|
||
clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor) | ||
self.visual_projection = clip_model.visual_projection.requires_grad_(False) | ||
if self._fabric is not None: | ||
self.visual_projection = self._fabric.to_device(self.visual_projection) | ||
self.logit_scale = self._fabric.to_device(clip_model.logit_scale.exp()) | ||
|
||
for task in self.modelpool.model_names(): | ||
log.info(f"Construct zero shot classification head for task {task}") | ||
classnames, templates = get_classnames_and_templates( | ||
self.get_task_config(task)["dataset"].name | ||
) | ||
clip_classifier.set_classification_task(classnames, templates) | ||
self.zeroshot_weights[task] = clip_classifier.zeroshot_weights | ||
if self._fabric is not None: | ||
self.zeroshot_weights[task] = self._fabric.to_device( | ||
self.zeroshot_weights[task] | ||
) | ||
|
||
def compute_logits(self, module, batch, task) -> Tensor: | ||
images, _ = batch | ||
text_embeds = self.zeroshot_weights[task] | ||
|
||
image_embeds = module(images)[1] | ||
image_embeds = self.visual_projection(image_embeds) | ||
|
||
# normalize embeddings | ||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | ||
|
||
# cosine similarity | ||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale | ||
logits_per_image = logits_per_text.t() | ||
|
||
return logits_per_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import logging | ||
from abc import abstractmethod | ||
from copy import deepcopy | ||
from typing import List, Mapping, Union | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
from omegaconf import DictConfig | ||
from torch import Tensor, nn | ||
|
||
from fusion_bench.models.wrappers.task_wise_fusion import ( | ||
TaskWiseMergedModel, | ||
get_task_wise_weights, | ||
) | ||
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub | ||
|
||
from ...modelpool import ModelPool | ||
from ...utils.type import _StateDict | ||
from ..base_algorithm import ModelFusionAlgorithm | ||
from tqdm.autonotebook import tqdm | ||
from torch.utils.data import DataLoader | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def entropy_loss(logits: Tensor) -> Tensor: | ||
""" | ||
Compute the entropy loss of a set of logits. | ||
Args: | ||
logits (Tensor): The logits to compute the entropy loss of. | ||
Returns: | ||
Tensor: The entropy loss of the logits. | ||
""" | ||
probs = torch.softmax(logits, dim=-1) | ||
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean() | ||
|
||
|
||
class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm): | ||
_fabric: L.Fabric = None | ||
|
||
def __init__(self, algorithm_config: DictConfig): | ||
super().__init__(algorithm_config) | ||
|
||
if self._fabric is not None and torch.cuda.is_available(): | ||
self._fabric = L.Fabric(devices=self.config.devices) | ||
self._fabric.launch() | ||
|
||
@torch.no_grad() | ||
def construct_task_wise_merged_model(self, modelpool: ModelPool): | ||
if self.config.weights is None: | ||
task_wise_weight = get_task_wise_weights( | ||
num_models=len(modelpool.model_names), | ||
init_values=self.config.init_values, | ||
) | ||
else: | ||
if isinstance(self.config.weights, str): | ||
# self.config.weights is a path to a .np or .pt file | ||
if self.config.weights.endswith(".pt"): | ||
task_wise_weight = torch.load( | ||
self.config.weights, map_location="cpu" | ||
).detach_() | ||
elif self.config.weights.endswith(".np"): | ||
task_wise_weight = torch.from_numpy( | ||
np.load(self.config.weights) | ||
).detach_() | ||
else: | ||
raise ValueError(f"Unsupported file format: {self.config.weights}") | ||
else: | ||
try: | ||
task_wise_weight = torch.tensor( | ||
list(self.config.weights), dtype=torch.float32 | ||
) | ||
except ValueError: | ||
raise ValueError( | ||
f"Unsupported weights format: {self.config.weights}" | ||
) | ||
|
||
pretrained_model = modelpool.load_model("_pretrained_") | ||
finetuned_models = [ | ||
modelpool.load_model(name) for name in modelpool.model_names | ||
] | ||
|
||
module = TaskWiseMergedModel( | ||
task_wise_weight=task_wise_weight, | ||
pretrained_model=pretrained_model, | ||
finetuned_models=finetuned_models, | ||
clamp_weights=self.config.clamp_weights, | ||
tie_weights=self.config.tie_weights, | ||
strict=self.config.strict, | ||
) | ||
return module | ||
|
||
def fuse(self, modelpool: ModelPool): | ||
log.info("Fusing models using task-wise adaptive merging.") | ||
self.modelpool = modelpool | ||
|
||
module = self.construct_task_wise_merged_model(modelpool) | ||
|
||
if self.config.weights is not None: | ||
# skip the test-time adaptation | ||
return module.merge_and_unload(module) | ||
else: | ||
module = self.test_time_adaptation() | ||
return module.merge_and_unload() | ||
|
||
def on_test_time_adaptation_start(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader: | ||
pass | ||
|
||
@abstractmethod | ||
def compute_logits(self, module, batch, task) -> Tensor: | ||
pass | ||
|
||
def test_time_adaptation(self, module: TaskWiseMergedModel): | ||
self.on_test_time_adaptation_start() | ||
|
||
# configure optimizer | ||
if self.config.optimizer == "adam": | ||
optimizer = torch.optim.Adam( | ||
[self.module.task_wise_weight], lr=self.config.lr | ||
) | ||
else: | ||
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}") | ||
|
||
if self._fabric is not None: | ||
module, optimizer = self._fabric.setup(module, optimizer) | ||
|
||
module.train() | ||
module.merge_weights() | ||
for step_idx in tqdm( | ||
self.config.max_steps, "AdaMerging Test-time adaptation", dynamic_ncols=True | ||
): | ||
loss = 0 | ||
for task in self.modelpool.model_names: | ||
batch = next(self.get_shuffled_test_loader_iter(task)) | ||
logits = self.compute_logits(module, batch, task) | ||
assert ( | ||
logits.dim() == 2 | ||
), f"Expected logits to be 2D, got {logits.dim()}" | ||
loss = loss + entropy_loss(logits) | ||
optimizer.zero_grad() | ||
self._fabric.backward(loss) | ||
optimizer.step() | ||
module.merge_weights() | ||
|
||
return module |
Oops, something went wrong.