-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
534 additions
and
3 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import functools | ||
import itertools | ||
import logging | ||
|
||
import torch | ||
from omegaconf import DictConfig, open_dict | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
from transformers import CLIPModel, CLIPProcessor | ||
|
||
from fusion_bench.dataset import CLIPDataset, load_dataset_from_config | ||
from fusion_bench.modelpool.huggingface_clip_vision import HuggingFaceClipVisionPool | ||
from fusion_bench.models.hf_clip import HFCLIPClassifier | ||
from fusion_bench.tasks.clip_classification import get_classnames_and_templates | ||
from fusion_bench.utils import timeit_context | ||
|
||
from .layer_wise_adamerging import LayerWiseAdaMergingAlgorithm | ||
import os | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class InfiniteDataLoader: | ||
def __init__(self, data_loader): | ||
self.data_loader = data_loader | ||
self.data_iter = iter(data_loader) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
try: | ||
data = next(self.data_iter) | ||
except StopIteration: | ||
self.data_iter = iter(self.data_loader) # Reset the data loader | ||
data = next(self.data_iter) | ||
return data | ||
|
||
|
||
class CLIPLayerWiseAdaMergingAlgorithm(LayerWiseAdaMergingAlgorithm): | ||
modelpool: HuggingFaceClipVisionPool = None | ||
_clip_processor: CLIPProcessor = None | ||
zeroshot_weights = {} | ||
|
||
def __init__(self, algorithm_config: DictConfig): | ||
super().__init__(algorithm_config) | ||
|
||
def get_task_config(self, task): | ||
for task_config in self.modelpool.config.tta_datasets: | ||
if task_config.name == task: | ||
return task_config | ||
raise ValueError(f"Task {task} not found in config") | ||
|
||
def prepare_dataset_config(self, dataset_config: DictConfig): | ||
if not hasattr(dataset_config, "type"): | ||
with open_dict(dataset_config): | ||
dataset_config["type"] = self.modelpool.config.dataset_type | ||
return dataset_config | ||
|
||
@functools.cache | ||
def get_test_dataset(self, task: str): | ||
""" | ||
Load the test dataset for the task. | ||
This method is cached, so the dataset is loaded only once. | ||
""" | ||
dataset_config = self.get_task_config(task)["dataset"] | ||
dataset_config = self.prepare_dataset_config(dataset_config) | ||
log.info(f"Loading test dataset: {dataset_config.name}") | ||
dataset = load_dataset_from_config(dataset_config) | ||
dataset = CLIPDataset(dataset, self._clip_processor) | ||
return dataset | ||
|
||
@functools.cache | ||
def get_shuffled_test_loader_iter(self, task: str): | ||
loader = DataLoader( | ||
self.get_test_dataset(task), | ||
batch_size=self.config.batch_size, | ||
shuffle=True, | ||
num_workers=self.config.num_workers, | ||
pin_memory=True, | ||
) | ||
if self._fabric is not None: | ||
loader = self._fabric.setup_dataloaders(loader) | ||
return iter(InfiniteDataLoader(loader)) | ||
|
||
def on_test_time_adaptation_start(self): | ||
""" | ||
Here we load the CLIP processor and construct the zero-shot classification head for each task. | ||
""" | ||
clip_model_config = self.modelpool.get_model_config("_pretrained_") | ||
|
||
with timeit_context("Loading CLIP processor and pretrained CLIP model."): | ||
self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path) | ||
clip_model = CLIPModel.from_pretrained(clip_model_config.path) | ||
|
||
clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor) | ||
self.visual_projection = clip_model.visual_projection.requires_grad_(False) | ||
self.logit_scale = clip_model.logit_scale.exp() | ||
if self._fabric is not None: | ||
self.visual_projection = self._fabric.to_device(self.visual_projection) | ||
self.logit_scale = self._fabric.to_device(self.logit_scale) | ||
|
||
for task in self.modelpool.model_names: | ||
cache_file = os.path.join( | ||
self.config.cache_dir, | ||
f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt", | ||
) | ||
if os.path.exists(cache_file): | ||
log.info(f"Loading cached zeroshot weights for task: {task}") | ||
zeroshot_weights = torch.load(cache_file, map_location="cpu") | ||
else: | ||
log.info(f"Construct zero shot classification head for task: {task}") | ||
classnames, templates = get_classnames_and_templates( | ||
self.get_task_config(task)["dataset"].name | ||
) | ||
clip_classifier.set_classification_task(classnames, templates) | ||
zeroshot_weights = clip_classifier.zeroshot_weights | ||
log.info(f"save zeroshot weights to {cache_file}") | ||
torch.save(zeroshot_weights, cache_file) | ||
self.zeroshot_weights[task] = zeroshot_weights | ||
if self._fabric is not None: | ||
self.zeroshot_weights[task] = self._fabric.to_device( | ||
self.zeroshot_weights[task] | ||
) | ||
|
||
def compute_logits(self, module, batch, task) -> Tensor: | ||
images, _ = batch | ||
text_embeds = self.zeroshot_weights[task] | ||
|
||
image_embeds = module(images)[1] | ||
image_embeds = self.visual_projection(image_embeds) | ||
|
||
# normalize embeddings | ||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | ||
|
||
# cosine similarity | ||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale | ||
logits_per_image = logits_per_text.t() | ||
|
||
return logits_per_image |
171 changes: 171 additions & 0 deletions
171
fusion_bench/method/adamerging/layer_wise_adamerging.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import logging | ||
from abc import abstractmethod | ||
from copy import deepcopy | ||
from typing import List, Mapping, Union | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
from omegaconf import DictConfig | ||
from torch import Tensor, nn | ||
|
||
from fusion_bench.models.wrappers.layer_wise_fusion import ( | ||
LayerWiseMergedModel, | ||
get_layer_wise_weights, | ||
) | ||
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub | ||
|
||
from ...modelpool import ModelPool | ||
from ...utils.type import _StateDict | ||
from ..base_algorithm import ModelFusionAlgorithm | ||
from tqdm.autonotebook import tqdm | ||
from torch.utils.data import DataLoader | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def entropy_loss(logits: Tensor) -> Tensor: | ||
""" | ||
Compute the entropy loss of a set of logits. | ||
Args: | ||
logits (Tensor): The logits to compute the entropy loss of. | ||
Returns: | ||
Tensor: The entropy loss of the logits. | ||
""" | ||
probs = torch.softmax(logits, dim=-1) | ||
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean() | ||
|
||
|
||
class LayerWiseAdaMergingAlgorithm(ModelFusionAlgorithm): | ||
_fabric: L.Fabric = None | ||
|
||
def __init__(self, algorithm_config: DictConfig): | ||
super().__init__(algorithm_config) | ||
|
||
if self._fabric is None and torch.cuda.is_available(): | ||
self._fabric = L.Fabric(devices=self.config.devices) | ||
self._fabric.launch() | ||
|
||
@torch.no_grad() | ||
def construct_layer_wise_merged_model(self, modelpool: ModelPool): | ||
pretrained_model = modelpool.load_model("_pretrained_") | ||
finetuned_models = [ | ||
modelpool.load_model(name) for name in modelpool.model_names | ||
] | ||
|
||
if self.config.weights is None: | ||
layer_wise_weight = get_layer_wise_weights( | ||
num_models=len(modelpool.model_names), | ||
num_layers=len( | ||
list( | ||
filter(lambda p: p.requires_grad, pretrained_model.parameters()) | ||
) | ||
), | ||
init_values=self.config.init_values, | ||
) | ||
else: | ||
if isinstance(self.config.weights, str): | ||
# self.config.weights is a path to a .np or .pt file | ||
if self.config.weights.endswith(".pt"): | ||
layer_wise_weight = torch.load( | ||
self.config.weights, map_location="cpu" | ||
).detach_() | ||
elif self.config.weights.endswith(".np"): | ||
layer_wise_weight = torch.from_numpy( | ||
np.load(self.config.weights) | ||
).detach_() | ||
else: | ||
raise ValueError(f"Unsupported file format: {self.config.weights}") | ||
else: | ||
try: | ||
layer_wise_weight = torch.tensor( | ||
list(self.config.weights), dtype=torch.float32 | ||
) | ||
except ValueError: | ||
raise ValueError( | ||
f"Unsupported weights format: {self.config.weights}" | ||
) | ||
|
||
module = LayerWiseMergedModel( | ||
layer_wise_weight=layer_wise_weight, | ||
pretrained_model=pretrained_model, | ||
finetuned_models=finetuned_models, | ||
clamp_weights=self.config.clamp_weights, | ||
tie_weights=self.config.tie_weights, | ||
strict=self.config.strict, | ||
) | ||
return module | ||
|
||
def fuse(self, modelpool: ModelPool): | ||
log.info("Fusing models using layer-wise adaptive merging.") | ||
self.modelpool = modelpool | ||
|
||
module = self.construct_layer_wise_merged_model(modelpool) | ||
|
||
if self.config.weights is not None: | ||
# skip the test-time adaptation | ||
return module.merge_and_unload(module) | ||
else: | ||
module = self.test_time_adaptation(module) | ||
if self.config.get("save_merging_weights", False): | ||
torch.save(module.merge_weight, self.config.save_merging_weights) | ||
return module.merge_and_unload() | ||
|
||
def on_test_time_adaptation_start(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader: | ||
pass | ||
|
||
@abstractmethod | ||
def compute_logits(self, module, batch, task) -> Tensor: | ||
pass | ||
|
||
def test_time_adaptation(self, module: LayerWiseMergedModel): | ||
self.on_test_time_adaptation_start() | ||
|
||
# configure optimizer | ||
if self.config.optimizer == "adam": | ||
optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr) | ||
else: | ||
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}") | ||
|
||
if self._fabric is not None: | ||
module, optimizer = self._fabric.setup(module, optimizer) | ||
|
||
module.train() | ||
module.merge_weights() | ||
|
||
if self.config.get("fast_dev_run", False): | ||
log.info("Running fast_dev_run, only one step") | ||
pbar = tqdm( | ||
range(1), | ||
"AdaMerging Test-time adaptation", | ||
dynamic_ncols=True, | ||
) | ||
else: | ||
pbar = tqdm( | ||
range(self.config.max_steps), | ||
"AdaMerging Test-time adaptation", | ||
dynamic_ncols=True, | ||
) | ||
for step_idx in pbar: | ||
for task in self.modelpool.model_names: | ||
batch = next(self.get_shuffled_test_loader_iter(task)) | ||
logits = self.compute_logits(module, batch, task) | ||
assert ( | ||
logits.dim() == 2 | ||
), f"Expected logits to be 2D, got {logits.dim()}" | ||
loss = entropy_loss(logits) | ||
# .backward() accumulates when .zero_grad() wasn't called | ||
# this can save memory | ||
self._fabric.backward(loss, retain_graph=True) | ||
|
||
optimizer.step() | ||
optimizer.zero_grad() | ||
module.merge_weights() | ||
|
||
return module |
Oops, something went wrong.