diff --git a/docs/modelpool/nyuv2.md b/docs/modelpool/nyuv2.md index 84a2948a..db6d02e5 100644 --- a/docs/modelpool/nyuv2.md +++ b/docs/modelpool/nyuv2.md @@ -23,7 +23,7 @@ fusion_bench --config-name nyuv2_config \ ### Ties-Merging ```bash -fusion --config-name nyuv2_config \ +fusion_bench --config-name nyuv2_config \ method=ties_merging \ method.scaling_factor=0.3 ``` diff --git a/fusion_bench/dataset/llama/stanford_shp.py b/fusion_bench/dataset/llama/stanford_shp.py index f829abba..07d30e24 100644 --- a/fusion_bench/dataset/llama/stanford_shp.py +++ b/fusion_bench/dataset/llama/stanford_shp.py @@ -81,6 +81,8 @@ def tokenize(sample): sample["rejected_input_ids"].append(tokenizer.eos_token_id) sample["rejected_attention_mask"].append(1) + return sample + dataset = dataset.map(tokenize, num_proc=num_proc) if cache_path is not None and rank_zero_only.rank == 0: diff --git a/fusion_bench/method/pruning/llama_magnitude_prune.py b/fusion_bench/method/pruning/llama_magnitude_prune.py index 8b6fbd31..dba39997 100644 --- a/fusion_bench/method/pruning/llama_magnitude_prune.py +++ b/fusion_bench/method/pruning/llama_magnitude_prune.py @@ -1,7 +1,7 @@ -from typing import Literal, Optional, Union +from typing import Dict, Literal, Optional, Union import torch -from torch import Dict, nn +from torch import nn from tqdm.auto import tqdm from transformers import LlamaForCausalLM, LlamaModel diff --git a/fusion_bench/method/pruning/llama_random_prune.py b/fusion_bench/method/pruning/llama_random_prune.py index 64c8532a..32992476 100644 --- a/fusion_bench/method/pruning/llama_random_prune.py +++ b/fusion_bench/method/pruning/llama_random_prune.py @@ -1,7 +1,7 @@ -from typing import Literal, Optional, Union # noqa: F401 +from typing import Dict, Literal, Optional, Union # noqa: F401 import torch -from torch import Dict, nn +from torch import nn from tqdm.auto import tqdm from transformers import LlamaForCausalLM, LlamaModel diff --git a/fusion_bench/mixins/clip_classification.py b/fusion_bench/mixins/clip_classification.py index ef8fdc38..83604dc0 100644 --- a/fusion_bench/mixins/clip_classification.py +++ b/fusion_bench/mixins/clip_classification.py @@ -178,10 +178,13 @@ def compute_logits( module: Union[nn.Module, CLIPVisionModel], images: torch.Tensor, task: str, + image_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: text_embeds = self.zeroshot_weights[task] - image_embeds = module(images)[1] + if image_embeds is None: + image_embeds = module(images)[1] + assert isinstance(image_embeds, torch.Tensor), f"`image_embeds` must be a tensor, but got {type(image_embeds)}" image_embeds = self.visual_projection(image_embeds) # normalize embeddings