Skip to content

Commit

Permalink
Merge pull request #48 from tanganke/develop
Browse files Browse the repository at this point in the history
merge develop into main
  • Loading branch information
tanganke authored Dec 10, 2024
2 parents 299a481 + 437276d commit 68cc9b9
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/modelpool/nyuv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fusion_bench --config-name nyuv2_config \
### Ties-Merging

```bash
fusion --config-name nyuv2_config \
fusion_bench --config-name nyuv2_config \
method=ties_merging \
method.scaling_factor=0.3
```
Expand Down
2 changes: 2 additions & 0 deletions fusion_bench/dataset/llama/stanford_shp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def tokenize(sample):
sample["rejected_input_ids"].append(tokenizer.eos_token_id)
sample["rejected_attention_mask"].append(1)

return sample

dataset = dataset.map(tokenize, num_proc=num_proc)

if cache_path is not None and rank_zero_only.rank == 0:
Expand Down
4 changes: 2 additions & 2 deletions fusion_bench/method/pruning/llama_magnitude_prune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal, Optional, Union
from typing import Dict, Literal, Optional, Union

import torch
from torch import Dict, nn
from torch import nn
from tqdm.auto import tqdm
from transformers import LlamaForCausalLM, LlamaModel

Expand Down
4 changes: 2 additions & 2 deletions fusion_bench/method/pruning/llama_random_prune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal, Optional, Union # noqa: F401
from typing import Dict, Literal, Optional, Union # noqa: F401

import torch
from torch import Dict, nn
from torch import nn
from tqdm.auto import tqdm
from transformers import LlamaForCausalLM, LlamaModel

Expand Down
5 changes: 4 additions & 1 deletion fusion_bench/mixins/clip_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,13 @@ def compute_logits(
module: Union[nn.Module, CLIPVisionModel],
images: torch.Tensor,
task: str,
image_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_embeds = self.zeroshot_weights[task]

image_embeds = module(images)[1]
if image_embeds is None:
image_embeds = module(images)[1]
assert isinstance(image_embeds, torch.Tensor), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
image_embeds = self.visual_projection(image_embeds)

# normalize embeddings
Expand Down

0 comments on commit 68cc9b9

Please sign in to comment.