Skip to content

Commit

Permalink
Merge pull request #32 from tanganke/llama
Browse files Browse the repository at this point in the history
Fix bug in config/clip-vit-base-patch32_robustness_corrupted.yaml
  • Loading branch information
tanganke authored Nov 19, 2024
2 parents 0355f54 + 1f8b39b commit 23c1218
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 23 deletions.
1 change: 1 addition & 0 deletions config/method/pruning/magnitude_diff_pruning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ _target_: fusion_bench.method.MagnitudeDiffPruningAlgorithm
prune_ratio: 0.5
rescale: false
extract_names: null
prune_type: minor
53 changes: 53 additions & 0 deletions fusion_bench/dataset/llama/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence


def padded_collate_sft(
batch: List[Dict[str, List[int]]],
padding_idx: int = 0,
input_ids_key: str = "input_ids",
attention_mask_key: Optional[str] = "attention_mask",
labels_key: Optional[str] = "labels",
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
"""Pad a batch of sequences to the longest sequence length in the batch, and
convert integer lists to tensors.
Args:
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
padding_idx (int): Padding index for input ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.
Returns:
Dict[str, torch.Tensor]: Collated input and label tensors.
"""
input_ids = pad_sequence(
[torch.tensor(x[input_ids_key]) for x in batch],
batch_first=True,
padding_value=padding_idx,
)
if attention_mask_key is not None and attention_mask_key in batch[0]:
attention_mask = pad_sequence(
[torch.tensor(x[attention_mask_key]) for x in batch],
batch_first=True,
padding_value=0,
)
else:
attention_mask = None
labels = pad_sequence(
[torch.tensor(x[labels_key]) for x in batch],
batch_first=True,
padding_value=ignore_idx,
)

if attention_mask is not None:
return {
input_ids_key: input_ids,
attention_mask_key: attention_mask,
labels_key: labels,
}
else:
return {input_ids_key: input_ids, labels_key: labels}
40 changes: 19 additions & 21 deletions fusion_bench/method/pruning/magnitude_diff_pruning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re
from copy import deepcopy
from typing import Dict, List, Optional, Union # noqa: F401
from typing import Dict, List, Literal, Optional, Union # noqa: F401

import torch
from torch import Tensor, nn
Expand All @@ -10,27 +10,12 @@
from fusion_bench.method import BaseAlgorithm
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
from fusion_bench.modelpool import BaseModelPool
import functools
from .prune_utils import unstructured_magnitude_prune_

log = logging.getLogger(__name__)


def _magnitude_prune(weight: Tensor, prune_ratio: float) -> Tensor:
"""
Prune the weights by setting values below a certain quantile to zero.
Args:
weight (Tensor): The weight tensor to be pruned.
prune_ratio (float): The ratio of weights to prune.
Returns:
Tensor: The pruned weight tensor.
"""
weight_abs = weight.abs()
mask = weight_abs > weight_abs.quantile(prune_ratio)
weight = weight * mask
return weight


def _is_name_matched(name: str, extract_names: List[str]):
"""
Check if the parameter name matches any of the provided regular expressions.
Expand Down Expand Up @@ -77,6 +62,7 @@ def __init__(
prune_ratio: float,
rescale: Optional[Union[bool, float]] = None,
extract_names: List[str] = None,
prune_type: Literal["minor", "major"] = "minor",
**kwargs,
):
"""
Expand All @@ -90,6 +76,7 @@ def __init__(
self.prune_ratio = prune_ratio
self.rescale = rescale
self.extract_names = extract_names
self.prune_type = prune_type
super().__init__(**kwargs)

@torch.no_grad()
Expand Down Expand Up @@ -173,9 +160,20 @@ def magnitude_prune(
# Prune the diff parameter if its name matches
if _is_name_matched(name, extract_names):
w_diff = ft_state_dict[name] - param
w_diff = _magnitude_prune(w_diff, prune_ratio=self.prune_ratio)
if self.rescale is not None and self.rescale:
w_diff = w_diff * self.rescale
w_diff = unstructured_magnitude_prune_(
w_diff,
(
torch.abs
if self.prune_type == "minor"
else lambda x: -torch.abs(x)
),
sparsity_ratio=self.prune_ratio,
)
if self.rescale is not None:
rescale = (
1 / self.prune_ratio if self.rescale == True else self.rescale
)
w_diff = w_diff * rescale
param.data = param + w_diff

return model
23 changes: 22 additions & 1 deletion fusion_bench/utils/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_current_device() -> torch.device:
return torch.device(device)


def get_device_memory_info(device: torch.device) -> dict:
def get_device_memory_info(device: torch.device, reset_stats: bool = True) -> dict:
"""
Get memory information for a given device.
Expand All @@ -174,10 +174,22 @@ def get_device_memory_info(device: torch.device) -> dict:
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
peak_memory_active = torch.cuda.memory_stats(device).get(
"active_bytes.all.peak", 0
)
peak_mem_alloc = torch.cuda.max_memory_allocated(device)
peak_mem_reserved = torch.cuda.max_memory_reserved(device)

if reset_stats:
torch.cuda.reset_peak_memory_stats(device)

return {
"total_memory": total_memory,
"reserved_memory": reserved_memory,
"allocated_memory": allocated_memory,
"peak_memory_active": peak_memory_active,
"peak_memory_allocated": peak_mem_alloc,
"peak_memory_reserved": peak_mem_reserved,
}
else:
raise ValueError(
Expand Down Expand Up @@ -208,3 +220,12 @@ def get_device_capabilities(device: torch.device) -> dict:
raise ValueError(
f"Capabilities information not available for device type: {device.type}"
)


def cleanup_cuda():
"""
Call gc collect, empty CUDA cache, and reset peak memory stats.
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
50 changes: 49 additions & 1 deletion fusion_bench/utils/dtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Optional
import contextlib
from typing import Dict, Generator, Iterable, Optional, Tuple

import torch
from transformers.utils import (
Expand Down Expand Up @@ -78,6 +79,33 @@ def get_dtype(obj) -> torch.dtype:
raise ValueError(f"Unsupported object type: {type(obj)}")


@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
"""
Context manager to set torch's default dtype.
Args:
dtype (torch.dtype): The desired default dtype inside the context manager.
Returns:
ContextManager: context manager for setting default dtype.
Example:
>>> with set_default_dtype(torch.bfloat16):
>>> x = torch.tensor([1, 2, 3])
>>> x.dtype
torch.bfloat16
"""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(old_dtype)


def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
Expand All @@ -96,3 +124,23 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
return torch.float16
else:
return torch.float32


def validate_expected_param_dtype(
named_params: Iterable[Tuple[str, torch.nn.Parameter]], dtype: torch.dtype
) -> None:
"""
Validates that all input parameters have the expected dtype.
Args:
named_params (Iterable[Tuple[str, torch.nn.Parameter]]): Iterable of named parameters.
dtype (torch.dtype): Expected dtype.
Raises:
ValueError: If any parameter has a different dtype than `dtype`.
"""
for name, param in named_params:
if param.dtype != dtype:
raise ValueError(
f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
)

0 comments on commit 23c1218

Please sign in to comment.