-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
36 additions
and
12 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import functools | ||
import logging | ||
import os | ||
from abc import abstractmethod | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401 | ||
|
||
import torch | ||
from lightning.fabric.utilities.rank_zero import rank_zero_only | ||
from omegaconf import DictConfig | ||
from torch import Tensor, nn | ||
from torch.utils.data import DataLoader | ||
from tqdm.autonotebook import tqdm | ||
from transformers.data import default_data_collator | ||
|
||
from fusion_bench.method import BaseAlgorithm | ||
from fusion_bench.method.simple_average import simple_average | ||
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin | ||
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin | ||
from fusion_bench.modelpool import GPT2ForSequenceClassificationPool | ||
from fusion_bench.models.wrappers.layer_wise_fusion import ( | ||
LayerWiseMergedModel, | ||
get_layer_wise_weights, | ||
) | ||
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file | ||
from fusion_bench.utils.instantiate import instantiate | ||
|
||
from .entropy_loss import entropy_loss | ||
from .min_norm_solvers import MinNormSolver | ||
from .utils import get_memory_usage | ||
|
||
class FlanT5LayerWiseAdaMergingAlgorithm(BaseAlgorithm, | ||
LightningFabricMixin, | ||
SimpleProfilerMixin, | ||
): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters