Skip to content

Commit

Permalink
update task-wise adamerging
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 16, 2024
1 parent 34de1b0 commit 15f5469
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 46 deletions.
10 changes: 8 additions & 2 deletions config/method/task_wise_adamerging.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name: ??? # one of "clip_task_wise_adamerging"
# this option can be "clip_task_wise_adamerging"
name: ???

# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
# if weights is specified, skip the test-time adaptation training
weights: null

# learning rate
lr: 0.0001
lr: 1e-3
optimizer: adam

init_values: 0.3
Expand All @@ -20,3 +21,8 @@ devices: 1
batch_size: 16
num_workers: 4
max_steps: 1000
fast_dev_run: ${fast_dev_run}

# the path for saving the merging weights
save_merging_weights: false
cache_dir: outputs
42 changes: 42 additions & 0 deletions config/modelpool/clip-vit-large-patch14_TA8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,45 @@ models:
path: tanganke/clip-vit-large-patch14_mnist
- name: dtd
path: tanganke/clip-vit-large-patch14_dtd

# The following datasets are used for test-time adaptation
dataset_type: huggingface_image_classification
tta_datasets:
- name: svhn
dataset:
type: instantiate
name: svhn
object:
_target_: datasets.load_dataset
_args_:
- svhn
- cropped_digits
split: test
- name: stanford_cars
dataset:
name: tanganke/stanford_cars
split: test
- name: resisc45
dataset:
name: tanganke/resisc45
split: test
- name: eurosat
dataset:
name: tanganke/eurosat
split: test
- name: gtsrb
dataset:
name: tanganke/gtsrb
split: test
- name: mnist
dataset:
name: mnist
split: test
- name: dtd
dataset:
name: tanganke/dtd
split: test
- name: sun397
dataset:
name: tanganke/sun397
split: test
57 changes: 45 additions & 12 deletions fusion_bench/method/adamerging/clip_task_wise_adamerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,28 @@
from fusion_bench.utils import timeit_context

from .task_wise_adamerging import TaskWiseAdaMergingAlgorithm
import os

log = logging.getLogger(__name__)


class InfiniteDataLoader:
def __init__(self, data_loader):
self.data_loader = data_loader
self.data_iter = iter(data_loader)

def __iter__(self):
return self

def __next__(self):
try:
data = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.data_loader) # Reset the data loader
data = next(self.data_iter)
return data


class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
modelpool: HuggingFaceClipVisionPool = None
_clip_processor: CLIPProcessor = None
Expand All @@ -28,18 +46,18 @@ def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

def get_task_config(self, task):
for task_config in self.config.tta_datasets:
for task_config in self.modelpool.config.tta_datasets:
if task_config.name == task:
return task_config
raise ValueError(f"Task {task} not found in config")

def prepare_dataset_config(self, dataset_config: DictConfig):
if not hasattr(dataset_config, "type"):
with open_dict(dataset_config):
dataset_config["type"] = self.config.dataset_type
dataset_config["type"] = self.modelpool.config.dataset_type
return dataset_config

@functools.cache()
@functools.cache
def get_test_dataset(self, task: str):
"""
Load the test dataset for the task.
Expand All @@ -52,7 +70,7 @@ def get_test_dataset(self, task: str):
dataset = CLIPDataset(dataset, self._clip_processor)
return dataset

@functools.cache()
@functools.cache
def get_shuffled_test_loader_iter(self, task: str):
loader = DataLoader(
self.get_test_dataset(task),
Expand All @@ -63,9 +81,12 @@ def get_shuffled_test_loader_iter(self, task: str):
)
if self._fabric is not None:
loader = self._fabric.setup_dataloaders(loader)
return iter(itertools.cycle(loader))
return iter(InfiniteDataLoader(loader))

def on_test_time_adaptation_start(self):
"""
Here we load the CLIP processor and construct the zero-shot classification head for each task.
"""
clip_model_config = self.modelpool.get_model_config("_pretrained_")

with timeit_context("Loading CLIP processor and pretrained CLIP model."):
Expand All @@ -74,17 +95,29 @@ def on_test_time_adaptation_start(self):

clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
self.logit_scale = clip_model.logit_scale.exp()
if self._fabric is not None:
self.visual_projection = self._fabric.to_device(self.visual_projection)
self.logit_scale = self._fabric.to_device(clip_model.logit_scale.exp())
self.logit_scale = self._fabric.to_device(self.logit_scale)

for task in self.modelpool.model_names():
log.info(f"Construct zero shot classification head for task {task}")
classnames, templates = get_classnames_and_templates(
self.get_task_config(task)["dataset"].name
for task in self.modelpool.model_names:
cache_file = os.path.join(
self.config.cache_dir,
f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",
)
clip_classifier.set_classification_task(classnames, templates)
self.zeroshot_weights[task] = clip_classifier.zeroshot_weights
if os.path.exists(cache_file):
log.info(f"Loading cached zeroshot weights for task: {task}")
zeroshot_weights = torch.load(cache_file, map_location="cpu")
else:
log.info(f"Construct zero shot classification head for task: {task}")
classnames, templates = get_classnames_and_templates(
self.get_task_config(task)["dataset"].name
)
clip_classifier.set_classification_task(classnames, templates)
zeroshot_weights = clip_classifier.zeroshot_weights
log.info(f"save zeroshot weights to {cache_file}")
torch.save(zeroshot_weights, cache_file)
self.zeroshot_weights[task] = zeroshot_weights
if self._fabric is not None:
self.zeroshot_weights[task] = self._fabric.to_device(
self.zeroshot_weights[task]
Expand Down
38 changes: 26 additions & 12 deletions fusion_bench/method/adamerging/task_wise_adamerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

if self._fabric is not None and torch.cuda.is_available():
if self._fabric is None and torch.cuda.is_available():
self._fabric = L.Fabric(devices=self.config.devices)
self._fabric.launch()

Expand Down Expand Up @@ -103,7 +103,9 @@ def fuse(self, modelpool: ModelPool):
# skip the test-time adaptation
return module.merge_and_unload(module)
else:
module = self.test_time_adaptation()
module = self.test_time_adaptation(module)
if self.config.get("save_merging_weights", False):
torch.save(module.task_wise_weight, self.config.save_merging_weights)
return module.merge_and_unload()

def on_test_time_adaptation_start(self):
Expand All @@ -122,9 +124,7 @@ def test_time_adaptation(self, module: TaskWiseMergedModel):

# configure optimizer
if self.config.optimizer == "adam":
optimizer = torch.optim.Adam(
[self.module.task_wise_weight], lr=self.config.lr
)
optimizer = torch.optim.Adam([module.task_wise_weight], lr=self.config.lr)
else:
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

Expand All @@ -133,20 +133,34 @@ def test_time_adaptation(self, module: TaskWiseMergedModel):

module.train()
module.merge_weights()
for step_idx in tqdm(
self.config.max_steps, "AdaMerging Test-time adaptation", dynamic_ncols=True
):
loss = 0

if self.config.get("fast_dev_run", False):
log.info("Running fast_dev_run, only one step")
pbar = tqdm(
range(1),
"AdaMerging Test-time adaptation",
dynamic_ncols=True,
)
else:
pbar = tqdm(
range(self.config.max_steps),
"AdaMerging Test-time adaptation",
dynamic_ncols=True,
)
for step_idx in pbar:
for task in self.modelpool.model_names:
batch = next(self.get_shuffled_test_loader_iter(task))
logits = self.compute_logits(module, batch, task)
assert (
logits.dim() == 2
), f"Expected logits to be 2D, got {logits.dim()}"
loss = loss + entropy_loss(logits)
optimizer.zero_grad()
self._fabric.backward(loss)
loss = entropy_loss(logits)
# .backward() accumulates when .zero_grad() wasn't called
# this can save memory
self._fabric.backward(loss, retain_graph=True)

optimizer.step()
optimizer.zero_grad()
module.merge_weights()

return module
40 changes: 20 additions & 20 deletions fusion_bench/models/wrappers/task_wise_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ def __init__(
del_attr(m, name.split("."))
else:
for m in finetuned_models:
set_attr(
m,
name.split("."),
get_attr(pretrained_model, name.split(".")) - param,
get_attr(m, name.split(".")).data = (
get_attr(m, name.split(".")) - param
)
self.pretrained_model = pretrained_model.requires_grad_(False)
for m in finetuned_models:
m.requires_grad_(False)
self.task_vectors = nn.ModuleList(finetuned_models)

@property
Expand Down Expand Up @@ -224,19 +224,19 @@ def forward(self, *args, **kwargs):
self.merge_weights()
return self.forward_model(args=args, kwargs=kwargs)

def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
except AttributeError:
attr = getattr(self.model, name)
if isinstance(attr, Callable):
warnings.warn(
f"forwarding `{name}` to the underlying model", UserWarning
)
return attr

def __setattr__(self, name: str, value: Any) -> None:
try:
super().__setattr__(name, value)
except AttributeError:
setattr(self.model, name, value)
# def __getattr__(self, name: str) -> Any:
# try:
# return super().__getattr__(name)
# except AttributeError:
# attr = getattr(self.pretrained_model, name)
# if isinstance(attr, Callable):
# warnings.warn(
# f"forwarding `{name}` to the underlying model", UserWarning
# )
# return attr

# def __setattr__(self, name: str, value: Any) -> None:
# try:
# super().__setattr__(name, value)
# except AttributeError:
# setattr(self.pretrained_model, name, value)
15 changes: 15 additions & 0 deletions fusion_bench/utils/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class InfiniteDataLoader:
def __init__(self, data_loader):
self.data_loader = data_loader
self.data_iter = iter(data_loader)

def __iter__(self):
return self

def __next__(self):
try:
data = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.data_loader) # Reset the data loader
data = next(self.data_iter)
return data
4 changes: 4 additions & 0 deletions offline_mode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
# This script is used to set the environment variables for offline mode
export TRANSFORMERS_OFFLINE=1
export HF_DATASETS_OFFLINE=1

0 comments on commit 15f5469

Please sign in to comment.