Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Nov 25, 2024
1 parent 953f9bb commit c62f011
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
36 changes: 36 additions & 0 deletions fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import functools
import logging
import os
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401

import torch
from lightning.fabric.utilities.rank_zero import rank_zero_only
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers.data import default_data_collator

from fusion_bench.method import BaseAlgorithm
from fusion_bench.method.simple_average import simple_average
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
from fusion_bench.models.wrappers.layer_wise_fusion import (
LayerWiseMergedModel,
get_layer_wise_weights,
)
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
from fusion_bench.utils.instantiate import instantiate

from .entropy_loss import entropy_loss
from .min_norm_solvers import MinNormSolver
from .utils import get_memory_usage

class FlanT5LayerWiseAdaMergingAlgorithm(BaseAlgorithm,
LightningFabricMixin,
SimpleProfilerMixin,
):

12 changes: 0 additions & 12 deletions fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,6 @@ def construct_layer_wise_merged_model(
modelpool.load_model(name) for name in modelpool.model_names
]

for name in (
["wte", "wpe", "ln_f"]
+ [f"h.{i}.ln_1" for i in range(len(pretrained_model.h))]
+ [f"h.{i}.ln_2" for i in range(len(pretrained_model.h))]
+ [f"h.{i}.attn" for i in range(len(pretrained_model.h))]
):
simple_average(
[model.get_submodule(name) for model in finetuned_models],
base_module=pretrained_model.get_submodule(name),
)
pretrained_model.get_submodule(name).requires_grad_(False)

# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
if self.merging_weights_load_path is None:
layer_wise_weight = get_layer_wise_weights(
Expand Down

0 comments on commit c62f011

Please sign in to comment.