diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/404.html b/404.html new file mode 100644 index 00000000..b6ca89e3 --- /dev/null +++ b/404.html @@ -0,0 +1,2188 @@ + + + + + + + + + + + + + + + + + + + FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ +

404 - Not found

+ +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/adamerging/index.html b/algorithms/adamerging/index.html new file mode 100644 index 00000000..6d6c70fc --- /dev/null +++ b/algorithms/adamerging/index.html @@ -0,0 +1,5921 @@ + + + + + + + + + + + + + + + + + + + + + + + AdaMerging - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

AdaMerging

+
+ alt text +
Task Vector, Task Arithmetic, and AdaMerging. Credit to 1
+
+

In the complex landscape of multi-task learning, AdaMerging has emerged as a potent method for adaptively merging model parameters to optimize performance across tasks. Unlike traditional fixed-coefficient methods, AdaMerging autonomously learns merging coefficients, offering a more refined and responsive approach1.

+

The cornerstone of AdaMerging lies in its adaptive nature, where it learns the coefficients for merging either on a task-wise or layer-wise basis. This adaptability is driven by an entropy minimization strategy applied to unlabeled test samples as a surrogate objective function, which serves to refine the merging coefficients for optimal performance.

+

Task-wise AdaMerging is formulated as:

+
\[ +\theta = \theta_0 + \sum_{i=1}^{n} \lambda_i \tau_i +\]
+

where \(\lambda_i\) represents the merging coefficient for the \(i\)-th task, and \(\tau_i\) denotes the task vector for the \(i\)-th task.

+

On the other hand, Layer-wise AdaMerging is articulated as:

+
\[ +\theta^l = \theta_0^l + \sum_{i=1}^{n} \lambda^{l}_{i} \tau^{l}_{i} +\]
+

where the merging coefficient \(\lambda^{l}_{i}\) and task vector \(\tau^{l}_{i}\) are specific to each layer \(l\) of the model.

+

By leveraging this adaptive learning approach, AdaMerging significantly enhances the model's ability to generalize across tasks and layers, resulting in a more robust and finely-tuned performance profile. The method’s reliance on entropy minimization ensures that the merging process continually seeks the most informative and stable configuration, adapting to the specific needs of the dataset and tasks at hand.

+

AdaMerging Analysis

+

Task-wise Coefficients. +The below Figure shows the changes during the iteration process of merging coefficient optimization of each task vector in Task-wise AdaMerging and AdaMerging++, which is shown every ten steps. We consistently observe that the merging coefficients of each task vector are inconsistent. When the number of tasks is relatively large, it is obviously undesirable to grid search the coefficients of each task, but our AdaMerging avoids this manual search process.

+
+alt text +
+Model merging coefficients \(\{λ_k\}_{k=1}^K\) change with respect to training steps on ViT-B/32:
+(a) Task-wise AdaMerging; (b) Task-wise AdaMerging++. Each line represents the change process of the coefficient \(λ_k\) of a task vector \(T_k (k \in \{1, 2, . . . , K\})\). +
+
+

Layer-wise Coefficients. +The following Figure shows the merging coefficients learned by Layer-wise AdaMerging and AdaMerging++ on ViT-B/32 respectively. We observed that:

+
    +
  1. The coefficients learned by each layer of each task vector are different, which shows that the importance of each layer in the model merging process is different.
  2. +
  3. The coefficients learned by shallow layers are generally smaller than those of deep layers, which indicates that shallow layers rely more on the weights of the pre-trained model rather than the weights provided by task vectors, while the deep layers rely more on the weights provided by the task vectors. This may be since the shallow layer learns general features, which are cross-task, while the deep layer learns task-specific features 2. This finding is also consistent with routing analysis in 3.
  4. +
+
+alt text +
+Learned model merging coefficients \(\{λ_l^k\}^{K,L}_{k=1,l=1}\) of Layer-wise AdaMerging (Above) and AdaMerging++ (Below) on ViT-B/32. +The \(k\)-th row represents the \(k\)-th task vector, the \(l\)-th column represents the \(l\)-th layer, and the intersection point represents the coefficient \(λ^l_k\). +
+
+

Code Integration

+

Merge CLIP-ViT-B/32 models from eight downstream image classification tasks:

+
fusion_bench \
+    method=adamerging \
+        method.name=clip_layer_wise_adamerging \
+        method.save_merging_weights=merging_weights.pt \
+    modelpool=clip-vit-base-patch32_TA8 \
+    taskpool=clip-vit-classification_TA8 \
+    fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
+    fabric.loggers.name=clip_layer_wise_adamerging_adam
+
+

Part of the output:

+
Profiler Report
+
+----------------------------------------------------------------------------------------------------------------------------------
+|  Action                       |  Mean duration (s)    |  Num calls            |  Total time (s)       |  Percentage %         |
+----------------------------------------------------------------------------------------------------------------------------------
+|  Total                        |  -                    |  26001                |  724.65               |  100 %                |
+----------------------------------------------------------------------------------------------------------------------------------
+|  backward pass                |  0.060172             |  8000                 |  481.38               |  66.429               |
+|  forward pass                 |  0.016124             |  8000                 |  128.99               |  17.801               |
+|  data loading                 |  0.0063443            |  8000                 |  50.754               |  7.004                |
+|  merging weights              |  0.050735             |  1000                 |  50.735               |  7.0013               |
+|  construct the wrapped model  |  7.2558               |  1                    |  7.2558               |  1.0013               |
+|  optimizer step               |  0.00098186           |  1000                 |  0.98186              |  0.13549              |
+----------------------------------------------------------------------------------------------------------------------------------
+
+

Reference

+

Task-Wise AdaMerging

+ + +
+ + + +

+ task_wise_adamerging + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ TaskWiseAdaMergingAlgorithm + + +
+ + +
+

+ Bases: ModelFusionAlgorithm

+ + + + + + + +
+ Source code in fusion_bench/method/adamerging/task_wise_adamerging.py +
class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
+    _fabric: L.Fabric = None
+
+    def __init__(self, algorithm_config: DictConfig):
+        super().__init__(algorithm_config)
+
+        if self._fabric is None and torch.cuda.is_available():
+            self._fabric = L.Fabric(devices=self.config.get("devices", 1))
+            self._fabric.launch()
+
+    @torch.no_grad()
+    def construct_task_wise_merged_model(self, modelpool: ModelPool):
+        if self.config.weights is None:
+            task_wise_weight = get_task_wise_weights(
+                num_models=len(modelpool.model_names),
+                init_values=self.config.init_values,
+            )
+        else:
+            if isinstance(self.config.weights, str):
+                # self.config.weights is a path to a .np or .pt file
+                if self.config.weights.endswith(".pt"):
+                    task_wise_weight = torch.load(
+                        self.config.weights, map_location="cpu"
+                    ).detach_()
+                elif self.config.weights.endswith(".np"):
+                    task_wise_weight = torch.from_numpy(
+                        np.load(self.config.weights)
+                    ).detach_()
+                else:
+                    raise ValueError(f"Unsupported file format: {self.config.weights}")
+            else:
+                try:
+                    task_wise_weight = torch.tensor(
+                        list(self.config.weights), dtype=torch.float32
+                    )
+                except ValueError:
+                    raise ValueError(
+                        f"Unsupported weights format: {self.config.weights}"
+                    )
+
+        pretrained_model = modelpool.load_model("_pretrained_")
+        finetuned_models = [
+            modelpool.load_model(name) for name in modelpool.model_names
+        ]
+
+        module = TaskWiseMergedModel(
+            task_wise_weight=task_wise_weight,
+            pretrained_model=pretrained_model,
+            finetuned_models=finetuned_models,
+            clamp_weights=self.config.clamp_weights,
+            tie_weights=self.config.tie_weights,
+            strict=self.config.strict,
+        )
+        return module
+
+    def run(self, modelpool: ModelPool):
+        log.info("Fusing models using task-wise adaptive merging.")
+        self.modelpool = modelpool
+
+        module = self.construct_task_wise_merged_model(modelpool)
+
+        if self.config.weights is not None:
+            # skip the test-time adaptation
+            return module.merge_and_unload()
+        else:
+            module = self.test_time_adaptation(module)
+            if self.config.get("save_merging_weights", False):
+                torch.save(module.merge_weight, self.config.save_merging_weights)
+            return module.merge_and_unload()
+
+    def on_test_time_adaptation_start(self):
+        pass
+
+    @abstractmethod
+    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
+        pass
+
+    @abstractmethod
+    def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
+        """
+        Compute the logits for the given batch and task.
+
+        Args:
+            module (nn.Module): The model module.
+            batch (tuple): A batch of input data.
+            task (str): The name of the task.
+
+        Returns:
+            Tensor: The classification logits for the batch.
+        """
+        pass
+
+    def test_time_adaptation(self, module: TaskWiseMergedModel):
+        self.on_test_time_adaptation_start()
+
+        # configure optimizer
+        if self.config.optimizer == "adam":
+            optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
+        else:
+            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
+
+        if self._fabric is not None:
+            module, optimizer = self._fabric.setup(module, optimizer)
+
+        module.train()
+        module.merge_weights()
+
+        if self.config.get("fast_dev_run", False):
+            log.info("Running fast_dev_run, only one step")
+            pbar = tqdm(
+                range(1),
+                "AdaMerging Test-time adaptation",
+                dynamic_ncols=True,
+            )
+        else:
+            pbar = tqdm(
+                range(self.config.max_steps),
+                "AdaMerging Test-time adaptation",
+                dynamic_ncols=True,
+            )
+        for step_idx in pbar:
+            for task in self.modelpool.model_names:
+                batch = next(self.get_shuffled_test_loader_iter(task))
+                logits = self.compute_logits(module, batch, task)
+                assert (
+                    logits.dim() == 2
+                ), f"Expected logits to be 2D, got {logits.dim()}"
+                loss = entropy_loss(logits)
+                # .backward() accumulates when .zero_grad() wasn't called
+                # this can save memory
+                self._fabric.backward(loss, retain_graph=True)
+
+            optimizer.step()
+            optimizer.zero_grad()
+            module.merge_weights()
+
+        return module
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ compute_logits(module, batch, task) + + + abstractmethod + + +
+ + +
+ +

Compute the logits for the given batch and task.

+ + +

Parameters:

+
    +
  • + module + (Module) + – +
    +

    The model module.

    +
    +
  • +
  • + batch + (tuple) + – +
    +

    A batch of input data.

    +
    +
  • +
  • + task + (str) + – +
    +

    The name of the task.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The classification logits for the batch.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/task_wise_adamerging.py +
@abstractmethod
+def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
+    """
+    Compute the logits for the given batch and task.
+
+    Args:
+        module (nn.Module): The model module.
+        batch (tuple): A batch of input data.
+        task (str): The name of the task.
+
+    Returns:
+        Tensor: The classification logits for the batch.
+    """
+    pass
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +
+ entropy_loss(logits) + +
+ + +
+ +

Compute the entropy loss of a set of logits.

+ + +

Parameters:

+
    +
  • +
    logits +
    (Tensor) + – +
    +

    The logits to compute the entropy loss of.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The entropy loss of the logits.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/task_wise_adamerging.py +
24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
def entropy_loss(logits: Tensor) -> Tensor:
+    """
+    Compute the entropy loss of a set of logits.
+
+    Args:
+        logits (Tensor): The logits to compute the entropy loss of.
+
+    Returns:
+        Tensor: The entropy loss of the logits.
+    """
+    probs = torch.softmax(logits, dim=-1)
+    return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ clip_task_wise_adamerging + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ CLIPTaskWiseAdaMergingAlgorithm + + +
+ + +
+

+ Bases: TaskWiseAdaMergingAlgorithm

+ + +

A class for task-wise adaptive merging of CLIP models.

+

This class extends the TaskWiseAdaMergingAlgorithm to provide specific +functionality for CLIP models, including loading datasets, constructing +zero-shot classification heads, and computing logits.

+ + +

Attributes:

+
    +
  • + modelpool + (CLIPVisionModelPool) + – +
    +

    The model pool containing CLIP models.

    +
    +
  • +
  • + _clip_processor + (CLIPProcessor) + – +
    +

    The CLIP processor for preparing inputs.

    +
    +
  • +
  • + zeroshot_weights + (dict) + – +
    +

    A dictionary to store zero-shot weights for each task.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py +
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
+    """
+    A class for task-wise adaptive merging of CLIP models.
+
+    This class extends the TaskWiseAdaMergingAlgorithm to provide specific
+    functionality for CLIP models, including loading datasets, constructing
+    zero-shot classification heads, and computing logits.
+
+    Attributes:
+        modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
+        _clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
+        zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
+    """
+
+    modelpool: CLIPVisionModelPool = None
+    _clip_processor: CLIPProcessor = None
+    zeroshot_weights = {}
+
+    def __init__(self, algorithm_config: DictConfig):
+        super().__init__(algorithm_config)
+
+    @functools.cache
+    def get_test_dataset(self, task: str):
+        """
+        Load the test dataset for the task.
+        This method is cached, so the dataset is loaded only once.
+
+        Args:
+            task (str): The name of the task.
+
+        Returns:
+            CLIPDataset: The test dataset for the task.
+        """
+        log.info(f"Loading test dataset: {task}")
+        dataset = self.modelpool.load_test_dataset(task)
+        dataset = CLIPDataset(dataset, self._clip_processor)
+        return dataset
+
+    @functools.cache
+    def get_shuffled_test_loader_iter(self, task: str):
+        """
+        Get an iterator over the shuffled test DataLoader for the task.
+
+        Args:
+            task (str): The name of the task.
+
+        Returns:
+            iterator: An iterator over the shuffled test DataLoader.
+        """
+        loader = DataLoader(
+            self.get_test_dataset(task),
+            batch_size=self.config.batch_size,
+            shuffle=True,
+            num_workers=self.config.num_workers,
+            pin_memory=True,
+        )
+        if self._fabric is not None:
+            loader = self._fabric.setup_dataloaders(loader)
+        return iter(InfiniteDataLoader(loader))
+
+    def on_test_time_adaptation_start(self):
+        """
+        Prepare for test-time adaptation.
+
+        This method loads the CLIP processor and constructs the zero-shot
+        classification head for each task.
+        """
+        clip_model_config = self.modelpool.get_model_config("_pretrained_")
+        pretrained_path = (
+            clip_model_config.pretrained_model_name_or_path
+            if hasattr(clip_model_config, "pretrained_model_name_or_path")
+            else clip_model_config.path
+        )
+
+        with timeit_context("Loading CLIP processor and pretrained CLIP model."):
+            self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
+            clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)
+
+            clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
+            self.visual_projection = clip_model.visual_projection.requires_grad_(False)
+            self.logit_scale_exp = clip_model.logit_scale.exp()
+            if self._fabric is not None:
+                self.visual_projection = self._fabric.to_device(self.visual_projection)
+                self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)
+
+        for task in self.modelpool.model_names:
+            cache_file = os.path.join(
+                self.config.cache_dir,
+                f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
+            )
+            if os.path.exists(cache_file):
+                log.info(f"Loading cached zeroshot weights for task: {task}")
+                zeroshot_weights = torch.load(cache_file, map_location="cpu")
+            else:
+                log.info(f"Construct zero shot classification head for task: {task}")
+                classnames, templates = get_classnames_and_templates(task)
+                clip_classifier.set_classification_task(classnames, templates)
+                zeroshot_weights = clip_classifier.zeroshot_weights
+                log.info(f"save zeroshot weights to {cache_file}")
+                torch.save(zeroshot_weights, cache_file)
+            self.zeroshot_weights[task] = zeroshot_weights
+            if self._fabric is not None:
+                self.zeroshot_weights[task] = self._fabric.to_device(
+                    self.zeroshot_weights[task]
+                )
+
+    def compute_logits(self, module, batch, task: str) -> Tensor:
+        """
+        Compute the logits for the given batch and task.
+
+        This method computes the image embeddings, normalizes them, and calculates
+        the cosine similarity with the text embeddings to produce classification logits.
+
+        Args:
+            module (nn.Module): The model module.
+            batch (tuple): A batch of input data.
+            task (str): The name of the task.
+
+        Returns:
+            Tensor: The classification logits for the batch.
+        """
+        images, _ = batch
+        text_embeds = self.zeroshot_weights[task]
+
+        image_embeds = module(images)[1]
+        image_embeds = self.visual_projection(image_embeds)
+
+        # normalize embeddings
+        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+
+        # cosine similarity
+        logits_per_text = (
+            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
+        )
+        logits_per_image = logits_per_text.t()
+
+        return logits_per_image
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ compute_logits(module, batch, task) + +
+ + +
+ +

Compute the logits for the given batch and task.

+

This method computes the image embeddings, normalizes them, and calculates +the cosine similarity with the text embeddings to produce classification logits.

+ + +

Parameters:

+
    +
  • + module + (Module) + – +
    +

    The model module.

    +
    +
  • +
  • + batch + (tuple) + – +
    +

    A batch of input data.

    +
    +
  • +
  • + task + (str) + – +
    +

    The name of the task.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The classification logits for the batch.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py +
def compute_logits(self, module, batch, task: str) -> Tensor:
+    """
+    Compute the logits for the given batch and task.
+
+    This method computes the image embeddings, normalizes them, and calculates
+    the cosine similarity with the text embeddings to produce classification logits.
+
+    Args:
+        module (nn.Module): The model module.
+        batch (tuple): A batch of input data.
+        task (str): The name of the task.
+
+    Returns:
+        Tensor: The classification logits for the batch.
+    """
+    images, _ = batch
+    text_embeds = self.zeroshot_weights[task]
+
+    image_embeds = module(images)[1]
+    image_embeds = self.visual_projection(image_embeds)
+
+    # normalize embeddings
+    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+
+    # cosine similarity
+    logits_per_text = (
+        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
+    )
+    logits_per_image = logits_per_text.t()
+
+    return logits_per_image
+
+
+
+ +
+ +
+ + +
+ get_shuffled_test_loader_iter(task) + + + cached + + +
+ + +
+ +

Get an iterator over the shuffled test DataLoader for the task.

+ + +

Parameters:

+
    +
  • + task + (str) + – +
    +

    The name of the task.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +iterator – +
    +

    An iterator over the shuffled test DataLoader.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py +
@functools.cache
+def get_shuffled_test_loader_iter(self, task: str):
+    """
+    Get an iterator over the shuffled test DataLoader for the task.
+
+    Args:
+        task (str): The name of the task.
+
+    Returns:
+        iterator: An iterator over the shuffled test DataLoader.
+    """
+    loader = DataLoader(
+        self.get_test_dataset(task),
+        batch_size=self.config.batch_size,
+        shuffle=True,
+        num_workers=self.config.num_workers,
+        pin_memory=True,
+    )
+    if self._fabric is not None:
+        loader = self._fabric.setup_dataloaders(loader)
+    return iter(InfiniteDataLoader(loader))
+
+
+
+ +
+ +
+ + +
+ get_test_dataset(task) + + + cached + + +
+ + +
+ +

Load the test dataset for the task. +This method is cached, so the dataset is loaded only once.

+ + +

Parameters:

+
    +
  • + task + (str) + – +
    +

    The name of the task.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +CLIPDataset – +
    +

    The test dataset for the task.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
@functools.cache
+def get_test_dataset(self, task: str):
+    """
+    Load the test dataset for the task.
+    This method is cached, so the dataset is loaded only once.
+
+    Args:
+        task (str): The name of the task.
+
+    Returns:
+        CLIPDataset: The test dataset for the task.
+    """
+    log.info(f"Loading test dataset: {task}")
+    dataset = self.modelpool.load_test_dataset(task)
+    dataset = CLIPDataset(dataset, self._clip_processor)
+    return dataset
+
+
+
+ +
+ +
+ + +
+ on_test_time_adaptation_start() + +
+ + +
+ +

Prepare for test-time adaptation.

+

This method loads the CLIP processor and constructs the zero-shot +classification head for each task.

+ +
+ Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py +
def on_test_time_adaptation_start(self):
+    """
+    Prepare for test-time adaptation.
+
+    This method loads the CLIP processor and constructs the zero-shot
+    classification head for each task.
+    """
+    clip_model_config = self.modelpool.get_model_config("_pretrained_")
+    pretrained_path = (
+        clip_model_config.pretrained_model_name_or_path
+        if hasattr(clip_model_config, "pretrained_model_name_or_path")
+        else clip_model_config.path
+    )
+
+    with timeit_context("Loading CLIP processor and pretrained CLIP model."):
+        self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
+        clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)
+
+        clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
+        self.visual_projection = clip_model.visual_projection.requires_grad_(False)
+        self.logit_scale_exp = clip_model.logit_scale.exp()
+        if self._fabric is not None:
+            self.visual_projection = self._fabric.to_device(self.visual_projection)
+            self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)
+
+    for task in self.modelpool.model_names:
+        cache_file = os.path.join(
+            self.config.cache_dir,
+            f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
+        )
+        if os.path.exists(cache_file):
+            log.info(f"Loading cached zeroshot weights for task: {task}")
+            zeroshot_weights = torch.load(cache_file, map_location="cpu")
+        else:
+            log.info(f"Construct zero shot classification head for task: {task}")
+            classnames, templates = get_classnames_and_templates(task)
+            clip_classifier.set_classification_task(classnames, templates)
+            zeroshot_weights = clip_classifier.zeroshot_weights
+            log.info(f"save zeroshot weights to {cache_file}")
+            torch.save(zeroshot_weights, cache_file)
+        self.zeroshot_weights[task] = zeroshot_weights
+        if self._fabric is not None:
+            self.zeroshot_weights[task] = self._fabric.to_device(
+                self.zeroshot_weights[task]
+            )
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +
+ InfiniteDataLoader + + +
+ + +
+ + +

A wrapper class for DataLoader to create an infinite data loader. +This is useful in case we are only interested in the number of steps and not the number of epochs.

+

This class wraps a DataLoader and provides an iterator that resets +when the end of the dataset is reached, creating an infinite loop.

+ + +

Attributes:

+
    +
  • + data_loader + (DataLoader) + – +
    +

    The DataLoader to wrap.

    +
    +
  • +
  • + data_iter + (iterator) + – +
    +

    An iterator over the DataLoader.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py +
22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
class InfiniteDataLoader:
+    """
+    A wrapper class for DataLoader to create an infinite data loader.
+    This is useful in case we are only interested in the number of steps and not the number of epochs.
+
+    This class wraps a DataLoader and provides an iterator that resets
+    when the end of the dataset is reached, creating an infinite loop.
+
+    Attributes:
+        data_loader (DataLoader): The DataLoader to wrap.
+        data_iter (iterator): An iterator over the DataLoader.
+    """
+
+    def __init__(self, data_loader):
+        self.data_loader = data_loader
+        self.data_iter = iter(data_loader)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        try:
+            data = next(self.data_iter)
+        except StopIteration:
+            self.data_iter = iter(self.data_loader)  # Reset the data loader
+            data = next(self.data_iter)
+        return data
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +

Layer-Wise AdaMerging

+ + +
+ + + +

+ layer_wise_adamerging + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ LayerWiseAdaMergingAlgorithm + + +
+ + +
+

+ Bases: ModelFusionAlgorithm, LightningFabricMixin, SimpleProfilerMixin

+ + +

Implements the Layer-Wise AdaMerging Algorithm.

+

This class merges the layers of a pretrained model with those of several fine-tuned models. +The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.

+ + + + + + +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
class LayerWiseAdaMergingAlgorithm(
+    ModelFusionAlgorithm,
+    LightningFabricMixin,
+    SimpleProfilerMixin,
+):
+    """
+    Implements the Layer-Wise AdaMerging Algorithm.
+
+    This class merges the layers of a pretrained model with those of several fine-tuned models.
+    The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
+    """
+
+    def __init__(self, algorithm_config: DictConfig):
+        """
+        Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.
+
+        Args:
+            algorithm_config (DictConfig): The configuration for the algorithm.
+        """
+        super().__init__(algorithm_config)
+
+    @torch.no_grad()
+    def construct_layer_wise_merged_model(self, modelpool: ModelPool):
+        """
+        Constructs a wrapped layer-wise merged model from model pool.
+
+        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
+        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
+        The merging weights can be initialized based on a provided configuration or loaded from a file.
+
+        Args:
+            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.
+
+        Returns:
+            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
+        """
+        pretrained_model = modelpool.load_model("_pretrained_")
+        finetuned_models = [
+            modelpool.load_model(name) for name in modelpool.model_names
+        ]
+
+        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
+        if self.config.weights is None:
+            layer_wise_weight = get_layer_wise_weights(
+                num_models=len(modelpool.model_names),
+                num_layers=len(
+                    tuple(
+                        filter(lambda p: p.requires_grad, pretrained_model.parameters())
+                    )
+                ),
+                init_values=self.config.init_values,
+            )
+        else:
+            if isinstance(self.config.weights, str):
+                # self.config.weights is a path to a saved tensor
+                layer_wise_weight = load_tensor_from_file(self.config.weights)
+            else:
+                raise ValueError(f"Unsupported weights format: {self.config.weights}")
+
+        module = LayerWiseMergedModel(
+            layer_wise_weight=layer_wise_weight,
+            pretrained_model=pretrained_model,
+            finetuned_models=finetuned_models,
+            clamp_weights=self.config.clamp_weights,
+            tie_weights=self.config.tie_weights,
+            strict=self.config.strict,
+        )
+        print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
+        return module
+
+    @rank_zero_only
+    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
+        """
+        Save the merging weights to a file.
+
+        Args:
+            file_path (str): The path to save the merging weights.
+            merging_weights (torch.Tensor): The merging weights to save.
+        """
+        if self.fabric.is_global_zero and self.config.get(
+            "save_merging_weights", False
+        ):
+            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
+                # if the file path is not absolute or relative to current working directory, save it in the log directory
+                save_path = os.path.join(self.log_dir, file_path)
+            else:
+                save_path = file_path
+            log.info(f"saving merging weights to {save_path}.")
+            if os.path.dirname(save_path):
+                os.makedirs(os.path.dirname(save_path), exist_ok=True)
+            torch.save(merging_weights.detach().cpu(), save_path)
+
+    def run(self, modelpool: ModelPool):
+        """
+        Run the Layer-Wise AdaMerging Algorithm.
+
+        This method constructs the wrapped model and performs test-time adaptation if necessary.
+
+        Args:
+            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
+
+        Returns:
+            LayerWiseMergedModel: The merged model after test-time adaptation.
+        """
+        log.info("Fusing models using layer-wise adaptive merging.")
+        self.modelpool = modelpool
+        self.log_hyperparams(self.config)
+
+        with self.profile("construct the wrapped model"):
+            module = self.construct_layer_wise_merged_model(modelpool)
+
+        if self.config.weights is not None:
+            # skip the test-time adaptation
+            return module.merge_and_unload()
+        else:
+            with self.profile("test-time adaptation"):
+                module = self.test_time_adaptation(module)
+            if self.config.get("save_merging_weights", False):
+                self.save_merging_weights(
+                    self.config.save_merging_weights, module.merge_weight
+                )
+            return module.merge_and_unload()
+
+    def on_test_time_adaptation_start(self):
+        """
+        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
+        """
+        pass
+
+    @abstractmethod
+    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
+        """
+        Loader of test dataset for test-time adaptation. labels are not needed.
+
+        Args:
+            task (str): The name of the task.
+
+        Returns:
+            DataLoader: The data loader for the test dataset.
+        """
+        pass
+
+    @abstractmethod
+    def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
+        """
+        Compute the logits for the given images and task.
+
+        Args:
+            module: The model module.
+            images (Tensor): The input images.
+            task (str): The name of the task.
+
+        Returns:
+            Tensor: The computed logits.
+        """
+        pass
+
+    def test_time_adaptation(self, module: LayerWiseMergedModel):
+        """
+        Perform test-time adaptation on the merged model.
+
+        This method adapts the merging weights during test-time to improve performance.
+
+        Args:
+            module (LayerWiseMergedModel): The merged model.
+
+        Returns:
+            LayerWiseMergedModel: The adapted merged model.
+        """
+        self.on_test_time_adaptation_start()
+
+        # configure optimizer
+        if self.config.optimizer == "adam":
+            optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
+            print(f"{optimizer=}")
+            module, optimizer = self.fabric.setup(module, optimizer)
+        else:
+            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
+
+        module.train()
+        module.merge_weights()
+        for step_idx in (
+            pbar := tqdm(
+                range(self.config.max_steps if not self.is_debug_mode else 1),
+                ("[DEBUG MODE] " if self.is_debug_mode else "")
+                + "AdaMerging Test-time adaptation",
+                dynamic_ncols=True,
+            )
+        ):
+            # default behavior for first-order optimizers
+            for task in self.modelpool.model_names:
+                with self.profile("data loading"):
+                    batch = next(self.get_shuffled_test_loader_iter(task))
+                with self.profile("forward pass"):
+                    logits = self.compute_logits(module, batch[0], task)
+                    loss = entropy_loss(logits)
+                with self.profile("backward pass"):
+                    self.fabric.backward(loss, retain_graph=True)
+
+            with self.profile("optimizer step"):
+                optimizer.step()
+                optimizer.zero_grad()
+            with self.profile("merging weights"):
+                module.merge_weights()
+
+            metrics = {
+                "train/loss": loss.item(),
+                "train/weight_max": module.merge_weight.max().item(),
+                "train/weight_min": module.merge_weight.min().item(),
+                "train/weight_mean": module.merge_weight.mean().item(),
+            }
+            self.fabric.log_dict(metrics, step=step_idx)
+            pbar.set_postfix(metrics)
+
+        self.print_profile_summary()
+        return module
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ __init__(algorithm_config) + +
+ + +
+ +

Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.

+ + +

Parameters:

+
    +
  • + algorithm_config + (DictConfig) + – +
    +

    The configuration for the algorithm.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
40
+41
+42
+43
+44
+45
+46
+47
def __init__(self, algorithm_config: DictConfig):
+    """
+    Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.
+
+    Args:
+        algorithm_config (DictConfig): The configuration for the algorithm.
+    """
+    super().__init__(algorithm_config)
+
+
+
+ +
+ +
+ + +
+ compute_logits(module, images, task) + + + abstractmethod + + +
+ + +
+ +

Compute the logits for the given images and task.

+ + +

Parameters:

+
    +
  • + module + – +
    +

    The model module.

    +
    +
  • +
  • + images + (Tensor) + – +
    +

    The input images.

    +
    +
  • +
  • + task + (str) + – +
    +

    The name of the task.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The computed logits.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
@abstractmethod
+def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
+    """
+    Compute the logits for the given images and task.
+
+    Args:
+        module: The model module.
+        images (Tensor): The input images.
+        task (str): The name of the task.
+
+    Returns:
+        Tensor: The computed logits.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ construct_layer_wise_merged_model(modelpool) + +
+ + +
+ +

Constructs a wrapped layer-wise merged model from model pool.

+

This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. +The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers). +The merging weights can be initialized based on a provided configuration or loaded from a file.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool) + – +
    +

    An object containing the pretrained model and fine-tuned models to be merged.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +LayerWiseMergedModel – +
    +

    An instance of the merged model with layer-wise weights applied.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
@torch.no_grad()
+def construct_layer_wise_merged_model(self, modelpool: ModelPool):
+    """
+    Constructs a wrapped layer-wise merged model from model pool.
+
+    This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
+    The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
+    The merging weights can be initialized based on a provided configuration or loaded from a file.
+
+    Args:
+        modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.
+
+    Returns:
+        LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
+    """
+    pretrained_model = modelpool.load_model("_pretrained_")
+    finetuned_models = [
+        modelpool.load_model(name) for name in modelpool.model_names
+    ]
+
+    # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
+    if self.config.weights is None:
+        layer_wise_weight = get_layer_wise_weights(
+            num_models=len(modelpool.model_names),
+            num_layers=len(
+                tuple(
+                    filter(lambda p: p.requires_grad, pretrained_model.parameters())
+                )
+            ),
+            init_values=self.config.init_values,
+        )
+    else:
+        if isinstance(self.config.weights, str):
+            # self.config.weights is a path to a saved tensor
+            layer_wise_weight = load_tensor_from_file(self.config.weights)
+        else:
+            raise ValueError(f"Unsupported weights format: {self.config.weights}")
+
+    module = LayerWiseMergedModel(
+        layer_wise_weight=layer_wise_weight,
+        pretrained_model=pretrained_model,
+        finetuned_models=finetuned_models,
+        clamp_weights=self.config.clamp_weights,
+        tie_weights=self.config.tie_weights,
+        strict=self.config.strict,
+    )
+    print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
+    return module
+
+
+
+ +
+ +
+ + +
+ get_shuffled_test_loader_iter(task) + + + abstractmethod + + +
+ + +
+ +

Loader of test dataset for test-time adaptation. labels are not needed.

+ + +

Parameters:

+
    +
  • + task + (str) + – +
    +

    The name of the task.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +DataLoader ( DataLoader +) – +
    +

    The data loader for the test dataset.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
@abstractmethod
+def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
+    """
+    Loader of test dataset for test-time adaptation. labels are not needed.
+
+    Args:
+        task (str): The name of the task.
+
+    Returns:
+        DataLoader: The data loader for the test dataset.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ on_test_time_adaptation_start() + +
+ + +
+ +

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
def on_test_time_adaptation_start(self):
+    """
+    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Run the Layer-Wise AdaMerging Algorithm.

+

This method constructs the wrapped model and performs test-time adaptation if necessary.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool) + – +
    +

    The model pool containing the pretrained and fine-tuned models.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +LayerWiseMergedModel – +
    +

    The merged model after test-time adaptation.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
def run(self, modelpool: ModelPool):
+    """
+    Run the Layer-Wise AdaMerging Algorithm.
+
+    This method constructs the wrapped model and performs test-time adaptation if necessary.
+
+    Args:
+        modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
+
+    Returns:
+        LayerWiseMergedModel: The merged model after test-time adaptation.
+    """
+    log.info("Fusing models using layer-wise adaptive merging.")
+    self.modelpool = modelpool
+    self.log_hyperparams(self.config)
+
+    with self.profile("construct the wrapped model"):
+        module = self.construct_layer_wise_merged_model(modelpool)
+
+    if self.config.weights is not None:
+        # skip the test-time adaptation
+        return module.merge_and_unload()
+    else:
+        with self.profile("test-time adaptation"):
+            module = self.test_time_adaptation(module)
+        if self.config.get("save_merging_weights", False):
+            self.save_merging_weights(
+                self.config.save_merging_weights, module.merge_weight
+            )
+        return module.merge_and_unload()
+
+
+
+ +
+ +
+ + +
+ save_merging_weights(file_path, merging_weights) + +
+ + +
+ +

Save the merging weights to a file.

+ + +

Parameters:

+
    +
  • + file_path + (str) + – +
    +

    The path to save the merging weights.

    +
    +
  • +
  • + merging_weights + (Tensor) + – +
    +

    The merging weights to save.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
@rank_zero_only
+def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
+    """
+    Save the merging weights to a file.
+
+    Args:
+        file_path (str): The path to save the merging weights.
+        merging_weights (torch.Tensor): The merging weights to save.
+    """
+    if self.fabric.is_global_zero and self.config.get(
+        "save_merging_weights", False
+    ):
+        if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
+            # if the file path is not absolute or relative to current working directory, save it in the log directory
+            save_path = os.path.join(self.log_dir, file_path)
+        else:
+            save_path = file_path
+        log.info(f"saving merging weights to {save_path}.")
+        if os.path.dirname(save_path):
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
+        torch.save(merging_weights.detach().cpu(), save_path)
+
+
+
+ +
+ +
+ + +
+ test_time_adaptation(module) + +
+ + +
+ +

Perform test-time adaptation on the merged model.

+

This method adapts the merging weights during test-time to improve performance.

+ + +

Parameters:

+
    +
  • + module + (LayerWiseMergedModel) + – +
    +

    The merged model.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +LayerWiseMergedModel – +
    +

    The adapted merged model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py +
def test_time_adaptation(self, module: LayerWiseMergedModel):
+    """
+    Perform test-time adaptation on the merged model.
+
+    This method adapts the merging weights during test-time to improve performance.
+
+    Args:
+        module (LayerWiseMergedModel): The merged model.
+
+    Returns:
+        LayerWiseMergedModel: The adapted merged model.
+    """
+    self.on_test_time_adaptation_start()
+
+    # configure optimizer
+    if self.config.optimizer == "adam":
+        optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
+        print(f"{optimizer=}")
+        module, optimizer = self.fabric.setup(module, optimizer)
+    else:
+        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
+
+    module.train()
+    module.merge_weights()
+    for step_idx in (
+        pbar := tqdm(
+            range(self.config.max_steps if not self.is_debug_mode else 1),
+            ("[DEBUG MODE] " if self.is_debug_mode else "")
+            + "AdaMerging Test-time adaptation",
+            dynamic_ncols=True,
+        )
+    ):
+        # default behavior for first-order optimizers
+        for task in self.modelpool.model_names:
+            with self.profile("data loading"):
+                batch = next(self.get_shuffled_test_loader_iter(task))
+            with self.profile("forward pass"):
+                logits = self.compute_logits(module, batch[0], task)
+                loss = entropy_loss(logits)
+            with self.profile("backward pass"):
+                self.fabric.backward(loss, retain_graph=True)
+
+        with self.profile("optimizer step"):
+            optimizer.step()
+            optimizer.zero_grad()
+        with self.profile("merging weights"):
+            module.merge_weights()
+
+        metrics = {
+            "train/loss": loss.item(),
+            "train/weight_max": module.merge_weight.max().item(),
+            "train/weight_min": module.merge_weight.min().item(),
+            "train/weight_mean": module.merge_weight.mean().item(),
+        }
+        self.fabric.log_dict(metrics, step=step_idx)
+        pbar.set_postfix(metrics)
+
+    self.print_profile_summary()
+    return module
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + +

+ clip_layer_wise_adamerging + + +

+ +
+ +

Example Usage:

+
fusion_bench     method=adamerging         method.name=clip_layer_wise_adamerging         method.save_merging_weights=merging_weights.pt     modelpool=clip-vit-base-patch32_TA8     taskpool=clip-vit-classification_TA8     fabric.loggers.root_dir=outputs/logs/ViT-B-32     fabric.loggers.name=clip_layer_wise_adamerging_adam
+
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ CLIPLayerWiseAdaMergingAlgorithm + + +
+ + +
+

+ Bases: CLIPClassificationMixin, LayerWiseAdaMergingAlgorithm

+ + + + + + + +
+ Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +
30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
class CLIPLayerWiseAdaMergingAlgorithm(
+    CLIPClassificationMixin,
+    LayerWiseAdaMergingAlgorithm,
+):
+    def on_test_time_adaptation_start(self):
+        """
+        Here we load the CLIP processor and construct the zero-shot classification head for each task.
+        """
+        self.setup_zero_shot_classification_head()
+
+    @functools.cache
+    def get_shuffled_test_loader_iter(self, task: str):
+        return super().get_shuffled_test_loader_iter(
+            task,
+            batch_size=self.config.batch_size,
+            num_workers=self.config.num_workers,
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ on_test_time_adaptation_start() + +
+ + +
+ +

Here we load the CLIP processor and construct the zero-shot classification head for each task.

+ +
+ Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +
34
+35
+36
+37
+38
def on_test_time_adaptation_start(self):
+    """
+    Here we load the CLIP processor and construct the zero-shot classification head for each task.
+    """
+    self.setup_zero_shot_classification_head()
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+
+
    +
  1. +

    (ICLR 2024) AdaMerging: Adaptive Model Merging for Multi-Task Learning. https://openreview.net/pdf?id=nZP6NgD3QY 

    +
  2. +
  3. +

    Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? Advances in neural information processing systems, 27, 2014. 

    +
  4. +
  5. +

    A. Tang, L. Shen, Y. Luo, N. Yin, L. Zhang, and D. Tao, “Merging Multi-Task Models via Weight-Ensembling Mixture of Experts,” ICML 2024. doi: 10.48550/arXiv.2402.00433. 

    +
  6. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/concrete_subspace/index.html b/algorithms/concrete_subspace/index.html new file mode 100644 index 00000000..fc40c601 --- /dev/null +++ b/algorithms/concrete_subspace/index.html @@ -0,0 +1,2632 @@ + + + + + + + + + + + + + + + + + + + + + + + Concrete Subspace - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Concrete Subspace Learning

+
+alt text +
+(a) Framework overview. Our proposed framework comprises two main steps: first, establishing a common subspace for task vectors across various tasks using a shared mask, and second, merging the models within this shared subspace.
+(b) Mask sampling. Here we illustrate the procedure for sampling discrete binary masks and our differentiable Concrete mask. It's important to note that while a Concrete mask can also be binarized, this binarization process is non-differentiable. +
+
+

Contrete Masking

+

The Gumbel-Max Trick

+

Consider a discrete categorical distribution parameterized by logits \(\mathbf{x} = (x_1, \dots, x_n) \in \mathbb{R}^{n}\), where \(x_i\) is the logit of the \(i\)-th category. The Gumbel-Max trick 123 states a reparameterization trick to sample from the categorical distribution by sampling from the standard Gumbel distribution \(\text{Gumbel}(\mu=0,\beta=1)\) and taking the argmax of the sum of the Gumbel random variables and the logits.

+

This trick proceeds as follows: +sample \(n\) Gumbel random variables \(g_1, \dots, g_n\) independently from the standard Gumbel distribution \(\text{Gumbel}(\mu=0,\beta=1)\) (We can draw a random sample \(u\) from a unifrom distribution on the interval \((0,1)\) and then transform it into a Gumbel-distributed variable \(g\) using the formula \(g=-\log(-\log u)\).), find the index \(i\) of that maximizes \(x_i + g_i\), then we have

+
\[ + {\arg\max}_{i\in[n]} (x_i + g_i) \sim \text{Categorical}(\text{softmax}(\mathbf{x})). +\]
+

If we represent the categorical distribution as a one-hot vector \(\mathbf{y} = (y_1, \dots, y_n) \in \{0,1\}^n\), where \(y_i=1\) indicates that the \(i\)-th category is sampled and for all \(j\neq i\), \(y_j=0\), then we have

+
\[ + \mathbb{P}(y_k=1) = \mathbb{P}\left({\arg\max}_{i\in[n]} (x_i + g_i) = k\right) = \frac{\exp(x_k)}{\sum_{i=1}^n \exp(x_i)}. +\]
+

Continuous Relaxation of the discrete Categorical Distribution

+

Since the derivative of the \({\arg\max}\) function is not defined, we cannot backpropagate the gradients through it. +To address this issue, (Maddison et al., 2017)4 proposed to use a continuous relaxation of the discrete categorical distribution. +A CONCRETE random variable (CONtinuous relaxation of disCRETE random variable) relax the condition that the one-hot vector \(\mathbf{y}\) must be located at the vertices of the \((n-1)\)-dimensional simplex \(\Delta^{n-1}\), and instead, it allows \(\mathbf{y}\) to be located anywhere inside the simplex \(\Delta^{n-1}\), i.e. \(\{ y\in \mathbb{R}^n | y_i \in [0,1], \sum_{i=1}^n y_i =1 \}\).

+

To sample a Concrete random variable \(\mathbf{y}\) from a distribution that is parameterized by a temperature hyperparameter \(\lambda > 0\) and a vector of logits \(\mathbf{x} = (x_1, \dots, x_n) \in \mathbb{R}^{n}\), we have

+
\[ + \mathbf{y} = \text{softmax}\left(\frac{\mathbf{x} + \mathbf{g}}{\lambda}\right), \quad + y_i = \frac{\exp\left((x_i + g_i)/{\lambda}\right)}{\sum_{j=1}^n \exp\left(({x_j + g_j})/{\lambda}\right)} \quad \text{for} \,\, i\in[n]. +\]
+

where \(\mathbf{g} = (g_1, \dots, g_n)\) is a vector of Gumbel random variables that are independently sampled from the standard Gumbel distribution \(\text{Gumbel}(\mu=0,\beta=1)\).

+

Concrete Masking

+

A subspace mask \(\mathbf{m}\) is a binary vector that identifies a subspace of the parameter space. +For a neural network parametrized by \(\theta\), we can use a subspace mask \(\mathbf{m}\) to identify a subspace of the parameter space \(\mathbf{\theta}\) by setting the parameters that are not in the subspace to zero, i.e. \(\mathbf{\theta} \circ \mathbf{m}\), where \(\circ\) denotes the element-wise product. +We can draw a random sample \(\mathbf{m}\) from a Bernoulli distribution \(\text{Bernoulli}(\mathbf{p}=\sigma(\mathbf{x}))\), where \(\mathbf{p}\) is the probability (\(\mathbf{x}\) denotes the logits) of each parameter being activated. However, the discrete Bernoulli distribution is not differentiable, so we cannot backpropagate the gradients through it to optimize the parameters \(\mathbf{p}\) or \(\mathbf{x}\).

+

To address this issue, we introduce the Concrete mask which can be drawn from a continuous relaxation of Bernoulli distribution. Before we introduce the Concrete mask, we first review the Gumbel-Max trick in the two-class case.

+

Let \(p_0\) and \(p_1\) denote the unnormalized probabilities of a Bernoulli random variable being 0 and 1, respectively, with \(x\) representing the logits. Then, the probability of the event \(m=1\) is given by

+
\[ + \mathbb{P}(m=1) = \frac{p_1}{p_0 + p_1} = \sigma(x), +\]
+

where \(\sigma\) denotes the sigmoid function. +In the context of the Gumbel-Max trick, the occurrence of the event \(m=1\) is determined by the condition \(g_1 + \log p_1 > g_0 + \log p_0\), where \(g_0\) and \(g_1\) are two independent standard Gumbel random variables. +Thus we have

+
\[ + \mathbb{P}(m=1) = \mathbb{P}(g_1 + \log p_1 > g_0 + \log p_0) + = \mathbb{P}\left((g_1 - g_0) + (\log p_1 - \log p_0)> 0\right). +\]
+

Because the difference of two standard Gumbel random variables is a Logistic random variable, we can replace \(g_1 - g_0\) by \(\log u - \log(1-u)\) where \(u\) is a random variable sampled from a uniform distribution on the interval \((0,1)\). +Substitute this into Eq.(\ref{eq:appendix_P_m_1}) and express the probability in terms of the logits \(x\) to simplify the expression, we have

+
\[ + \mathbb{P}(m=1) = \mathbb{P}\left(\log \frac{u}{1-u} + \log \frac{\sigma(x)}{1-\sigma(x)} > 0\right), \quad u \sim \text{Uniform}(0,1). +\]
+

The binary Concrete distribution offers a continuous relaxation of the discrete Bernoulli random variables, which is beneficial for gradient-based optimization as it allows for the backpropagation of gradients even through the sampling process. +Instead of making a hard decision as the above equation, we use a temperature parameter \(\lambda\) to control the steepness of the sigmoid function, and hence control how close our 'soft' decisions are to being 'hard' decisions. The continuous version of the Bernoulli random variable is then given by

+
\[ + \hat{m} = \sigma\left(\left(\log \frac{u}{1 - u} + \log \frac{\sigma(x)}{1 - \sigma(x)}\right) / \lambda\right). +\]
+

As the temperature \(\lambda\) approaches zero, the sigmoid function becomes a step function, and the Concrete random variable \(\hat{m}\) becomes a Bernoulli random variable, as shown in the following Figure. In the limit when \(\lambda \to 0\), this results in sampling \(m=1\) if \(\log \frac{\sigma(x)}{1 - \sigma(x)} > -\log \frac{u}{1 - u}\), consistent with the original Gumbel-Max trick. +The binary Concrete distribution thus provides a differentiable approximation to Bernoulli random variables. +We can further binarize the Concrete mask by setting the entries with values greater than 0.5 to 1 and the rest to 0.

+
+ alt text +
+ The sigmoid function \(\sigma(\cdot/\lambda)\) with different temperatures \(\lambda\). +
+
+

Method Analysis

+

Concrete AdaMerging

+
+alt text +
+Performance comparison between AdaMerging and Concrete AdaMerging. Here we show the whole process of applying AdaMerging and Concrete AdaMerging to CLIP-ViT-B/32, the y-axes are shared by these two subfigures: (a) shows the performance of the merged model during the meta-learning phase of the Concrete AdaMerging; (b) illustrates the comparison between AdaMerging with and without the Concrete mask. +
+
+

Code Integration

+

Merging CLIP models on eight image classification tasks, using the concrete task arithmetic algorithm

+
# tensorboard logs and learned checkpoints of the shared mask can be found at https://huggingface.co/tanganke/clip-vit-base-patch32_concrete-task-arithmetic_tblogs
+fusion_bench \
+    fabric.loggers.name=ViT-B-32/concrete_task_arithmetic \
+    method=concrete_subspace/clip_concrete_task_arithmetic \
+    modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
+    taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
+
+

results

+
{
+    "svhn": {
+        "accuracy": 0.903003990650177,
+        "loss": 0.37700024247169495
+    },
+    "stanford_cars": {
+        "accuracy": 0.6326327323913574,
+        "loss": 1.2553859949111938
+    },
+    "resisc45": {
+        "accuracy": 0.7558730244636536,
+        "loss": 1.017554759979248
+    },
+    "eurosat": {
+        "accuracy": 0.9407407641410828,
+        "loss": 0.20871955156326294
+    },
+    "gtsrb": {
+        "accuracy": 0.8285035490989685,
+        "loss": 0.5861473679542542
+    },
+    "mnist": {
+        "accuracy": 0.9800000190734863,
+        "loss": 0.08148527890443802
+    },
+    "dtd": {
+        "accuracy": 0.5249999761581421,
+        "loss": 2.2731478214263916
+    },
+    "sun397": {
+        "accuracy": 0.6421158909797668,
+        "loss": 1.4108904600143433
+    }
+}
+
+

Concrete AdaMerging (Layer-wise)

+
# tensorboard logs and learned checkpoints of the shared mask can be found at https://huggingface.co/tanganke/clip-vit-base-patch32_concrete-layer-wise_adamerging_tblogs
+fusion_bench \
+    fabric.loggers.name=ViT-B-32/clip_concrete_layer_wise_adamerging \
+    method=concrete_subspace/clip_concrete_layer_wise_adamerging \
+    modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
+    taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
+
+

Further Reading

+
    +
  • +

    🦙 + X. Yi, S. Zheng, L. Wang, X. Wang, and L. He, “A safety realignment framework via subspace-oriented model fusion for large language models.” arXiv, May 14, 2024. doi: 10.48550/arXiv.2405.09055.

    +
    +

    The paper introduces a safety realignment framework for large language models via subspace-oriented model fusion (SOMF, the authors learn a shared mask on the weight space of large language model), which combines safeguard capabilities of initially aligned models with fine-tuned models to ensure safety without compromising performance on downstream tasks.

    +
    +
  • +
+
+
+
    +
  1. +

    E. J. Gumbel. Statistical Theory of Extreme Values and Some Practical Applications. A Series of Lectures. Technical +Report PB175818, National Bureau of Standards, Washington, D. C. Applied Mathematics Div., 1954. URL +https://ntrl.ntis.gov/NTRL/dashboard/searchResults/titleDetail/PB175818.xhtml. 

    +
  2. +
  3. +

    R. Duncan Luce. Individual Choice Behavior. Individual Choice Behavior. John Wiley, Oxford, England, 1959 

    +
  4. +
  5. +

    Chris J Maddison, Daniel Tarlow, and Tom Minka. A* sampling. Advances in neural information processing systems, +27, 2014. 

    +
  6. +
  7. +

    Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: A Continuous Relaxation of Discrete +Random Variables, March 2017. URL http://arxiv.org/abs/1611.00712. 

    +
  8. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/depth_upscaling/index.html b/algorithms/depth_upscaling/index.html new file mode 100644 index 00000000..3ee32a05 --- /dev/null +++ b/algorithms/depth_upscaling/index.html @@ -0,0 +1,2992 @@ + + + + + + + + + + + + + + + + + + + + + + + Depth Upscaling - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Depth Upscaling

+

Usage

+

The DepthUpscalingAlgorithm is used to upscale the depth of PyTorch models. Here's a basic guide on how to use it:

+

First, import the necessary modules:

+
from omegaconf import DictConfig
+from torch import nn
+from fusion_bench.method.depth_upscaling import DepthUpscalingAlgorithm
+from fusion_bench.modelpool import to_modelpool
+
+

Create an instance of DepthUpscalingAlgorithm by passing a configuration dictionary. +This dictionary should contain the name of the method ("depth_upscaling") and a list of layer indices that determine the upscaling pattern.

+
method_config = {"name": "depth_upscaling", "layer_indices": [0, 1, 1, 0]}
+algorithm = DepthUpscalingAlgorithm(DictConfig(method_config))
+
+

Assume we have a list of PyTorch models (nn.ModuleList instances) that we want to upscale. Here, we're creating a list of linear models as an example:

+
model = nn.ModuleList([nn.Linear(10, 10) for _ in range(2)])
+
+

Then, we can the model to the run method of our algorithm:

+
upscaled_model = algorithm.run(model)
+
+

The run method will return an upscaled model. The type of the returned model will be the same as the input models (in this case, nn.ModuleList), and its length will be determined by the layer indices specified in the method configuration.

+

Examples

+

Here we provide an example of how to use the DepthUpscalingAlgorithm to upscale the depth of a Mistral model 1.

+
+ alt text +
Credit to "SOLAR 10.7B: Scaling Large Language Models with Simple yet Effective Depth Up-Scaling"
+
+
from omegaconf import DictConfig
+from torch import nn
+from transformers import AutoModelForCausalLM, MistralConfig, MistralForCausalLM
+from fusion_bench.method.depth_upscaling import DepthUpscalingAlgorithm
+
+# create a Mistral model
+# here we randomly initialize the model for demonstration purposes
+# in practice, you would load a pretrained model
+model_config = MistralConfig(
+    # https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json
+    **{
+        "architectures": ["MistralForCausalLM"],
+        "bos_token_id": 1,
+        "eos_token_id": 2,
+        "hidden_act": "silu",
+        "hidden_size": 4096,
+        "initializer_range": 0.02,
+        "intermediate_size": 14336,
+        "max_position_embeddings": 32768,
+        "model_type": "mistral",
+        "num_attention_heads": 32,
+        "num_hidden_layers": 32,
+        "num_key_value_heads": 8,
+        "rms_norm_eps": 1e-05,
+        "rope_theta": 10000.0,
+        "sliding_window": 4096,
+        "tie_word_embeddings": False,
+        "torch_dtype": "bfloat16",
+        "transformers_version": "4.34.0.dev0",
+        "use_cache": True,
+        "vocab_size": 32000,
+    }
+)
+print('creating model')
+model: MistralForCausalLM = AutoModelForCausalLM.from_config(model_config)
+
+method_config = {
+    "name": "depth_upscaling",
+    "layer_indices": ["range(0,24)", "range(8,32)"],
+}
+algorithm = DepthUpscalingAlgorithm(DictConfig(method_config))
+print('upscaling model')
+upscaled_model = algorithm.run(model.model.layers)
+
+# substitute the model with the upscaled model
+model.model.layers = upscaled_model
+
+

Code Integration

+

The DepthUpscalingAlgorithm is integrated into the fusion_bench package. You can use it by specifying "depth_upscaling" as the method name in the command line or configuration file.

+
config/method/depth_upscaling.yaml
name: depth_upscaling
+# this should be a list of integers or string, indicating the sequence of layers. If the entry is an integer, it will use the n-th layer of the model. If the entry is a string, it will use the layers specified by the string. The string should be a valid python expression that evaluates to a list of integers.
+# for example, ["range(0,12)", "range(6,12)"] will use the first 12 layers and the last 6 layers of the model to construct the new model
+# [0, 2, 4, "range(6,12)"] will use the 1st, 3rd, 5th, and the 7th to 12th layers of the model to construct the new model
+layer_indices: null
+
+

You can then run the fusion_bench command with the specified configuration file:

+
fusion_bench method=depth_upscaling ...
+
+

References

+ + +
+ + + +

+ DepthUpscalingAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + +

Implements the Depth Upscaling Algorithm.

+
    +
  • Kim et al. SOLAR 10.7B: Scaling Large Language Models with Simple yet Effective Depth Up-Scaling. http://arxiv.org/abs/2312.15166
  • +
+

This class extends the BaseModelFusionAlgorithm to handle depth upscaling of models. +It supports upscaling the depth of a model by duplicating specified layers.

+ + +

Parameters:

+
    +
  • +
    layer_indices +
    (list) + – +
    +

    List of layer indices to duplicate.

    +
    +
  • +
  • +
    **kwargs +
    – +
    +

    Additional keyword arguments.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/depth_upscaling/depth_upscaling.py +
15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
class DepthUpscalingAlgorithm(BaseAlgorithm):
+    R"""
+    Implements the Depth Upscaling Algorithm.
+
+    - Kim et al. SOLAR 10.7B: Scaling Large Language Models with Simple yet Effective Depth Up-Scaling. http://arxiv.org/abs/2312.15166
+
+    This class extends the `BaseModelFusionAlgorithm` to handle depth upscaling of models.
+    It supports upscaling the depth of a model by duplicating specified layers.
+
+    Args:
+        layer_indices (list): List of layer indices to duplicate.
+        **kwargs: Additional keyword arguments.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "layer_indices": "layer_indices",
+    }
+
+    def __init__(self, layer_indices: Union[str, List[int]], **kwargs):
+        self.layer_indices = layer_indices
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def run(self, modelpool: nn.ModuleList | BaseModelPool) -> nn.ModuleList:
+        """
+        Executes the depth upscaling algorithm on a given model pool.
+
+        This method checks the type of the model pool, ensures that it contains only one model, and verifies that the model is an instance of `nn.ModuleList`.
+
+        Args:
+            modelpool (nn.ModuleList | ModelPool): The pool of models to upscale. Must contain only one model.
+
+        Returns:
+            nn.ModuleList: The upscaled model.
+
+        Raises:
+            AssertionError: If the model pool contains more than one model or if the model is not an instance of `nn.ModuleList`.
+            ValueError: If an invalid layer specification is provided in the configuration.
+        """
+        # check the modelpool type
+        if isinstance(modelpool, BaseModelPool):
+            assert len(modelpool) == 1, "DepthUpscaling only support one model"
+            model = modelpool.load_model(modelpool.model_names[0])
+            assert isinstance(
+                model, nn.ModuleList
+            ), f"The model should be a `nn.ModuleList`, but got {type(model)}"
+        elif isinstance(modelpool, nn.ModuleList):
+            model = modelpool
+        else:
+            raise AssertionError(
+                f"Invalid modelpool type: {type(modelpool)}. Expected `ModelPool` or `nn.ModuleList`."
+            )
+
+        # parse the layers
+        layer_indices = self.layer_indices
+        parsed_layer_indices = []
+        for layer in layer_indices:
+            if isinstance(layer, int):
+                parsed_layer_indices.append(layer)
+            elif isinstance(layer, str):
+                parsed_layer_indices.extend(eval(layer))
+            else:
+                raise ValueError("Invalid layer specification: {}".format(layer))
+
+        # create a new model with the specified layers
+        new_model = nn.ModuleList(
+            [
+                deepcopy(model[i])
+                for i in tqdm(
+                    parsed_layer_indices, desc="constructing depth-upscaled model"
+                )
+            ]
+        )
+
+        return new_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Executes the depth upscaling algorithm on a given model pool.

+

This method checks the type of the model pool, ensures that it contains only one model, and verifies that the model is an instance of nn.ModuleList.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (ModuleList | ModelPool) + – +
    +

    The pool of models to upscale. Must contain only one model.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + ModuleList + – +
    +

    nn.ModuleList: The upscaled model.

    +
    +
  • +
+ + +

Raises:

+
    +
  • + AssertionError + – +
    +

    If the model pool contains more than one model or if the model is not an instance of nn.ModuleList.

    +
    +
  • +
  • + ValueError + – +
    +

    If an invalid layer specification is provided in the configuration.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/depth_upscaling/depth_upscaling.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
@torch.no_grad()
+def run(self, modelpool: nn.ModuleList | BaseModelPool) -> nn.ModuleList:
+    """
+    Executes the depth upscaling algorithm on a given model pool.
+
+    This method checks the type of the model pool, ensures that it contains only one model, and verifies that the model is an instance of `nn.ModuleList`.
+
+    Args:
+        modelpool (nn.ModuleList | ModelPool): The pool of models to upscale. Must contain only one model.
+
+    Returns:
+        nn.ModuleList: The upscaled model.
+
+    Raises:
+        AssertionError: If the model pool contains more than one model or if the model is not an instance of `nn.ModuleList`.
+        ValueError: If an invalid layer specification is provided in the configuration.
+    """
+    # check the modelpool type
+    if isinstance(modelpool, BaseModelPool):
+        assert len(modelpool) == 1, "DepthUpscaling only support one model"
+        model = modelpool.load_model(modelpool.model_names[0])
+        assert isinstance(
+            model, nn.ModuleList
+        ), f"The model should be a `nn.ModuleList`, but got {type(model)}"
+    elif isinstance(modelpool, nn.ModuleList):
+        model = modelpool
+    else:
+        raise AssertionError(
+            f"Invalid modelpool type: {type(modelpool)}. Expected `ModelPool` or `nn.ModuleList`."
+        )
+
+    # parse the layers
+    layer_indices = self.layer_indices
+    parsed_layer_indices = []
+    for layer in layer_indices:
+        if isinstance(layer, int):
+            parsed_layer_indices.append(layer)
+        elif isinstance(layer, str):
+            parsed_layer_indices.extend(eval(layer))
+        else:
+            raise ValueError("Invalid layer specification: {}".format(layer))
+
+    # create a new model with the specified layers
+    new_model = nn.ModuleList(
+        [
+            deepcopy(model[i])
+            for i in tqdm(
+                parsed_layer_indices, desc="constructing depth-upscaled model"
+            )
+        ]
+    )
+
+    return new_model
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/dummy/index.html b/algorithms/dummy/index.html new file mode 100644 index 00000000..1ba0a8b7 --- /dev/null +++ b/algorithms/dummy/index.html @@ -0,0 +1,2470 @@ + + + + + + + + + + + + + + + + + + + + + + + Dummy Algorithm - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Dummy Algorithm

+

The Dummy Algorithm is a simple algorithm that does not perform any fusion operation. Instead, it returns a pretrained model if one is available in the model pool. If no pretrained model is available, it returns the first model in the model pool. +This algorithm is useful for testing and debugging purposes, as it allows you to quickly check if the model pool is set up correctly and the fusion process is working as expected.

+

Usage

+

To use the Dummy Algorithm, you need to specify "dummy" as the algorithm name.

+
fusion_bench method=dummy ...
+
+

Implementation

+

The implementation of the Dummy Algorithm is straightforward. Here is the main method of the DummyAlgorithm class:

+ + +
+ + + +

+ DummyAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + + + + + + +
+ Source code in fusion_bench/method/dummy.py +
15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
class DummyAlgorithm(BaseAlgorithm):
+    def run(self, modelpool: BaseModelPool):
+        """
+        This method returns the pretrained model from the model pool.
+        If the pretrained model is not available, it returns the first model from the model pool.
+
+        Args:
+            modelpool (BaseModelPool): The pool of models to fuse.
+
+        Raises:
+            AssertionError: If the model is not found in the model pool.
+        """
+        if isinstance(modelpool, nn.Module):
+            return modelpool
+        elif not isinstance(modelpool, BaseModelPool):
+            modelpool = BaseModelPool(modelpool)
+
+        model = modelpool.load_pretrained_or_first_model()
+
+        assert model is not None, "Model is not found in the model pool."
+        return model
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/fisher_merging/index.html b/algorithms/fisher_merging/index.html new file mode 100644 index 00000000..91d47772 --- /dev/null +++ b/algorithms/fisher_merging/index.html @@ -0,0 +1,3246 @@ + + + + + + + + + + + + + + + + + + + + + + + Fisher Merging - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

(Diagonal) Fisher Merging

+

The Fisher merging algorithm 1 is a per-parameter weighed averaging method that assigns weights to the models based on the Fisher information matrix of the models on some labeled data. +The Fisher information matrix \(F_\theta\) of a model with parameters \(\theta\) can be expressed as:

+
\[ F_\theta = \mathbb{E}_{x \sim p(x)} \left[ \nabla_\theta \log p(y|x, \theta) \nabla_\theta \log p(y|x, \theta)^T \right] \]
+

where \(p(x)\) is the data distribution, \(p(y|x, \theta)\) is the model's output distribution, for example, the softmax output of a classification model, and \(\nabla_\theta\) is the gradient with respect to the model's parameters \(\theta\). +The Fisher information matrix can be used to estimate the importance of each parameter in the model and thus assign weights to the models based on their Fisher information. +In addition, the Fisher information matrix can be used to estimate the similarity between tasks, which can be useful in auxiliary-task learning and multi-task learning scenarios 2.

+

As the full Fisher information matrix is often computationally expensive to compute and memory-intensive to store, we approximate using the diagonal Fisher information matrix, which is the diagonal of the full Fisher information matrix. +The diagonal Fisher information matrix can be computed as:

+
\[ \hat{F}_\theta = \mathbb{E}_{x \sim p(x)} \left[ \left(\nabla_\theta \log p(y|x, \theta)\right)^2 \right] \]
+

Assuming we have \(n\) models with parameters \(\theta_i\) and diagonal Fisher information matrices \(\hat{F}_{\theta_i}\), the Fisher merging algorithm computes the merged model's parameters \(\theta\) as follows:

+
\[ \theta^{(j)} = \frac{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)} \theta_i^{(j)}}{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)}} \]
+

where \(\theta_i\) are the parameters of the individual models, \(\hat{F}_{\theta_i}\) are the diagonal Fisher information matrices of the individual models, and \(j\) indexes the parameters of the models. +The Fisher merging algorithm can be considered a per-weight weighed averaging method, where the weights are determined by the Fisher information of each parameter in the models.

+

Code Integration

+

Example of merging eight CLIP-ViT-B/32 models using Fisher merging:

+
fusion_bench method=clip_fisher_merging \
+  modelpool=clip-vit-base-patch32_TA8 \
+  taskpool=clip-vit-classification_TA8
+
+

Merge eight CLIP-ViT-L/14 models using Fisher merging:

+
fusion_bench \
+  method=clip_fisher_merging \
+    method.batch_size=8 method.num_workers=4 \
+  modelpool=clip-vit-large-patch14_TA8 \
+  taskpool=clip-vit-classification_TA8 \
+    taskpool.clip_model=openai/clip-vit-large-patch14
+
+

Merge GPT-2 models for text classification tasks:

+
fusion_bench \
+  method=gpt2_fisher_merging \
+    method.num_fisher_examples=512 method.batch_size=8 \
+  modelpool=gpt-2_glue \
+  taskpool=gpt-2_glue
+
+

References

+ + +
+ + + +

+ FisherMergingAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + +

Implements the Fisher Merging Algorithm.

+

This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights. +It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.

+ + +

Methods:

+
    +
  • + run + – +
    +

    BaseModelPool) -> nn.Module: +Executes the Fisher merging process on the model pool and returns the merged model.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/fisher_merging/fisher_merging.py +
class FisherMergingAlgorithm(BaseAlgorithm):
+    """
+    Implements the Fisher Merging Algorithm.
+
+    This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights.
+    It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.
+
+    Methods:
+        run(modelpool: BaseModelPool) -> nn.Module:
+            Executes the Fisher merging process on the model pool and returns the merged model.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "exclude_param_names_regex": "exclude_param_names_regex",
+        "normalize_fisher_weight": "normalize_fisher_weight",
+        "minimal_fisher_weight": "minimal_fisher_weight",
+        "num_fisher_examples": "num_fisher_examples",
+    }
+
+    def __init__(
+        self,
+        *,
+        exclude_param_names_regex: list,
+        normalize_fisher_weight: bool,
+        minimal_fisher_weight: float,
+        num_fisher_examples: int,
+    ):
+        super().__init__()
+        self.exclude_param_names_regex = exclude_param_names_regex
+        self.normalize_fisher_weight = normalize_fisher_weight
+        self.minimal_fisher_weight = minimal_fisher_weight
+        self.num_fisher_examples = num_fisher_examples
+
+    def run(self, modelpool: BaseModelPool) -> nn.Module:
+        """
+        Run the Fisher Merging Algorithm.
+
+        This method constructs the wrapped model and performs test-time adaptation if necessary.
+
+        Args:
+            modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.
+
+        Returns:
+            nn.Module: The merged model after test-time adaptation.
+        """
+        log.info("Running Fisher Merging Algorithm")
+        if isinstance(modelpool, (dict, list, tuple)):
+            modelpool = BaseModelPool(modelpool)
+
+        assert len(modelpool) > 0, "model pool is empty"
+        assert (
+            modelpool.has_pretrained
+        ), "no pretrained model (base model) in the model pool"
+
+        self.modelpool = modelpool
+        self.on_fisher_merging_start()
+
+        # dictionary of list, where key is the parameter name,
+        # value is a list of the corresponding parameters of all the models that need to be merged
+        models_to_merge_param_dict = defaultdict(list)
+
+        # list of dictionaries with length len(models_to_merge),
+        # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
+        models_to_merge_fisher_weights_list = []
+
+        param_names_to_merge = None
+
+        for name, model in modelpool.named_models():
+            param_dict = model.state_dict()
+            if param_names_to_merge is None:
+                param_names_to_merge = get_param_names_to_merge(
+                    input_param_names=list(param_dict.keys()),
+                    exclude_param_names_regex=self.config.get(
+                        "exclude_param_names_regex", []
+                    ),
+                )
+
+            for param_name in param_names_to_merge:
+                models_to_merge_param_dict[param_name].append(param_dict[param_name])
+
+            model_to_merge_fisher_weights = self.get_fisher_weights(
+                model_name=name,
+                model=model,
+                train_dataset=modelpool.load_train_dataset(name),
+                param_names_to_merge=param_names_to_merge,
+            )
+
+            models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
+
+        merged_params = merging_with_fisher_weights(
+            models_to_merge_param_dict=models_to_merge_param_dict,
+            models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
+            fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
+            normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
+            minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
+        )
+
+        merged_model = modelpool.load_model("_pretrained_")
+        merged_model.load_state_dict(merged_params, strict=False)
+        return merged_model
+
+    def get_fisher_weights(
+        self,
+        model_name: str,
+        model: nn.Module,
+        train_dataset,
+        param_names_to_merge: List[str],
+    ) -> Dict[str, Tensor]:
+        """
+        Compute the Fisher weights for the given model and training dataset.
+
+        Args:
+            model_name (str): The name of the model.
+            model (nn.Module): The model module.
+            train_dataset: The training dataset.
+            param_names_to_merge (List[str]): List of parameter names to merge.
+
+        Returns:
+            Dict[str, Tensor]: The computed Fisher weights for each parameter.
+        """
+        # this function is used to compute fisher weights for a model
+        # it should be implemented in the subclass
+        raise NotImplementedError
+
+    def on_fisher_merging_start(self):
+        """
+        Setup the zero-shot classification head before starting the Fisher merging process.
+        """
+        # this function is used to initialize some variables before running fisher merging
+        pass
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ get_fisher_weights(model_name, model, train_dataset, param_names_to_merge) + +
+ + +
+ +

Compute the Fisher weights for the given model and training dataset.

+ + +

Parameters:

+
    +
  • +
    model_name +
    (str) + – +
    +

    The name of the model.

    +
    +
  • +
  • +
    model +
    (Module) + – +
    +

    The model module.

    +
    +
  • +
  • +
    train_dataset +
    – +
    +

    The training dataset.

    +
    +
  • +
  • +
    param_names_to_merge +
    (List[str]) + – +
    +

    List of parameter names to merge.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + Dict[str, Tensor] + – +
    +

    Dict[str, Tensor]: The computed Fisher weights for each parameter.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/fisher_merging/fisher_merging.py +
def get_fisher_weights(
+    self,
+    model_name: str,
+    model: nn.Module,
+    train_dataset,
+    param_names_to_merge: List[str],
+) -> Dict[str, Tensor]:
+    """
+    Compute the Fisher weights for the given model and training dataset.
+
+    Args:
+        model_name (str): The name of the model.
+        model (nn.Module): The model module.
+        train_dataset: The training dataset.
+        param_names_to_merge (List[str]): List of parameter names to merge.
+
+    Returns:
+        Dict[str, Tensor]: The computed Fisher weights for each parameter.
+    """
+    # this function is used to compute fisher weights for a model
+    # it should be implemented in the subclass
+    raise NotImplementedError
+
+
+
+ +
+ +
+ + +
+ on_fisher_merging_start() + +
+ + +
+ +

Setup the zero-shot classification head before starting the Fisher merging process.

+ +
+ Source code in fusion_bench/method/fisher_merging/fisher_merging.py +
def on_fisher_merging_start(self):
+    """
+    Setup the zero-shot classification head before starting the Fisher merging process.
+    """
+    # this function is used to initialize some variables before running fisher merging
+    pass
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Run the Fisher Merging Algorithm.

+

This method constructs the wrapped model and performs test-time adaptation if necessary.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (BaseModelPool) + – +
    +

    The model pool containing the pretrained and fine-tuned models.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + Module + – +
    +

    nn.Module: The merged model after test-time adaptation.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/fisher_merging/fisher_merging.py +
def run(self, modelpool: BaseModelPool) -> nn.Module:
+    """
+    Run the Fisher Merging Algorithm.
+
+    This method constructs the wrapped model and performs test-time adaptation if necessary.
+
+    Args:
+        modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.
+
+    Returns:
+        nn.Module: The merged model after test-time adaptation.
+    """
+    log.info("Running Fisher Merging Algorithm")
+    if isinstance(modelpool, (dict, list, tuple)):
+        modelpool = BaseModelPool(modelpool)
+
+    assert len(modelpool) > 0, "model pool is empty"
+    assert (
+        modelpool.has_pretrained
+    ), "no pretrained model (base model) in the model pool"
+
+    self.modelpool = modelpool
+    self.on_fisher_merging_start()
+
+    # dictionary of list, where key is the parameter name,
+    # value is a list of the corresponding parameters of all the models that need to be merged
+    models_to_merge_param_dict = defaultdict(list)
+
+    # list of dictionaries with length len(models_to_merge),
+    # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
+    models_to_merge_fisher_weights_list = []
+
+    param_names_to_merge = None
+
+    for name, model in modelpool.named_models():
+        param_dict = model.state_dict()
+        if param_names_to_merge is None:
+            param_names_to_merge = get_param_names_to_merge(
+                input_param_names=list(param_dict.keys()),
+                exclude_param_names_regex=self.config.get(
+                    "exclude_param_names_regex", []
+                ),
+            )
+
+        for param_name in param_names_to_merge:
+            models_to_merge_param_dict[param_name].append(param_dict[param_name])
+
+        model_to_merge_fisher_weights = self.get_fisher_weights(
+            model_name=name,
+            model=model,
+            train_dataset=modelpool.load_train_dataset(name),
+            param_names_to_merge=param_names_to_merge,
+        )
+
+        models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
+
+    merged_params = merging_with_fisher_weights(
+        models_to_merge_param_dict=models_to_merge_param_dict,
+        models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
+        fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
+        normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
+        minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
+    )
+
+    merged_model = modelpool.load_model("_pretrained_")
+    merged_model.load_state_dict(merged_params, strict=False)
+    return merged_model
+
+
+
+ +
+ + + +
+ +
+ +
+
+
    +
  1. +

    M. Matena, C. Raffel. "Merging Models with Fisher-Weighted Averaging" http://arxiv.org/abs/2111.09832 

    +
  2. +
  3. +

    C. Wu, et al. "Pi-Tuning: Transferring Multimodal Foundation Models with Optimal Multi-task Interpolation". https://github.com/TencentARC/pi-Tuning 

    +
  4. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/images/Task Arithmetic.png b/algorithms/images/Task Arithmetic.png new file mode 100644 index 00000000..924dd987 Binary files /dev/null and b/algorithms/images/Task Arithmetic.png differ diff --git a/algorithms/images/adamerging.png b/algorithms/images/adamerging.png new file mode 100644 index 00000000..2620ed7e Binary files /dev/null and b/algorithms/images/adamerging.png differ diff --git a/algorithms/images/adamerging_layerwise_coefficients.png b/algorithms/images/adamerging_layerwise_coefficients.png new file mode 100644 index 00000000..4518b381 Binary files /dev/null and b/algorithms/images/adamerging_layerwise_coefficients.png differ diff --git a/algorithms/images/adamerging_model_merging_coefficients.png b/algorithms/images/adamerging_model_merging_coefficients.png new file mode 100644 index 00000000..ba0a8d5b Binary files /dev/null and b/algorithms/images/adamerging_model_merging_coefficients.png differ diff --git a/algorithms/images/concrete_adamerging_vs_adamerging.png b/algorithms/images/concrete_adamerging_vs_adamerging.png new file mode 100644 index 00000000..ed0e1781 Binary files /dev/null and b/algorithms/images/concrete_adamerging_vs_adamerging.png differ diff --git a/algorithms/images/concrete_subspace_learning.png b/algorithms/images/concrete_subspace_learning.png new file mode 100644 index 00000000..16a75380 Binary files /dev/null and b/algorithms/images/concrete_subspace_learning.png differ diff --git a/algorithms/images/ewemoe.png b/algorithms/images/ewemoe.png new file mode 100644 index 00000000..9642d3cb Binary files /dev/null and b/algorithms/images/ewemoe.png differ diff --git a/algorithms/images/ewemoe_1.png b/algorithms/images/ewemoe_1.png new file mode 100644 index 00000000..6f9f38d4 Binary files /dev/null and b/algorithms/images/ewemoe_1.png differ diff --git a/algorithms/images/ewemoe_2.png b/algorithms/images/ewemoe_2.png new file mode 100644 index 00000000..6b8d1a2c Binary files /dev/null and b/algorithms/images/ewemoe_2.png differ diff --git a/algorithms/images/fedmr_model_recombination.jpg b/algorithms/images/fedmr_model_recombination.jpg new file mode 100644 index 00000000..abee7e17 Binary files /dev/null and b/algorithms/images/fedmr_model_recombination.jpg differ diff --git a/algorithms/images/max-model_predictor.png b/algorithms/images/max-model_predictor.png new file mode 100644 index 00000000..dbaad45e Binary files /dev/null and b/algorithms/images/max-model_predictor.png differ diff --git a/algorithms/images/pwe_moe.png b/algorithms/images/pwe_moe.png new file mode 100644 index 00000000..c3161def Binary files /dev/null and b/algorithms/images/pwe_moe.png differ diff --git a/algorithms/images/sigmoid.png b/algorithms/images/sigmoid.png new file mode 100644 index 00000000..5797db77 Binary files /dev/null and b/algorithms/images/sigmoid.png differ diff --git a/algorithms/images/smile_upscaling.png b/algorithms/images/smile_upscaling.png new file mode 100644 index 00000000..eb5aa883 Binary files /dev/null and b/algorithms/images/smile_upscaling.png differ diff --git a/algorithms/images/solar10.7B.png b/algorithms/images/solar10.7B.png new file mode 100644 index 00000000..ec6f6ba4 Binary files /dev/null and b/algorithms/images/solar10.7B.png differ diff --git a/algorithms/images/sparse_upcycling.png b/algorithms/images/sparse_upcycling.png new file mode 100644 index 00000000..c6bbe596 Binary files /dev/null and b/algorithms/images/sparse_upcycling.png differ diff --git a/algorithms/images/ties_merging.jpg b/algorithms/images/ties_merging.jpg new file mode 100644 index 00000000..0dc77957 Binary files /dev/null and b/algorithms/images/ties_merging.jpg differ diff --git a/algorithms/images/ties_merging_hyperparameter_tuning.png b/algorithms/images/ties_merging_hyperparameter_tuning.png new file mode 100644 index 00000000..b850bc3b Binary files /dev/null and b/algorithms/images/ties_merging_hyperparameter_tuning.png differ diff --git a/algorithms/images/wemoe.png b/algorithms/images/wemoe.png new file mode 100644 index 00000000..30ca4dbd Binary files /dev/null and b/algorithms/images/wemoe.png differ diff --git a/algorithms/images/wemoe_loss_landscape.png b/algorithms/images/wemoe_loss_landscape.png new file mode 100644 index 00000000..dd2f8ac8 Binary files /dev/null and b/algorithms/images/wemoe_loss_landscape.png differ diff --git a/algorithms/images/wemoe_lr_tuning.png b/algorithms/images/wemoe_lr_tuning.png new file mode 100644 index 00000000..f6f3af69 Binary files /dev/null and b/algorithms/images/wemoe_lr_tuning.png differ diff --git a/algorithms/index.html b/algorithms/index.html new file mode 100644 index 00000000..3a577f38 --- /dev/null +++ b/algorithms/index.html @@ -0,0 +1,2817 @@ + + + + + + + + + + + + + + + + + + + + + + + Introduction to Algorithm Module - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Introduction to Algorithm Module

+

The Fusion Algorithm module is a core component of the FusionBench project, dedicated to the implementation and execution of various model fusion techniques. +This module provides the mechanisms necessary to combine multiple models from the Model Pool, enabling nuanced and optimized model merging operations.

+

Key Points of the Fusion Algorithm Module

+
    +
  • Adaptive Fusion: The module supports advanced fusion techniques, such as AdaMerging, that adaptively learn the best coefficients for model merging using sophisticated methods like entropy minimization.
  • +
  • Algorithm Configuration: Algorithms are defined and loaded based on configuration files, ensuring flexibility and ease of experimentation. This modular approach allows researchers to switch between different fusion methods seamlessly.
  • +
  • Model Integration: It facilitates the integration of multiple models, combining their strengths and mitigating individual weaknesses. The result is a single, merged model that ideally performs better than any individual model alone or has multitasking capability.
  • +
  • Evaluation Support: Once the model fusion process is completed, the merged model can interface with the TaskPool to evaluate the performance of the merged model across various tasks, providing a comprehensive assessment of its capabilities.
  • +
+

Example Capabilities

+
    +
  • Entropy Minimization: Some algorithms in this module utilize entropy minimization on unlabeled test samples to refine merging coefficients, ensuring that the fusion process is data-driven and optimized.
  • +
  • Layer-wise and Task-wise Fusion: It allows both layer-wise and task-wise model fusion, where merging coefficients can be learned for individual layers or entire tasks, respectively.
  • +
+

Code Integration

+

The module is typically invoked through a configuration-driven approach in CLI scripts, enabling users to specify fusion algorithms and parameters via YAML configuration files. This method ensures reproducibility and ease of use. +For more information, see the document of fusion_bench CLI.

+

ModelFusionAlgorithm is the base class for all fusion algorithms in the Fusion Algorithm module. +It provides a common interface for different fusion techniques, allowing for seamless integration and execution of various algorithms.

+

Example Usage

+

Implement your own model fusion algorithm:

+
from fusion_bench.method import BaseModelFusionAlgorithm
+from fusion_bench.modelpool import BaseModelPool
+
+class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
+    """
+    An example of a derived model fusion algorithm.
+    """
+
+    # _config_mapping maps the attribution to the corresponding key in the configuration file.
+    _config_mapping = BaseModelFusionAlgorithm._config_mapping | {
+        "hyperparam_attr_1": "hyperparam_1",
+        "hyperparam_attr_2": "hyperparam_2",
+    }
+
+    def __init__(self, hyperparam_1, hyperparam_2, **kwargs):
+        self.hyperparam_attr_1 = hyperparam_1
+        self.hyperparam_attr_2 = hyperparam_2
+        super().__init__(**kwargs)
+
+    def run(self, modelpool: BaseModelPool):
+        # implement the fusion algorithm here
+        raise NotImplementedError(
+            "DerivedModelFusionAlgorithm.run() is not implemented."
+        )
+
+

We provide a simple example to illustrate how the algorithm is used in the FusionBench as follows:

+
import logging
+from typing import Dict, Optional
+from omegaconf import DictConfig
+
+from fusion_bench.utils import instantiate
+
+log = logging.getLogger(__name__)
+
+def run_model_fusion(
+    method_config: DictConfig,
+    modelpool_config: DictConfig,
+    taskpool_config: Optional[DictConfig] = None,
+    seed: Optional[int] = None,
+    print_config: bool = True,
+    **kwargs
+):
+    """
+    Run the model fusion process.
+
+    Args:
+        method_config: Configuration for the fusion method.
+        modelpool_config: Configuration for the model pool.
+        taskpool_config: Configuration for the task pool (optional).
+    """
+    # Instantiate components: modelpool, method, and taskpool
+    modelpool = instantiate(modelpool_config)
+    method = instantiate(method_config)
+    taskpool = None
+    if taskpool_config is not None:
+        taskpool = instantiate(taskpool_config)
+
+    # Run fusion
+    merged_model = method.run(modelpool)
+
+    # Evaluate if taskpool is provided
+    if taskpool is not None:
+        report = taskpool.evaluate(merged_model)
+
+

In summary, the Fusion Algorithm module is vital for the model merging operations within FusionBench, leveraging sophisticated techniques to ensure optimal fusion and performance evaluation of deep learning models. This capability makes it an indispensable tool for researchers and practitioners focusing on model fusion strategies.

+

References

+ + +
+ + + +

+ BaseAlgorithm + + +

+ + +
+

+ Bases: BaseYAMLSerializableModel

+ + +

Base class for model fusion algorithms.

+

This class provides a template for implementing model fusion algorithms. +Subclasses must implement the run method to define the fusion logic.

+ + + + + + +
+ Source code in fusion_bench/method/base_algorithm.py +
13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
class BaseAlgorithm(BaseYAMLSerializableModel):
+    """
+    Base class for model fusion algorithms.
+
+    This class provides a template for implementing model fusion algorithms.
+    Subclasses must implement the `run` method to define the fusion logic.
+    """
+
+    _program = None
+
+    @abstractmethod
+    def run(self, modelpool: BaseModelPool):
+        """
+        Fuse the models in the given model pool.
+
+        This method must be implemented by subclasses to define the fusion logic.
+
+        Examples:
+            >>> algorithm = SimpleAverageAlgorithm()
+            >>> modelpool = ModelPool()
+            >>> merged_model = algorithm.run(modelpool)
+
+        Args:
+            modelpool (BaseModelPool): The pool of models to fuse.
+        """
+        pass
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + + + abstractmethod + + +
+ + +
+ +

Fuse the models in the given model pool.

+

This method must be implemented by subclasses to define the fusion logic.

+ + +

Examples:

+
>>> algorithm = SimpleAverageAlgorithm()
+>>> modelpool = ModelPool()
+>>> merged_model = algorithm.run(modelpool)
+
+ + +

Parameters:

+
    +
  • +
    modelpool +
    (BaseModelPool) + – +
    +

    The pool of models to fuse.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/base_algorithm.py +
23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
@abstractmethod
+def run(self, modelpool: BaseModelPool):
+    """
+    Fuse the models in the given model pool.
+
+    This method must be implemented by subclasses to define the fusion logic.
+
+    Examples:
+        >>> algorithm = SimpleAverageAlgorithm()
+        >>> modelpool = ModelPool()
+        >>> merged_model = algorithm.run(modelpool)
+
+    Args:
+        modelpool (BaseModelPool): The pool of models to fuse.
+    """
+    pass
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ BaseModelFusionAlgorithm = BaseAlgorithm + + + module-attribute + + +

+ + +
+ +

Alias for BaseAlgorithm.

+
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/layer_recombination/index.html b/algorithms/layer_recombination/index.html new file mode 100644 index 00000000..a0ca6535 --- /dev/null +++ b/algorithms/layer_recombination/index.html @@ -0,0 +1,2213 @@ + + + + + + + + + + + + + + + + + + + Layer Recombination - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Layer Recombination

+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/max-model_predictor/index.html b/algorithms/max-model_predictor/index.html new file mode 100644 index 00000000..80344241 --- /dev/null +++ b/algorithms/max-model_predictor/index.html @@ -0,0 +1,2428 @@ + + + + + + + + + + + + + + + + + + + + + + + Max-Model Predictor - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Max-Model Predictor

+

The max-model predictor algorithm is a type of ensemble method. +Formally, a max-model predictor is defined as follows:

+

Definition (Max-Model Predictor) 1 +Given a set of predictors \(H = \{h_1, h_2, \ldots, h_n\}\), with \(h_i: \mathcal{X} \times \mathcal{Y}_i \mapsto \mathbb{R}\), the max-model predictor \(h_H\) is defined as:

+
\[h_H(x,y) = \max_{h_i\in H} h_i(x,y).\]
+

Take the flu detection problem as an example 1. +Doctors want to build a learning model to detect what type of virus one patient is affected based on her symptoms, for appropriate treatment. However, the types of influenza diverse geographically (Rejmanek et al., 2015), which means the distribution of patient records collected by a hospital in California may be different from those in Florida. In an extreme case, some types are unknown to the other hospital. Assume there are 4 types of influenza in the United States. In California, 2 of 4 are commonly detected, while in Florida 3 of 4 types are often detected. We assume in the two states, doctors separately trained two models \(h_{CA}\) and \(h_{FL}\) which work locally well in California and Florida respectively. However, a direct ensemble of the two local models may not work well on all the patients. Let \(h_{US}\) denote the ideal global model trained on the combination of local datasets. When we input a patient record \(x\), each model outputs its prediction as shown in the following table:

+

Table: Example of flu detection on a patient \(x\) affected with type 2 flu. “−” means this model is not able to predict the corresponding class. Taking the maximal score as prediction, \(h_{FL}\) is consistent with \(h_{US}\), but the combination of two local models \(h_{CA,FL}\) is not since \(3/4 > 4/7\).

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Type1234
\(h_{US}(x)\)2/104/101/103/10
\(h_{CA}(x)\)--1/43/4
\(h_{FL}(x)\)2/74/71/7-
\(h_{\{CA,FL\}}(x)\)2/74/71/43/4
+
+ alt text +
The illustration of running our method on the flu example.
+
+

Example

+

Here is an example of how to use the Max-Model Predictor Algorithm:

+
from fusion_bench.method import MaxModelPredictorAlgorithm
+from fusion_bench.modelpool import ModelPool
+
+# Instantiate the MaxPredictorAlgorithm
+algorithm = MaxModelPredictorAlgorithm()
+
+# Assume we have a ModelPool instance that contains the models we want to ensemble.
+modelpool = ModelPool(...) # or a list of nn.Module
+
+# Run the algorithm on the model pool.
+max_model_predictor : nn.Module = algorithm.run(modelpool)
+
+

Code Integration

+

Configuration template for the Max Predictor Algorithm:

+
config/method/max_model_predictor.yaml
name: max_model_predictor
+
+

To create a max predictor ensemble of models for a specific task, you can use the following command:

+
fusion_bench method=max_model_predictor \
+  modelpool=<modelpool_name> \
+  taskpool=<taskpool_name>
+
+
+
+
    +
  1. +

    Zhu et.al. ICML 2019. Heterogeneous model reuse via optimizing multiparty multiclass margin 

    +
  2. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/model_recombination/index.html b/algorithms/model_recombination/index.html new file mode 100644 index 00000000..dea7ee25 --- /dev/null +++ b/algorithms/model_recombination/index.html @@ -0,0 +1,3060 @@ + + + + + + + + + + + + + + + + + + + + + + + Model Recombination - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Model Recombination

+
+ alt text +
Credit to FedMR
+
+

Usage

+

ModelRecombinationAlgorithm is a class used to recombine models in a model pool. Here's how to use it:

+

First, import the necessary modules:

+
from fusion_bench.method import ModelRecombinationAlgorithm
+from fusion_bench.modelpool import ModelPool, to_modelpool
+from torch import nn
+
+

Create an instance of ModelRecombinationAlgorithm:

+
model_recombination = ModelRecombinationAlgorithm()
+
+

Create a model pool using the to_modelpool function. This function takes a list of models or a dict of models and converts it into a ModelPool:

+
models = [nn.Linear(10, 10) for _ in range(3)]
+modelpool = to_modelpool(models)
+
+

Use the run method of the ModelRecombinationAlgorithm instance to recombine the models in the model pool:

+
new_modelpool = model_recombination.run(modelpool, return_modelpool=True)
+
+

The run method takes two arguments:

+
    +
  • modelpool: The model pool to recombine.
  • +
  • return_modelpool (optional): A boolean indicating whether to return the entire model pool or just the first model. Defaults to True.
  • +
+

If return_modelpool is True, the run method returns a new ModelPool with the recombined models. If False, it returns the first model from the new model pool.

+
new_model = model_recombination.run(modelpool, return_modelpool=False)
+
+

You can check the type of the returned value to ensure that the run method worked correctly:

+
assert isinstance(new_modelpool, ModelPool)
+assert isinstance(new_model, nn.Module)
+
+

Code Integration

+

Configuration template for the model recombination algorithm:

+
config/method/model_recombination.yaml
name: model_recombination
+# if `return_model_pool` is not null, the argument `return_modelpool` passed to the `run` method will be ignored.
+return_modelpool: null
+
+

Construct a model recombination using our CLI tool fusion_bench:

+
fusion_bench \
+    method=model_recombination \
+        method.return_modelpool=false \
+    modelpool=... \
+    taskpool=...
+
+

References

+ + +
+ + + +

+ ModelRecombinationAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + +

Model recombination recombinates the layers of the given models, to create a new set of models.

+ + + + + + +
+ Source code in fusion_bench/method/model_recombination.py +
class ModelRecombinationAlgorithm(BaseAlgorithm):
+    """
+    Model recombination recombinates the layers of the given models, to create a new set of models.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "return_modelpool": "return_modelpool",
+    }
+
+    def __init__(self, return_modelpool: bool, **kwargs):
+        self.return_modelpool = return_modelpool
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def run(
+        self,
+        modelpool: BaseModelPool,
+        return_modelpool: bool = True,
+    ) -> Union[nn.Module, BaseModelPool]:
+        """
+        Executes the model recombination algorithm on a given model pool.
+
+        This method loads models from the model pool, determines their type, and applies the appropriate recombination method.
+        It then creates a new model pool with the recombined models. Depending on the `return_modelpool` flag, it either returns
+        the entire new model pool or just the first model from it.
+
+        - If the models in the model pool are of type `nn.ModuleList`, the recombination method `recombine_modellist` is used. Where each module in the list is shuffled across the models.
+        - If the models are of type `nn.ModuleDict`, the recombination method `recombine_modeldict` is used. Where each module in the dictionary is shuffled across the models.
+        - If the models are of type `nn.Module`, the recombination method `recombine_state_dict` is used. Where the state dictionaries of the models are shuffled across the models.
+
+        Args:
+            modelpool (BaseModelPool): The pool of models to recombine.
+            return_modelpool (bool, optional): Flag indicating whether to return the entire model pool or just the first model. Defaults to True. If this algorithm is initialized with config, the value of `return_modelpool` in the config will be used and this argument passed to the method will be ignored.
+
+        Returns:
+            Union[nn.Module, BaseModelPool]: The recombined model pool or the first model from the recombined pool, depending on the `return_modelpool` flag.
+
+        Raises:
+            ValueError: If the models in the model pool are of an unsupported type.
+        """
+        # If the config has a return_modelpool flag, use that, otherwise use the argument
+        if self.config.get("return_modelpool", None) is not None:
+            return_modelpool = self.config.return_modelpool
+        # check the modelpool type
+        if not isinstance(modelpool, BaseModelPool):
+            modelpool = BaseModelPool(modelpool)
+
+        log.info(f"Running model recombination algorithm with {len(modelpool)} models")
+
+        # TODO: optimize the `recombine_*` functions, if `return_modelpool` is False, we don't need to create the new modelpool, just the first model
+        models = [modelpool.load_model(m) for m in modelpool.model_names]
+        if isinstance(models[0], nn.ModuleList):
+            new_models = recombine_modellist(models)
+        elif isinstance(models[0], nn.ModuleDict):
+            new_models = recombine_modeldict(models)
+        elif isinstance(models[0], nn.Module):
+            new_models = recombine_state_dict(models)
+        else:
+            raise ValueError(f"Unsupported model type {type(models[0])}")
+
+        new_modelpool = BaseModelPool(
+            {n: m for n, m in zip(modelpool.model_names, new_models)}
+        )
+        if return_modelpool:
+            return new_modelpool
+        else:
+            return new_modelpool.load_model(new_modelpool.model_names[0])
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool, return_modelpool=True) + +
+ + +
+ +

Executes the model recombination algorithm on a given model pool.

+

This method loads models from the model pool, determines their type, and applies the appropriate recombination method. +It then creates a new model pool with the recombined models. Depending on the return_modelpool flag, it either returns +the entire new model pool or just the first model from it.

+
    +
  • If the models in the model pool are of type nn.ModuleList, the recombination method recombine_modellist is used. Where each module in the list is shuffled across the models.
  • +
  • If the models are of type nn.ModuleDict, the recombination method recombine_modeldict is used. Where each module in the dictionary is shuffled across the models.
  • +
  • If the models are of type nn.Module, the recombination method recombine_state_dict is used. Where the state dictionaries of the models are shuffled across the models.
  • +
+ + +

Parameters:

+
    +
  • +
    modelpool +
    (BaseModelPool) + – +
    +

    The pool of models to recombine.

    +
    +
  • +
  • +
    return_modelpool +
    (bool, default: + True +) + – +
    +

    Flag indicating whether to return the entire model pool or just the first model. Defaults to True. If this algorithm is initialized with config, the value of return_modelpool in the config will be used and this argument passed to the method will be ignored.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + Union[Module, BaseModelPool] + – +
    +

    Union[nn.Module, BaseModelPool]: The recombined model pool or the first model from the recombined pool, depending on the return_modelpool flag.

    +
    +
  • +
+ + +

Raises:

+
    +
  • + ValueError + – +
    +

    If the models in the model pool are of an unsupported type.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/model_recombination.py +
@torch.no_grad()
+def run(
+    self,
+    modelpool: BaseModelPool,
+    return_modelpool: bool = True,
+) -> Union[nn.Module, BaseModelPool]:
+    """
+    Executes the model recombination algorithm on a given model pool.
+
+    This method loads models from the model pool, determines their type, and applies the appropriate recombination method.
+    It then creates a new model pool with the recombined models. Depending on the `return_modelpool` flag, it either returns
+    the entire new model pool or just the first model from it.
+
+    - If the models in the model pool are of type `nn.ModuleList`, the recombination method `recombine_modellist` is used. Where each module in the list is shuffled across the models.
+    - If the models are of type `nn.ModuleDict`, the recombination method `recombine_modeldict` is used. Where each module in the dictionary is shuffled across the models.
+    - If the models are of type `nn.Module`, the recombination method `recombine_state_dict` is used. Where the state dictionaries of the models are shuffled across the models.
+
+    Args:
+        modelpool (BaseModelPool): The pool of models to recombine.
+        return_modelpool (bool, optional): Flag indicating whether to return the entire model pool or just the first model. Defaults to True. If this algorithm is initialized with config, the value of `return_modelpool` in the config will be used and this argument passed to the method will be ignored.
+
+    Returns:
+        Union[nn.Module, BaseModelPool]: The recombined model pool or the first model from the recombined pool, depending on the `return_modelpool` flag.
+
+    Raises:
+        ValueError: If the models in the model pool are of an unsupported type.
+    """
+    # If the config has a return_modelpool flag, use that, otherwise use the argument
+    if self.config.get("return_modelpool", None) is not None:
+        return_modelpool = self.config.return_modelpool
+    # check the modelpool type
+    if not isinstance(modelpool, BaseModelPool):
+        modelpool = BaseModelPool(modelpool)
+
+    log.info(f"Running model recombination algorithm with {len(modelpool)} models")
+
+    # TODO: optimize the `recombine_*` functions, if `return_modelpool` is False, we don't need to create the new modelpool, just the first model
+    models = [modelpool.load_model(m) for m in modelpool.model_names]
+    if isinstance(models[0], nn.ModuleList):
+        new_models = recombine_modellist(models)
+    elif isinstance(models[0], nn.ModuleDict):
+        new_models = recombine_modeldict(models)
+    elif isinstance(models[0], nn.Module):
+        new_models = recombine_state_dict(models)
+    else:
+        raise ValueError(f"Unsupported model type {type(models[0])}")
+
+    new_modelpool = BaseModelPool(
+        {n: m for n, m in zip(modelpool.model_names, new_models)}
+    )
+    if return_modelpool:
+        return new_modelpool
+    else:
+        return new_modelpool.load_model(new_modelpool.model_names[0])
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + +

+ recombine_modellist(models) + +

+ + +
+ +
+ Source code in fusion_bench/method/model_recombination.py +
14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
def recombine_modellist(models: List[nn.ModuleList]):
+    num_models = len(models)
+    num_layers = len(models[0])
+
+    new_models = [[] for _ in range(num_models)]
+    for layer_idx in range(num_layers):
+        shuffled_layers = [m[layer_idx] for m in models]
+        random.shuffle(shuffled_layers)
+        for model_idx in range(num_models):
+            new_models[model_idx].append(shuffled_layers[model_idx])
+    new_models = [nn.ModuleList(m) for m in new_models]
+    return new_models
+
+
+
+ +
+ +
+ + +

+ recombine_modeldict(models) + +

+ + +
+ +
+ Source code in fusion_bench/method/model_recombination.py +
28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
def recombine_modeldict(models: List[nn.ModuleDict]):
+    num_models = len(models)
+
+    new_models = [{} for _ in range(num_models)]
+    for layer_name in models[0].keys():
+        shuffled_layers = [m[layer_name] for m in models]
+        random.shuffle(shuffled_layers)
+        for model_idx in range(num_models):
+            new_models[model_idx][layer_name] = shuffled_layers[model_idx]
+    new_models = [nn.ModuleDict(m) for m in new_models]
+    return new_models
+
+
+
+ +
+ +
+ + +

+ recombine_state_dict(models) + +

+ + +
+ +
+ Source code in fusion_bench/method/model_recombination.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def recombine_state_dict(models: List[nn.Module]):
+    num_models = len(models)
+    state_dicts = [model.state_dict() for model in models]
+    new_state_dict = [{} for _ in range(num_models)]
+    for key in state_dicts[0].keys():
+        shuffled_layers = [state_dict[key] for state_dict in state_dicts]
+        random.shuffle(shuffled_layers)
+        for model_idx in range(num_models):
+            new_state_dict[model_idx][key] = shuffled_layers[model_idx]
+    for model_idx in range(num_models):
+        models[model_idx].load_state_dict(new_state_dict[model_idx])
+    return models
+
+
+
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/model_stitching/index.html b/algorithms/model_stitching/index.html new file mode 100644 index 00000000..d39cdc8d --- /dev/null +++ b/algorithms/model_stitching/index.html @@ -0,0 +1,2208 @@ + + + + + + + + + + + + + + + + + + + Model stitching - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Model stitching

+ + + + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/moe_based_merging/index.html b/algorithms/moe_based_merging/index.html new file mode 100644 index 00000000..c890aef7 --- /dev/null +++ b/algorithms/moe_based_merging/index.html @@ -0,0 +1,2970 @@ + + + + + + + + + + + + + + + + + + + + + + + MoE-based Merging - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

MoE-based Model Model Merging

+

Code Intergration

+

Here we provides instructions on how to use the fusion_bench command-line interface to merge models using a Mixture of Experts (MoE) approach.

+

The first code block is a YAML configuration file for the merging method. The name field specifies the name of the merging method. The num_experts field specifies the number of experts to use in the merging process. The experts_per_token field specifies the number of experts to use per token. The save_checkpoint field specifies the path where the merged model will be saved.

+
config/method/mixtral_moe_merging.yaml
name: mixtral_for_causal_lm_moe_merging
+
+experts_per_token: 2
+# path to save the merged model, if provided
+save_checkpoint: null
+
+

The second code block is another YAML configuration file, this time for the model pool. The type field specifies the type of model pool to use. The models field is a list of models to include in the pool. Each model should have a name and a path, and the model is loaded from the path.

+
config/modelpool/mixtral_moe_merging.yaml
type: AutoModelForCausalLMPool
+# each model should have a name and a path, and the model is loaded from the path
+# this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
+models:
+  - name: _pretrained_
+    path: path_to_your_pretrained_model
+  - name: expert_1
+    path: path_to_your_expert_model_1
+  - name: expert_2
+    path: path_to_your_expert_model_2
+  - name: expert_3
+    path: path_to_your_expert_model_3
+  - name: expert_4
+    path: path_to_your_expert_model_4
+
+

Finally, the third code block is a bash command that runs the fusion_bench command-line interface with the specified method, model pool, and task pool. The method argument specifies the merging method to use. The modelpool argument specifies the model pool to use. The modelpool.models.0.path argument specifies the path to the pretrained model to use. The taskpool argument specifies the task pool to use. In this case, a dummy task pool is used that does nothing but print the parameter counts of the merged model.

+
fusion_bench \
+    method=mixtral_moe_merging \
+    modelpool=mixtral_moe_merging \
+    taskpool=dummy # this is a dummy taskpool that does nothing but print the parameter counts of the merged model
+
+

This guide provides a step-by-step process for merging models using the fusion_bench command-line interface. By following these instructions, you can merge your own models and save them for future use.

+

References

+ + +
+ + + +

+ mixtral_merging + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ MixtralForCausalLMMergingAlgorithm + + +
+ + +
+

+ Bases: MixtralForCausalLMUpscalingAlgorithm

+ + +

This class is responsible for merging models into a MixtralForCausalLM.

+ + + + + + +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py +
class MixtralForCausalLMMergingAlgorithm(MixtralForCausalLMUpscalingAlgorithm):
+    """
+    This class is responsible for merging models into a `MixtralForCausalLM`.
+    """
+
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool) -> MixtralForCausalLM:
+        """
+        Runs the merging process. It first upscales the models to MixtralForCausalLM,
+        then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.
+
+        Args:
+            modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralForCausalLM` or `LlamaForCausalLM`.
+
+        Returns:
+            MixtralForCausalLM: The merged model.
+        """
+        with open_dict(self.config):
+            self.config.num_experts = len(modelpool)
+
+        # firstly, we upscale the models to MixtralForCausalLM
+        mixtral_model = super()._run(modelpool)
+
+        # then we substitute the experts of the MixtralForCausalLM with the models from the modelpool
+        for model_idx, model_name in enumerate(modelpool.model_names):
+            expert_model: MistralForCausalLM | LlamaForCausalLM = modelpool.load_model(
+                model_name
+            )
+            _substitute_experts(model_idx, expert_model.model, mixtral_model.model)
+
+        if self.config.get("save_checkpoint", None) is not None:
+            mixtral_model.save_pretrained(self.config.save_checkpoint)
+        return mixtral_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Runs the merging process. It first upscales the models to MixtralForCausalLM, +then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool) + – +
    +

    The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a MistralForCausalLM or LlamaForCausalLM.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +MixtralForCausalLM ( MixtralForCausalLM +) – +
    +

    The merged model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py +
@torch.no_grad()
+def run(self, modelpool: BaseModelPool) -> MixtralForCausalLM:
+    """
+    Runs the merging process. It first upscales the models to MixtralForCausalLM,
+    then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.
+
+    Args:
+        modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralForCausalLM` or `LlamaForCausalLM`.
+
+    Returns:
+        MixtralForCausalLM: The merged model.
+    """
+    with open_dict(self.config):
+        self.config.num_experts = len(modelpool)
+
+    # firstly, we upscale the models to MixtralForCausalLM
+    mixtral_model = super()._run(modelpool)
+
+    # then we substitute the experts of the MixtralForCausalLM with the models from the modelpool
+    for model_idx, model_name in enumerate(modelpool.model_names):
+        expert_model: MistralForCausalLM | LlamaForCausalLM = modelpool.load_model(
+            model_name
+        )
+        _substitute_experts(model_idx, expert_model.model, mixtral_model.model)
+
+    if self.config.get("save_checkpoint", None) is not None:
+        mixtral_model.save_pretrained(self.config.save_checkpoint)
+    return mixtral_model
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +
+ MixtralMoEMergingAlgorithm + + +
+ + +
+

+ Bases: MixtralUpscalingAlgorithm

+ + +

This class is responsible for merging models into a MixtralModel.

+ + + + + + +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py +
48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
class MixtralMoEMergingAlgorithm(MixtralUpscalingAlgorithm):
+    """
+    This class is responsible for merging models into a MixtralModel.
+    """
+
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool) -> MixtralModel:
+        """
+        Runs the merging process.
+
+        Args:
+            modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralModel` or `LlamaModel`.
+
+        Returns:
+            MixtralModel: The merged model.
+        """
+        with open_dict(self.config):
+            self.config.num_experts = len(modelpool)
+
+        # firstly, we upscale the models to MixtralModel
+        mixtral_model = super()._run(modelpool)
+
+        # then we substitute the experts of the MixtralModel with the models from the modelpool
+        for model_idx, model_name in enumerate(modelpool.model_names):
+            expert_model: MistralModel | LlamaModel = modelpool.load_model(model_name)
+            _substitute_experts(model_idx, expert_model, mixtral_model)
+
+        if self.config.get("save_checkpoint", None) is not None:
+            mixtral_model.save_pretrained(self.config.save_checkpoint)
+        return mixtral_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Runs the merging process.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool) + – +
    +

    The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a MistralModel or LlamaModel.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +MixtralModel ( MixtralModel +) – +
    +

    The merged model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py +
53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
@torch.no_grad()
+def run(self, modelpool: BaseModelPool) -> MixtralModel:
+    """
+    Runs the merging process.
+
+    Args:
+        modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralModel` or `LlamaModel`.
+
+    Returns:
+        MixtralModel: The merged model.
+    """
+    with open_dict(self.config):
+        self.config.num_experts = len(modelpool)
+
+    # firstly, we upscale the models to MixtralModel
+    mixtral_model = super()._run(modelpool)
+
+    # then we substitute the experts of the MixtralModel with the models from the modelpool
+    for model_idx, model_name in enumerate(modelpool.model_names):
+        expert_model: MistralModel | LlamaModel = modelpool.load_model(model_name)
+        _substitute_experts(model_idx, expert_model, mixtral_model)
+
+    if self.config.get("save_checkpoint", None) is not None:
+        mixtral_model.save_pretrained(self.config.save_checkpoint)
+    return mixtral_model
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/moe_based_upscaling/index.html b/algorithms/moe_based_upscaling/index.html new file mode 100644 index 00000000..34624219 --- /dev/null +++ b/algorithms/moe_based_upscaling/index.html @@ -0,0 +1,3759 @@ + + + + + + + + + + + + + + + + + + + + + + + MoE-based Upscaling - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

MoE-based Model Model Upscaling (Sparse Upcycling)

+
+ alt text +
+

Sparse upcycling is a technique used to initialize a sparsely activated Mixture-of-Experts (MoE) model from a dense checkpoint. This approach leverages previously incurred training costs to improve the performance of large models while reducing the computational expense. In the process, dense Transformer blocks are partially replaced with MoE blocks, where the MLPs in a Transformer block are replaced by multiple experts. The experts are chosen based on routing probabilities determined by a router. The initialized MoE model is then further trained to recover the performance. This method results in improved performance for both language and vision models while using only a fraction of the original dense pretraining cost 1.

+

Examples

+

Here’s an example demonstrating how to upscale a pre-trained Mistral model to a Mixtral model:

+
import os
+
+from omegaconf import DictConfig
+from transformers import MistralForCausalLM
+
+from fusion_bench.method.mixture_of_experts.mixtral_upcycling import (
+    MixtralForCausalLMUpscalingAlgorithm,
+)
+from fusion_bench.utils import print_parameters
+
+# Load a pre-trained Mistral model
+pretrained_model = MistralForCausalLM.from_pretrained(
+    os.path.expanduser("path_to_mistral_model")
+)
+print("Pretrained model:")
+print_parameters(pretrained_model)
+# Output:
+# Pretrained model:
+# trainable params: 7.24B || all params: 7.24B || trainable%: 100.0000
+
+# Define the configuration for Mixtral
+config = {
+    "num_experts": 4,  # Number of expert channels
+    "experts_per_token": 2,  # Experts to choose per token
+}
+
+# Initialize the upscaling algorithm
+upscaling_for_causal_lm_algorithm = MixtralForCausalLMUpscalingAlgorithm(
+    DictConfig(config)
+)
+
+# Run the upscaling process to get a Mixtral model
+mixtral_for_causal_lm_model = upscaling_for_causal_lm_algorithm.run(pretrained_model)
+
+print("Mixtral model:")
+print_parameters(mixtral_for_causal_lm_model)
+# Outputs:
+# Mixtral model:
+# trainable params: 24.15B || all params: 24.15B || trainable%: 100.0000
+
+# Save the upscaled Mixtral model
+mixtral_for_causal_lm_model.save_pretrained("path_to_save_mixtral_model")
+
+

A Jupyter notebook example is also available at our repo.

+

Code Integration

+

This is a guide on how to use the fusion_bench command-line interface to upscale a Mistral model to a Mixtral model.

+

The first code block is a YAML configuration file for the upscaling method. The name field specifies the name of the upscaling method. The num_experts field specifies the number of experts to use in the upscaling process. The experts_per_token field specifies the number of experts to use per token. The save_checkpoint field specifies the path where the upscaled model will be saved, if provided.

+
config/method/mixtral_moe_upscaling.yaml
name: mixtral_for_causal_lm_moe_upscaling # or "mixtral_moe_upscaling"
+
+num_experts: 4
+experts_per_token: 2
+# path to save the upscaled model
+save_checkpoint: null
+
+

The second code block is another YAML configuration file, this time for the model pool. The type field specifies the type of model pool to use. The models field is a list of models to include in the pool. Each model should have a name and a path, and the model is loaded from the path.

+
config/modelpool/mixtral_moe_upscaling.yaml
type: AutoModelForCausalLMPool
+# each model should have a name and a path, and the model is loaded from the path
+# this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
+models:
+  - name: _pretrained_
+    path: path_to_your_pretrained_model
+
+

Finally, the third code block is a bash command that runs the fusion_bench command-line interface with the specified method, model pool, and task pool. The method argument specifies the upscaling method to use. The modelpool argument specifies the model pool to use. The modelpool.models.0.path argument specifies the path to the pretrained model to use. The taskpool argument specifies the task pool to use. In this case, a dummy task pool is used that does nothing but print the parameter counts of the merged model.

+
fusion_bench \
+    method=mixtral_moe_upscaling \
+    modelpool=mixtral_moe_upscaling \
+        modelpool.models.0.path=path_to_your_pretrained_model \
+    taskpool=dummy # this is a dummy taskpool that does nothing but print the parameter counts of the merged model
+
+

References

+ + +
+ + + +

+ mixtral_upcycling + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ MixtralForCausalLMUpscalingAlgorithm + + +
+ + +
+

+ Bases: BaseAlgorithm

+ + +

This class is responsible for upscaling a model to a MixtralForCausalLM. +It inherits from the ModelFusionAlgorithm class.

+ + + + + + +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
+    """
+    This class is responsible for upscaling a model to a MixtralForCausalLM.
+    It inherits from the ModelFusionAlgorithm class.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "num_experts": "num_experts",
+        "experts_per_token": "experts_per_token",
+        "save_checkpoint": "save_checkpoint",
+    }
+
+    def __init__(
+        self,
+        num_experts: int,
+        experts_per_token: int,
+        save_checkpoint: str,
+        **kwargs,
+    ):
+        """
+        Initialize the MixtralForCausalLMUpscalingAlgorithm.
+
+        Args:
+            num_experts (int): The number of experts in the Mixtral model.
+            experts_per_token (int): The number of experts per token.
+            save_checkpoint (str): The path to save the checkpoint.
+            **kwargs: Additional keyword arguments.
+        """
+        self.num_experts = num_experts
+        self.experts_per_token = experts_per_token
+        self.save_checkpoint = save_checkpoint
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def _run(
+        self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
+    ) -> MixtralForCausalLM:
+        """
+        Internal method to run the upscaling process.
+
+        Args:
+            modelpool (BaseModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.
+
+        Returns:
+            MixtralForCausalLM: The upscaled model.
+        """
+        if isinstance(modelpool, BaseModelPool):
+            assert modelpool.has_pretrained, "ModelPool must have pretrained model."
+            pretrained_model = modelpool.load_model("_pretrained_")
+        elif isinstance(modelpool, (LlamaForCausalLM, MistralForCausalLM)):
+            pretrained_model = modelpool
+        else:
+            raise ValueError("Invalid modelpool type")
+
+        mixtral_config = _convert_config_to_mixtral(
+            pretrained_model.config,
+            self.config.num_experts,
+            self.config.experts_per_token,
+        )
+
+        with ContextManagers([no_init_weights(True)]):
+            for _ in tqdm(range(1), desc="Initializing Mixtral model"):
+                mixtral_model = MixtralForCausalLM(mixtral_config)
+        upscale_to_mixtral_for_causal_lm(pretrained_model, mixtral_model)
+
+        return mixtral_model
+
+    @torch.no_grad()
+    def run(
+        self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
+    ) -> MixtralForCausalLM:
+        """
+        Runs the upscaling process.
+
+        Args:
+            modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.
+
+        Returns:
+            MixtralForCausalLM: The upscaled model.
+        """
+        mixtral_model = self._run(modelpool)
+
+        if self.config.get("save_checkpoint", None) is not None:
+            mixtral_model.save_pretrained(self.config.save_checkpoint)
+        return mixtral_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ __init__(num_experts, experts_per_token, save_checkpoint, **kwargs) + +
+ + +
+ +

Initialize the MixtralForCausalLMUpscalingAlgorithm.

+ + +

Parameters:

+
    +
  • + num_experts + (int) + – +
    +

    The number of experts in the Mixtral model.

    +
    +
  • +
  • + experts_per_token + (int) + – +
    +

    The number of experts per token.

    +
    +
  • +
  • + save_checkpoint + (str) + – +
    +

    The path to save the checkpoint.

    +
    +
  • +
  • + **kwargs + – +
    +

    Additional keyword arguments.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
def __init__(
+    self,
+    num_experts: int,
+    experts_per_token: int,
+    save_checkpoint: str,
+    **kwargs,
+):
+    """
+    Initialize the MixtralForCausalLMUpscalingAlgorithm.
+
+    Args:
+        num_experts (int): The number of experts in the Mixtral model.
+        experts_per_token (int): The number of experts per token.
+        save_checkpoint (str): The path to save the checkpoint.
+        **kwargs: Additional keyword arguments.
+    """
+    self.num_experts = num_experts
+    self.experts_per_token = experts_per_token
+    self.save_checkpoint = save_checkpoint
+    super().__init__(**kwargs)
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Runs the upscaling process.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool | LlamaForCausalLM | MistralForCausalLM) + – +
    +

    The model to be upscaled.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +MixtralForCausalLM ( MixtralForCausalLM +) – +
    +

    The upscaled model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
@torch.no_grad()
+def run(
+    self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
+) -> MixtralForCausalLM:
+    """
+    Runs the upscaling process.
+
+    Args:
+        modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.
+
+    Returns:
+        MixtralForCausalLM: The upscaled model.
+    """
+    mixtral_model = self._run(modelpool)
+
+    if self.config.get("save_checkpoint", None) is not None:
+        mixtral_model.save_pretrained(self.config.save_checkpoint)
+    return mixtral_model
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +
+ MixtralUpscalingAlgorithm + + +
+ + +
+

+ Bases: BaseAlgorithm

+ + +

This class is responsible for upscaling a model to a MixtralModel. +It inherits from the ModelFusionAlgorithm class.

+ + + + + + +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
class MixtralUpscalingAlgorithm(BaseAlgorithm):
+    """
+    This class is responsible for upscaling a model to a MixtralModel.
+    It inherits from the ModelFusionAlgorithm class.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "num_experts": "num_experts",
+        "experts_per_token": "experts_per_token",
+        "save_checkpoint": "save_checkpoint",
+    }
+
+    def __init__(
+        self,
+        num_experts: int,
+        experts_per_token: int,
+        save_checkpoint: str,
+        **kwargs,
+    ):
+        """
+        Initialize the MixtralUpscalingAlgorithm.
+
+        Args:
+            num_experts (int): The number of experts in the Mixtral model.
+            experts_per_token (int): The number of experts per token.
+            save_checkpoint (str): The path to save the checkpoint.
+            **kwargs: Additional keyword arguments.
+        """
+        self.num_experts = num_experts
+        self.experts_per_token = experts_per_token
+        self.save_checkpoint = save_checkpoint
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def _run(
+        self, modelpool: BaseModelPool | LlamaModel | MistralModel
+    ) -> MixtralModel:
+        """
+        Internal method to run the upscaling process.
+
+        Args:
+            modelpool (BaseModelPool | LlamaModel | MistralModel): The model to be upscaled.
+
+        Returns:
+            MixtralModel: The upscaled model.
+        """
+        if isinstance(modelpool, BaseModelPool):
+            assert modelpool.has_pretrained, "ModelPool must have pretrained model."
+            pretrained_model = modelpool.load_model("_pretrained_")
+        elif isinstance(modelpool, (LlamaModel, MistralModel)):
+            pretrained_model = modelpool
+        else:
+            raise ValueError("Invalid modelpool type")
+
+        mixtral_config = _convert_config_to_mixtral(
+            pretrained_model.config,
+            self.config.num_experts,
+            self.config.experts_per_token,
+        )
+
+        with ContextManagers([no_init_weights(True)]):
+            for _ in tqdm(range(1), desc="Initializing Mixtral model"):
+                mixtral_model = MixtralModel(mixtral_config)
+        upscale_to_mixtral_model(pretrained_model, mixtral_model)
+
+        return mixtral_model
+
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool | LlamaModel | MistralModel) -> MixtralModel:
+        """
+        Runs the upscaling process.
+
+        Args:
+            modelpool (ModelPool | LlamaModel | MistralModel): The model to be upscaled.
+
+        Returns:
+            MixtralModel: The upscaled model.
+        """
+        mixtral_model = self._run(modelpool)
+
+        if self.config.get("save_checkpoint", None) is not None:
+            mixtral_model.save_pretrained(self.config.save_checkpoint)
+        return mixtral_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ __init__(num_experts, experts_per_token, save_checkpoint, **kwargs) + +
+ + +
+ +

Initialize the MixtralUpscalingAlgorithm.

+ + +

Parameters:

+
    +
  • + num_experts + (int) + – +
    +

    The number of experts in the Mixtral model.

    +
    +
  • +
  • + experts_per_token + (int) + – +
    +

    The number of experts per token.

    +
    +
  • +
  • + save_checkpoint + (str) + – +
    +

    The path to save the checkpoint.

    +
    +
  • +
  • + **kwargs + – +
    +

    Additional keyword arguments.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
def __init__(
+    self,
+    num_experts: int,
+    experts_per_token: int,
+    save_checkpoint: str,
+    **kwargs,
+):
+    """
+    Initialize the MixtralUpscalingAlgorithm.
+
+    Args:
+        num_experts (int): The number of experts in the Mixtral model.
+        experts_per_token (int): The number of experts per token.
+        save_checkpoint (str): The path to save the checkpoint.
+        **kwargs: Additional keyword arguments.
+    """
+    self.num_experts = num_experts
+    self.experts_per_token = experts_per_token
+    self.save_checkpoint = save_checkpoint
+    super().__init__(**kwargs)
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Runs the upscaling process.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool | LlamaModel | MistralModel) + – +
    +

    The model to be upscaled.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +MixtralModel ( MixtralModel +) – +
    +

    The upscaled model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
@torch.no_grad()
+def run(self, modelpool: BaseModelPool | LlamaModel | MistralModel) -> MixtralModel:
+    """
+    Runs the upscaling process.
+
+    Args:
+        modelpool (ModelPool | LlamaModel | MistralModel): The model to be upscaled.
+
+    Returns:
+        MixtralModel: The upscaled model.
+    """
+    mixtral_model = self._run(modelpool)
+
+    if self.config.get("save_checkpoint", None) is not None:
+        mixtral_model.save_pretrained(self.config.save_checkpoint)
+    return mixtral_model
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +
+ upscale_to_mixtral_for_causal_lm(input_model, output_model) + +
+ + +
+ +

A helper function.

+

Upscales a LlamaForCausalLM or MistralForCausalLM to a MixtralForCausalLM.

+ + +

Parameters:

+
    +
  • +
    input_model +
    (LlamaForCausalLM | MistralForCausalLM) + – +
    +

    The input model to be upscaled.

    +
    +
  • +
  • +
    output_model +
    (MixtralForCausalLM) + – +
    +

    The output model where the upscaled weights will be loaded.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    None

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
def upscale_to_mixtral_for_causal_lm(
+    input_model: LlamaForCausalLM | MistralForCausalLM, output_model: MixtralForCausalLM
+):
+    """
+    A helper function.
+
+    Upscales a LlamaForCausalLM or MistralForCausalLM to a MixtralForCausalLM.
+
+    Args:
+        input_model (LlamaForCausalLM | MistralForCausalLM): The input model to be upscaled.
+        output_model (MixtralForCausalLM): The output model where the upscaled weights will be loaded.
+
+    Returns:
+        None
+    """
+    output_model.lm_head.load_state_dict(input_model.lm_head.state_dict())
+    upscale_to_mixtral_model(input_model.model, output_model.model)
+
+
+
+ +
+ +
+ + +
+ upscale_to_mixtral_model(input_model, output_model) + +
+ + +
+ +

A helper function.

+

Upscales a LlamaModel or MistralModel to a MixtralModel.

+ + +

Parameters:

+
    +
  • +
    input_model +
    (LlamaModel | MistralModel) + – +
    +

    The input model to be upscaled.

    +
    +
  • +
  • +
    output_model +
    (MixtralModel) + – +
    +

    The output model where the upscaled weights will be loaded.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    None

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +
def upscale_to_mixtral_model(
+    input_model: LlamaModel | MistralModel, output_model: MixtralModel
+):
+    """
+    A helper function.
+
+    Upscales a LlamaModel or MistralModel to a MixtralModel.
+
+    Args:
+        input_model (LlamaModel | MistralModel): The input model to be upscaled.
+        output_model (MixtralModel): The output model where the upscaled weights will be loaded.
+
+    Returns:
+        None
+    """
+    # copy the weights from the pretrained model
+    output_model.embed_tokens.load_state_dict(input_model.embed_tokens.state_dict())
+    output_model.norm.load_state_dict(input_model.norm.state_dict())
+    for input_layer, output_layer in tqdm(
+        zip(input_model.layers, output_model.layers),
+        desc="Upscaling layers",
+        total=len(input_model.layers),
+    ):
+        _upscale_decoder_layer(input_layer, output_layer)
+
+
+
+ +
+ + + +
+ +
+ +
+
+
    +
  1. +

    Sparse Upcycling: Training Mixture-of-Experts from Dense Checkpoints. http://arxiv.org/abs/2212.05055 

    +
  2. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/pruning/images/llama_2_4_semistructued_first_layer.png b/algorithms/pruning/images/llama_2_4_semistructued_first_layer.png new file mode 100644 index 00000000..6dcee1ff Binary files /dev/null and b/algorithms/pruning/images/llama_2_4_semistructued_first_layer.png differ diff --git a/algorithms/pruning/magnitude_pruning/index.html b/algorithms/pruning/magnitude_pruning/index.html new file mode 100644 index 00000000..532c313c --- /dev/null +++ b/algorithms/pruning/magnitude_pruning/index.html @@ -0,0 +1,2885 @@ + + + + + + + + + + + + + + + + + + + + + + + Magnitude Pruning - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Magnitude Pruning

+

Examples

+

Pruning a Llama Model

+

Unstructured Magnitude Pruning

+

The following command prunes a Llama model with a sparsity ratio of 0.7 (70% of the weights are pruned) using unstructured magnitude pruning. The pruned model is saved to outputs/llama/magnitude_pruning/unstructured/0.7.

+
fusion_bench \
+    --config-name llama_magnitude_pruning \
+    method.prune_type=unstructured \
+    method.sparsity_ratio=0.7 \
+    modelpool.models.0.path=decapoda-research/llama-7b-hf \
+    merged_model_save_path=outputs/llama/magnitude_pruning/unstructured/0.7
+
+

Semi-Structured Magnitude Pruning

+

The following command prunes a Llama model with a 2:4 semi-structured pruning ratio using magnitude pruning. The pruned model is saved to outputs/llama/magnitude_pruning/semistructure/2_4.

+
fusion_bench \
+    --config-name llama_magnitude_pruning \
+    method.prune_type=semistructured \
+    method.n=2 method.m=4 \
+    modelpool.models.0.path=decapoda-research/llama-7b-hf \
+    merged_model_save_path=outputs/llama/magnitude_pruning/semistructure/2_4
+
+

Below is an example of how to visualize the pruned weights of the first layer of the pruned model.

+
from transformers import AutoModelForCausalLM
+import matplotlib.pyplot as plt
+import seaborn as sns
+import torch
+
+# Load the pruned model
+model = AutoModelForCausalLM.from_pretrained("outputs/llama/magnitude_pruning/semistructure/2_4")
+
+# Extract the tensor data
+tensor_data = model.model.layers[0].self_attn.q_proj.weight[:32, :32]
+
+# Convert to NumPy array
+tensor_data_np = tensor_data.detach().cpu().numpy()
+
+# Plot heatmap
+plt.figure(figsize=(10, 8))
+ax = sns.heatmap(tensor_data_np, center=0, cmap="coolwarm", annot=False)
+
+# Add grid lines for 4x4 cells
+for i in range(0, tensor_data_np.shape[0], 4):
+    ax.axhline(i, color="black", linewidth=0.5)
+    ax.axvline(i, color="black", linewidth=0.5)
+
+plt.title("Heatmap of q_proj.weight[:32, :32]")
+plt.show()
+
+

The following image shows the pruned weights of the first layer of the pruned model.

+

alt text

+

References

+ + +
+ + + +

+ MagnitudePruningForLlama + + +

+ + +
+

+ Bases: BaseAlgorithm, SimpleProfilerMixin

+ + +

Implements magnitude-based pruning for LLama models.

+

This class supports both unstructured and semistructured pruning methods. +It loads a pre-trained model or the first model in the pool and applies the specified pruning technique.

+ + +

Methods:

+
    +
  • + run + – +
    +

    LLamaForCausalLMPool) -> nn.Module: +Executes the pruning process on the model pool and returns the pruned model.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/pruning/llama_magnitude_prune.py +
class MagnitudePruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
+    """
+    Implements magnitude-based pruning for LLama models.
+
+    This class supports both unstructured and semistructured pruning methods.
+    It loads a pre-trained model or the first model in the pool and applies the specified pruning technique.
+
+    Methods:
+        run(modelpool: LLamaForCausalLMPool) -> nn.Module:
+            Executes the pruning process on the model pool and returns the pruned model.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "prune_type": "prune_type",
+        "device": "device",
+        "dtype": "dtype",
+        "sparsity_ratio": "sparsity_ratio",
+        "n": "n",
+        "m": "m",
+    }
+
+    def __init__(
+        self,
+        *,
+        prune_type: Literal["unstructured", "semistructured"],
+        device: str,
+        dtype: Optional[str],
+        sparsity_ratio: float,
+        n: int,
+        m: int,
+        **kwargs,
+    ):
+        self.prune_type = prune_type
+        self.device = device
+        self.dtype = dtype
+        self.sparsity_ratio = sparsity_ratio
+        self.n = n
+        self.m = m
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def run(self, modelpool: CausalLMPool):
+        """
+        Execute the pruning process on the first model from the given model pool.
+
+        Args:
+            modelpool (CausalLMPool): The model pool containing the models to prune.
+
+        Returns:
+            nn.Module: The pruned model.
+        """
+        config = self.config
+
+        # load pre-trained model or the first model in the pool
+        base_model = modelpool.load_pretrained_or_first_model()
+
+        dtype = parse_dtype(config.dtype)
+        device = torch.device(config.device)
+
+        if config.prune_type == "unstructured":
+            unstructured_magnitude_prune_(
+                base_model, config.sparsity_ratio, dtype=dtype, device=device
+            )
+        elif config.prune_type == "semistructured":
+            semistructured_magnitude_prune_(
+                base_model, config.n, config.m, dtype=dtype, device=device
+            )
+        else:
+            raise ValueError(
+                f"Invalid pruning type: {config.prune_type}"
+                "Choose from 'unstructured' or 'semistructured'"
+            )
+
+        return base_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Execute the pruning process on the first model from the given model pool.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (CausalLMPool) + – +
    +

    The model pool containing the models to prune.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    nn.Module: The pruned model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/pruning/llama_magnitude_prune.py +
@torch.no_grad()
+def run(self, modelpool: CausalLMPool):
+    """
+    Execute the pruning process on the first model from the given model pool.
+
+    Args:
+        modelpool (CausalLMPool): The model pool containing the models to prune.
+
+    Returns:
+        nn.Module: The pruned model.
+    """
+    config = self.config
+
+    # load pre-trained model or the first model in the pool
+    base_model = modelpool.load_pretrained_or_first_model()
+
+    dtype = parse_dtype(config.dtype)
+    device = torch.device(config.device)
+
+    if config.prune_type == "unstructured":
+        unstructured_magnitude_prune_(
+            base_model, config.sparsity_ratio, dtype=dtype, device=device
+        )
+    elif config.prune_type == "semistructured":
+        semistructured_magnitude_prune_(
+            base_model, config.n, config.m, dtype=dtype, device=device
+        )
+    else:
+        raise ValueError(
+            f"Invalid pruning type: {config.prune_type}"
+            "Choose from 'unstructured' or 'semistructured'"
+        )
+
+    return base_model
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/pwe_moe/index.html b/algorithms/pwe_moe/index.html new file mode 100644 index 00000000..9ee031fa --- /dev/null +++ b/algorithms/pwe_moe/index.html @@ -0,0 +1,3470 @@ + + + + + + + + + + + + + + + + + + + + + + + PWE MoE - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+ +
+
+ + + +
+
+ + + + + + + +

PWEMoE: Pareto-Driven Weight-Ensembling Mixture of Experts

+

arXiv

+
+alt text +
+ Overview of PWE MoE + (a) An illustration of Pareto front learning in MOOP. Where \(P_1\) and \(P_2\) are performance metrics for two tasks, colored lines represent different Pareto optimal solutions, and the solid black line represents the Pareto front. + (b) An overview of the model up-scaling process. + We upcycle the MLP modules to MoE modules and merge the remaining parts using task arithmetic. + (c) The MoE module, comprising a routing network and a parameter decoder network. + The routing network accepts a user preference vector and generates routing weights for weight-ensembling.
+
+
+

Abstract

+

Solving multi-objective optimization problems for large deep neural networks is a challenging task due to the complexity of the loss landscape and the expensive computational cost of training and evaluating models. +Efficient Pareto front approximation of large models enables multi-objective optimization for various tasks such as multi-task learning and trade-off analysis. +Existing algorithms for learning Pareto set, including (1) evolutionary, hypernetworks, and hypervolume-maximization methods, are computationally expensive and have restricted scalability to large models; +(2) Scalarization algorithms, where a separate model is trained for each objective ray, which is inefficient for learning the entire Pareto set and fails to capture the objective trade-offs effectively. +Inspired by the recent success of model merging, we propose a practical and scalable approach to Pareto set learning problem via mixture of experts (MoE) based model fusion. +By ensembling the weights of specialized single-task models, the MoE module can effectively capture the trade-offs between multiple objectives and closely approximate the entire Pareto set of large neural networks. +Once the routers are learned and a preference vector is set, the MoE module can be unloaded, thus no additional computational cost is introduced during inference. +We conduct extensive experiments on vision and language tasks using large-scale models such as CLIP-ViT and GPT-2. +The experimental results demonstrate that our method efficiently approximates the entire Pareto front of large models. +Using only hundreds of trainable parameters of the MoE routers, our method even has lower memory usage compared to linear scalarization and algorithms that learn a single Pareto optimal solution, and are scalable to both the number of objectives and the size of the model. +Our method significantly reduces the computational burden of learning the Pareto set, for example, in the two-task case, it can be achieved in just a few minutes. +Code is available at: GitHub .

+
+

Examples

+
+

Not tested yet

+

The examples provided below have not been tested yet.

+

For a thoroughly tested and verified implementation of the algorithm, please refer to the original repository: tanganke/pareto_set_learning . +Additionally, the experimental results and further insights into the algorithm can be found in the original research paper: arXiv:2406.09770 .

+
+

PWEMoE-LS on eight image classification tasks using CLIP-ViT-B/32 models, and the results are logged to outputs/logs/ViT-B-32/PWEMoE-LS-8tasks.

+
fusion_bench \
+    method=pwe_moe_ls_for_clip \
+    modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
+    taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 \
+    fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
+    fabric.loggers.name=PWEMoE-LS-8tasks
+
+

References

+ + +
+ + + +

+ clip_pwe_moe + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ PWEMoEAlgorithmForCLIP + + +
+ + +
+

+ Bases: BaseAlgorithm, SimpleProfilerMixin, CLIPClassificationMixin

+ + + + + + + +
+ Source code in fusion_bench/method/pwe_moe/clip_pwe_moe.py +
class PWEMoEAlgorithmForCLIP(
+    BaseAlgorithm,
+    SimpleProfilerMixin,
+    CLIPClassificationMixin,
+):
+    modelpool: CLIPVisionModelPool = None
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "upscale_mlp": "upscale_mlp",
+        "upscale_attn": "upscale_attn",
+        "init_lambda": "init_lambda",
+        "router_hidden_layers": "router_hidden_layers",
+        "lr": "lr",
+        "num_steps": "num_steps",
+        "save_interval": "save_interval",
+        "alpha": "alpha",
+        "checkpoint_path": "checkpoint_path",
+        "eval_grid": "eval_grid",
+        "eval_grid_n": "eval_grid_n",
+        "eval_grid_m": "eval_grid_m",
+        "_dataloader_kwargs": "dataloader_kwargs",
+    }
+
+    def __init__(
+        self,
+        *,
+        upscale_mlp: bool,
+        upscale_attn: bool,
+        init_lambda: float,
+        router_hidden_layers: int,
+        lr: float,
+        num_steps: int,
+        save_interval: int,
+        alpha: float,
+        checkpoint_path: str,
+        eval_grid: bool,
+        eval_grid_n: int,
+        eval_grid_m: int,
+        dataloader_kwargs: DictConfig,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.upscale_mlp = upscale_mlp
+        self.upscale_attn = upscale_attn
+        self.init_lambda = init_lambda
+        self.router_hidden_layers = router_hidden_layers
+        self.lr = lr
+        self.num_steps = num_steps
+        self.save_interval = save_interval
+        self.alpha = alpha
+        self.checkpoint_path = checkpoint_path
+        self.eval_grid = eval_grid
+        self.eval_grid_n = eval_grid_n
+        self.eval_grid_m = eval_grid_m
+        self._dataloader_kwargs = dataloader_kwargs
+
+    @override
+    def run(self, modelpool: CLIPVisionModelPool):
+        self.modelpool = modelpool
+
+        model = self.setup_model()
+        if self.checkpoint_path is not None:
+            model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu"))
+        else:
+            train_loaders = self.setup_train_loaders()
+            model = self.train(model, train_loaders)
+
+        if self.eval_grid:
+            return map(
+                lambda m, r: {
+                    "model": ParetoWeightEnsemblingModule.set_preferenece_vector(
+                        m,
+                        torch.as_tensor(
+                            r, device=self.fabric.device, dtype=torch.float32
+                        ),
+                    ),
+                    "preference_vector": r,
+                },
+                itertools.cycle([model]),
+                generate_simplex_grid(self.eval_grid_n, self.eval_grid_m),
+            )
+        return model
+
+    def load_clip_models(self):
+        """
+        Loads the pretrained CLIP model and the fine-tuned models for each dataset specified in the configuration.
+        """
+        # load pretrained and fine-tuned model
+        with timeit_context():
+            log.info("load models")
+            pretrained_model: CLIPVisionModel = self.modelpool.load_model(
+                "_pretrained_"
+            )
+            finetuned_models = {
+                model_name: self.modelpool.load_model(model_name)
+                for model_name in self.modelpool.model_names
+            }
+
+        log.info("pretrained model statistics:")
+        print_parameters(pretrained_model)
+        return pretrained_model, finetuned_models
+
+    def setup_model(self):
+        pretrained_model, finetuned_models = self.load_clip_models()
+        self.setup_zero_shot_classification_head()
+
+        with timeit_context("Building PWEMoE model"):
+            model = deepcopy(pretrained_model)
+
+            # merge the remaining layers using task arithmetic
+            if self.init_lambda != 0:
+                task_arithmetic_merge(
+                    model,
+                    finetuned_models.values(),
+                    scaling_factor=self.init_lambda,
+                    inplace=True,
+                )
+            # fix all parameters
+            model.requires_grad_(False)
+
+            num_layers = len(model.vision_model.encoder.layers)
+
+            def get_layer(m, i):
+                return cast(CLIPEncoderLayer, m.vision_model.encoder.layers[i])
+
+            for layer_idx in tqdm(range(num_layers)):
+                if self.upscale_mlp:
+                    # upscale the mlp layer
+                    get_layer(model, layer_idx).mlp = ParetoWeightEnsemblingModule(
+                        base_model=get_layer(pretrained_model, layer_idx).mlp,
+                        expert_models=[
+                            get_layer(m, layer_idx).mlp
+                            for m in finetuned_models.values()
+                        ],
+                        init_lambda=self.init_lambda,
+                        fix_base_model_and_experts=True,
+                        router_hidden_layers=self.router_hidden_layers,
+                    )
+
+                if self.upscale_attn:
+                    # upscale the Attention layer
+                    get_layer(model, layer_idx).self_attn = (
+                        ParetoWeightEnsemblingModule(
+                            base_model=get_layer(pretrained_model, layer_idx).self_attn,
+                            expert_models=[
+                                get_layer(m, layer_idx).self_attn
+                                for m in finetuned_models.values()
+                            ],
+                            init_lambda=self.init_lambda,
+                            fix_base_model_and_experts=True,
+                            router_hidden_layers=self.router_hidden_layers,
+                        )
+                    )
+
+            print("model statistics after upscaling:")
+            print_parameters(model)
+            return model
+
+    def setup_train_loaders(self):
+        """
+        Loads the datasets specified in the configuration.
+        """
+        train_datasets = {
+            dataset_name: self.modelpool.load_train_dataset(
+                dataset_name, self.clip_processor
+            )
+            for dataset_name in self.modelpool.model_names
+        }
+        train_loaders = {
+            dataset_name: DataLoader(dataset, shuffle=True, **self._dataloader_kwargs)
+            for dataset_name, dataset in train_datasets.items()
+        }
+        train_loaders = {
+            dataset_name: self.fabric.setup_dataloaders(loader)
+            for dataset_name, loader in train_loaders.items()
+        }
+        return train_loaders
+
+    def train(self, model: nn.Module, train_loaders: Dict[str, DataLoader]):
+        config = self.config
+
+        # save the configuration
+        self.log_hyperparams(config, filename="method_config.yaml")
+
+        # setup the model
+        num_objectives = len(self.modelpool.model_names)
+        model = model
+
+        # setup data loaders
+        train_loaders = {
+            name: InfiniteDataLoader(loader) for name, loader in train_loaders.items()
+        }
+
+        # set up the optimizer and learning rate scheduler
+        optimizer = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=config.lr,
+        )
+        model, optimizer = self.fabric.setup(model, optimizer)
+        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+            optimizer=optimizer, T_max=config.num_steps, eta_min=config.lr * 0.1
+        )
+
+        model.train()
+        device = self.fabric.device
+        for step_idx in tqdm(
+            range(1, 1 + config.num_steps), "training", dynamic_ncols=True
+        ):
+            # sample a preference ray
+            ray = torch.from_numpy(
+                np.random.dirichlet((config.alpha,) * num_objectives, 1)
+                .astype(np.float32)
+                .flatten()
+            ).to(device)
+            ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
+
+            losses = []
+            for dataset_idx, dataset_name in enumerate(train_loaders):
+                batch = next(train_loaders[dataset_name])
+                images, labels = batch
+
+                logits = self.compute_logits(model, images, dataset_name)
+                _loss = F.cross_entropy(logits, labels)
+                losses.append(_loss)
+
+            loss = self.compute_loss(model, ray, losses)
+
+            optimizer.zero_grad()
+            self.fabric.backward(loss)
+            optimizer.step()
+
+            lr_scheduler.step()
+
+            self.fabric.log("train/loss", loss.item(), step=step_idx)
+
+            if step_idx % config.save_interval == 0:
+                (Path(self.log_dir) / "checkpoints").mkdir(exist_ok=True, parents=True)
+                save_path = (
+                    Path(self.log_dir) / "checkpoints" / f"model_step={step_idx}.pt"
+                )
+                torch.save(model.state_dict(), save_path)
+
+        return model
+
+    @abstractmethod
+    def compute_loss(
+        self, model: nn.Module, ray: Tensor, losses: List[Tensor]
+    ) -> Tensor:
+        """
+        Computes the overall losses using the given preference ray.
+
+        Args:
+            model (nn.Module): The model being trained.
+            ray (Tensor): A tensor representing the preference ray, which contains the weights for each objective.
+            losses (List[Tensor]): A list of loss values for each objective.
+        """
+        pass
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ compute_loss(model, ray, losses) + + + abstractmethod + + +
+ + +
+ +

Computes the overall losses using the given preference ray.

+ + +

Parameters:

+
    +
  • + model + (Module) + – +
    +

    The model being trained.

    +
    +
  • +
  • + ray + (Tensor) + – +
    +

    A tensor representing the preference ray, which contains the weights for each objective.

    +
    +
  • +
  • + losses + (List[Tensor]) + – +
    +

    A list of loss values for each objective.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/pwe_moe/clip_pwe_moe.py +
@abstractmethod
+def compute_loss(
+    self, model: nn.Module, ray: Tensor, losses: List[Tensor]
+) -> Tensor:
+    """
+    Computes the overall losses using the given preference ray.
+
+    Args:
+        model (nn.Module): The model being trained.
+        ray (Tensor): A tensor representing the preference ray, which contains the weights for each objective.
+        losses (List[Tensor]): A list of loss values for each objective.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ load_clip_models() + +
+ + +
+ +

Loads the pretrained CLIP model and the fine-tuned models for each dataset specified in the configuration.

+ +
+ Source code in fusion_bench/method/pwe_moe/clip_pwe_moe.py +
def load_clip_models(self):
+    """
+    Loads the pretrained CLIP model and the fine-tuned models for each dataset specified in the configuration.
+    """
+    # load pretrained and fine-tuned model
+    with timeit_context():
+        log.info("load models")
+        pretrained_model: CLIPVisionModel = self.modelpool.load_model(
+            "_pretrained_"
+        )
+        finetuned_models = {
+            model_name: self.modelpool.load_model(model_name)
+            for model_name in self.modelpool.model_names
+        }
+
+    log.info("pretrained model statistics:")
+    print_parameters(pretrained_model)
+    return pretrained_model, finetuned_models
+
+
+
+ +
+ +
+ + +
+ setup_train_loaders() + +
+ + +
+ +

Loads the datasets specified in the configuration.

+ +
+ Source code in fusion_bench/method/pwe_moe/clip_pwe_moe.py +
def setup_train_loaders(self):
+    """
+    Loads the datasets specified in the configuration.
+    """
+    train_datasets = {
+        dataset_name: self.modelpool.load_train_dataset(
+            dataset_name, self.clip_processor
+        )
+        for dataset_name in self.modelpool.model_names
+    }
+    train_loaders = {
+        dataset_name: DataLoader(dataset, shuffle=True, **self._dataloader_kwargs)
+        for dataset_name, dataset in train_datasets.items()
+    }
+    train_loaders = {
+        dataset_name: self.fabric.setup_dataloaders(loader)
+        for dataset_name, loader in train_loaders.items()
+    }
+    return train_loaders
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +
+ PWEMoELinearScalarizationForCLIP + + +
+ + +
+

+ Bases: PWEMoEAlgorithmForCLIP

+ + + + + + + +
+ Source code in fusion_bench/method/pwe_moe/clip_pwe_moe.py +
class PWEMoELinearScalarizationForCLIP(PWEMoEAlgorithmForCLIP):
+    def compute_loss(self, model, ray, losses):
+        loss = 0
+        for r, l in zip(ray, losses):
+            loss += r * l
+        return loss
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +
+ PWEMoExactParetoOptimalForCLIP + + +
+ + +
+

+ Bases: PWEMoEAlgorithmForCLIP

+ + + + + + + +
+ Source code in fusion_bench/method/pwe_moe/clip_pwe_moe.py +
class PWEMoExactParetoOptimalForCLIP(PWEMoEAlgorithmForCLIP):
+    def compute_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
+        from phn.solvers import EPOSolver
+
+        if self.epo_solver is None:
+            num_objectives = len(self.finetuned_models)
+            self.epo_solver = EPOSolver(n_tasks=num_objectives, n_params=None)
+        epo_solver = self.epo_solver
+
+        losses = torch.stack(losses)
+        loss = epo_solver.get_weighted_loss(
+            losses,
+            ray,
+            tuple(filter(lambda p: p.requires_grad, model.parameters())),
+        )
+        return loss
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/regmean/index.html b/algorithms/regmean/index.html new file mode 100644 index 00000000..4bcd66fb --- /dev/null +++ b/algorithms/regmean/index.html @@ -0,0 +1,2373 @@ + + + + + + + + + + + + + + + + + + + + + + + RegMean - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

RegMean

+

Code Integration

+

Merge CLIP-ViT-B/32 models on eight image classification tasks

+
fusion_bench method=clip_regmean \
+  modelpool=clip-vit-base-patch32_TA8 \
+  taskpool=clip-vit-classification_TA8
+
+

Merge CLIP-ViT-L/14 models on eight image classification tasks

+
fusion_bench \
+  method=clip_regmean \
+    method.batch_size=8 method.num_workers=4 \
+  modelpool=clip-vit-large-patch14_TA8 \
+  taskpool=clip-vit-classification_TA8 \
+    taskpool.clip_model=openai/clip-vit-large-patch14
+
+

Merge GPT-2 models for text classification tasks:

+
fusion_bench \
+  method=gpt2_regmean \
+  modelpool=gpt-2_glue \
+  taskpool=gpt-2_glue
+
+

References

+
+
+
    +
  1. +

    Xisen Jin, et al. "Dataless Knowledge Fusion by Merging Weights of Language Models." http://arxiv.org/abs/2212.09849 

    +
  2. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/simple_averaging/index.html b/algorithms/simple_averaging/index.html new file mode 100644 index 00000000..ec7ad264 --- /dev/null +++ b/algorithms/simple_averaging/index.html @@ -0,0 +1,2787 @@ + + + + + + + + + + + + + + + + + + + + + + + Simple Averaging - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Simple Averaging

+

Simple averaging is known in the literature as isotropic merging, ModelSoups, aims to yield a more robust and generalizable model. +Simple Averaging is a technique frequently employed when there are multiple models that have been fine-tuned or independently trained from scratch. +Specifically, if we possess \(n\) models that share a common architecture but different weights denoted as \(\theta_i\), the weights of the merged model, represented as \(\theta\), are computed as follows:

+
\[ \theta = \frac{1}{n} \sum_{i=1}^{n} \theta_i \]
+

This equation simply states that each weight of the final model is the average of the corresponding weights in the individual models. For example, if we have three models and the weight of the first neuron in the first layer is 0.1, 0.2, and 0.3 in each model respectively, the weight of that neuron in the final model will be (0.1 + 0.2 + 0.3) / 3 = 0.2.

+

Simple averaging is a straightforward and scalable method for model fusion. It does not require any additional training or fine-tuning, making it a good choice when computational resources are limited, where maintaining an ensemble of models is not feasible.

+

This method often assumes that all models are equally good. +If some models are significantly better than others, it might be beneficial to assign more weight to the better models when averaging. +This can be done by using weighted averaging, where each model's contribution to the final model is weighted by its performance on a validation set or some other metric. +See Weighed Averaging for more details. +Otherwise, the poor model may have a negative impact on the merged model.

+

Examples

+

In this example, we will demonstrate how to use the SimpleAverageAlgorithm class from the fusion_bench.method module. +This algorithm is used to merge multiple models by averaging their parameters.

+
from fusion_bench.method.simple_average import SimpleAverageAlgorithm
+
+# Instantiate the SimpleAverageAlgorithm
+# This algorithm will be used to merge multiple models by averaging their parameters.
+algorithm = SimpleAverageAlgorithm()
+
+# Assume we have a list of PyTorch models (nn.Module instances) that we want to merge.
+# The models should all have the same architecture.
+models = [...]
+
+# Run the algorithm on the models.
+# This will return a new model that is the result of averaging the parameters of the input models.
+merged_model = algorithm.run(models)
+
+

The run method of the SimpleAverageAlgorithm class takes a list of models as input and returns a new model. +The new model's parameters are the average of the parameters of the input models. +This is useful in scenarios where you have trained multiple models and want to combine them into a single model that hopefully performs better than any individual model.

+

Code Integration

+

Configuration template for the Simple Averaging algorithm:

+
config/method/simple_average.yaml
name: simple_average
+
+

use the following command to run the Simple Averaging algorithm:

+
fusion_bench method=simple_average ...
+
+

References

+ + +
+ + + +

+ SimpleAverageAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm, SimpleProfilerMixin

+ + + + + + + +
+ Source code in fusion_bench/method/simple_average.py +
class SimpleAverageAlgorithm(
+    BaseAlgorithm,
+    SimpleProfilerMixin,
+):
+    @torch.no_grad()
+    def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
+        """
+        Fuse the models in the given model pool using simple averaging.
+
+        This method iterates over the names of the models in the model pool, loads each model, and appends it to a list.
+        It then returns the simple average of the models in the list.
+
+        Args:
+            modelpool: The pool of models to fuse.
+
+        Returns:
+            The fused model obtained by simple averaging.
+        """
+        if isinstance(modelpool, dict):
+            modelpool = BaseModelPool(modelpool)
+
+        log.info(
+            f"Fusing models using simple average on {len(modelpool.model_names)} models."
+            f"models: {modelpool.model_names}"
+        )
+        sd: Optional[StateDictType] = None
+        forward_model = None
+        merged_model_names = []
+
+        for model_name in modelpool.model_names:
+            with self.profile("load model"):
+                model = modelpool.load_model(model_name)
+                merged_model_names.append(model_name)
+                print(f"load model of type: {type(model).__name__}")
+            with self.profile("merge weights"):
+                if sd is None:
+                    # Initialize the state dictionary with the first model's state dictionary
+                    sd = model.state_dict(keep_vars=True)
+                    forward_model = model
+                else:
+                    # Add the current model's state dictionary to the accumulated state dictionary
+                    sd = state_dict_add(sd, model.state_dict(keep_vars=True))
+        with self.profile("merge weights"):
+            # Divide the accumulated state dictionary by the number of models to get the average
+            sd = state_dict_mul(sd, 1 / len(modelpool.model_names))
+
+        forward_model.load_state_dict(sd)
+        # print profile report and log the merged models
+        self.print_profile_summary()
+        log.info(f"merged {len(merged_model_names)} models:")
+        for model_name in merged_model_names:
+            log.info(f"  - {model_name}")
+        return forward_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Fuse the models in the given model pool using simple averaging.

+

This method iterates over the names of the models in the model pool, loads each model, and appends it to a list. +It then returns the simple average of the models in the list.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (Union[BaseModelPool, Dict[str, Module]]) + – +
    +

    The pool of models to fuse.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    The fused model obtained by simple averaging.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/simple_average.py +
@torch.no_grad()
+def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
+    """
+    Fuse the models in the given model pool using simple averaging.
+
+    This method iterates over the names of the models in the model pool, loads each model, and appends it to a list.
+    It then returns the simple average of the models in the list.
+
+    Args:
+        modelpool: The pool of models to fuse.
+
+    Returns:
+        The fused model obtained by simple averaging.
+    """
+    if isinstance(modelpool, dict):
+        modelpool = BaseModelPool(modelpool)
+
+    log.info(
+        f"Fusing models using simple average on {len(modelpool.model_names)} models."
+        f"models: {modelpool.model_names}"
+    )
+    sd: Optional[StateDictType] = None
+    forward_model = None
+    merged_model_names = []
+
+    for model_name in modelpool.model_names:
+        with self.profile("load model"):
+            model = modelpool.load_model(model_name)
+            merged_model_names.append(model_name)
+            print(f"load model of type: {type(model).__name__}")
+        with self.profile("merge weights"):
+            if sd is None:
+                # Initialize the state dictionary with the first model's state dictionary
+                sd = model.state_dict(keep_vars=True)
+                forward_model = model
+            else:
+                # Add the current model's state dictionary to the accumulated state dictionary
+                sd = state_dict_add(sd, model.state_dict(keep_vars=True))
+    with self.profile("merge weights"):
+        # Divide the accumulated state dictionary by the number of models to get the average
+        sd = state_dict_mul(sd, 1 / len(modelpool.model_names))
+
+    forward_model.load_state_dict(sd)
+    # print profile report and log the merged models
+    self.print_profile_summary()
+    log.info(f"merged {len(merged_model_names)} models:")
+    for model_name in merged_model_names:
+        log.info(f"  - {model_name}")
+    return forward_model
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/simple_ensemble/index.html b/algorithms/simple_ensemble/index.html new file mode 100644 index 00000000..987ddece --- /dev/null +++ b/algorithms/simple_ensemble/index.html @@ -0,0 +1,2633 @@ + + + + + + + + + + + + + + + + + + + + + + + Simple Ensemble - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Simple Ensemble

+

Ensemble methods are simple and effective ways to improve the performance of machine learning models. +They combine the outputs of multiple models to create a stronger model.

+

Examples

+
from fusion_bench.method import EnsembleAlgorithm
+
+# Instantiate the EnsembleAlgorithm
+algorithm = EnsembleAlgorithm()
+
+# Assume we have a list of PyTorch models (nn.Module instances) that we want to ensemble.
+models = [...]
+
+# Run the algorithm on the models.
+merged_model = algorithm.run(models)
+
+

Code Integration

+

Configuration template for the ensemble algorithm:

+
config/method/simple_ensemble.yaml
name: simple_ensemble
+
+

create a simple ensemble of CLIP-ViT models for image classification tasks.

+
fusion_bench \
+  method=ensemble/simple_ensemble \
+  modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
+  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 
+
+

References

+ + +
+ + + +

+ SimpleEnsembleAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + + + + + + +
+ Source code in fusion_bench/method/ensemble.py +
19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
class SimpleEnsembleAlgorithm(BaseAlgorithm):
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool | List[nn.Module]):
+        """
+        Run the simple ensemble algorithm on the given model pool.
+
+        Args:
+            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
+
+        Returns:
+            EnsembleModule: The ensembled model.
+        """
+        log.info(f"Running ensemble algorithm with {len(modelpool)} models")
+
+        models = [modelpool.load_model(m) for m in modelpool.model_names]
+        ensemble = EnsembleModule(models=models)
+        return ensemble
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Run the simple ensemble algorithm on the given model pool.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (BaseModelPool | List[Module]) + – +
    +

    The pool of models to ensemble.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +EnsembleModule – +
    +

    The ensembled model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/ensemble.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
@torch.no_grad()
+def run(self, modelpool: BaseModelPool | List[nn.Module]):
+    """
+    Run the simple ensemble algorithm on the given model pool.
+
+    Args:
+        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
+
+    Returns:
+        EnsembleModule: The ensembled model.
+    """
+    log.info(f"Running ensemble algorithm with {len(modelpool)} models")
+
+    models = [modelpool.load_model(m) for m in modelpool.model_names]
+    ensemble = EnsembleModule(models=models)
+    return ensemble
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/smile_upscaling/index.html b/algorithms/smile_upscaling/index.html new file mode 100644 index 00000000..5b2412f4 --- /dev/null +++ b/algorithms/smile_upscaling/index.html @@ -0,0 +1,4277 @@ + + + + + + + + + + + + + + + + + + + + + + + SMILE Upscaling - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

SMILE Upscaling

+

arXiv

+
+alt text +
The architecture of the Sparse MIxture of Low-rank Experts (SMILE) module.1
+
+

Taxonomy for SMILE Upscaling

+

Here we present the taxonomy for the SMILE upscaling method following "A Survey on Model MoErging" by Yadav et al. (2024) 2.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Expert TrainingStandardExpert DataPrivateRouting DatasetNone
Input GranularityStepDepth GranularityModuleExpert SelectionSparse
Expert AggregationOutputGeneralizationIn-DistributionUser DatasetZero-Shot
+

Configurations

+

The SMILE upscaling method offers several configuration options, which are located in the config/method/ directory.

+
    +
  1. General nn.Module Upscaling: + This configuration is designed for upscaling any neural network module (nn.Module).
  2. +
  3. Mistral Model Upscaling: + This specific configuration is for Mistral models.
  4. +
+

Each configuration file contains detailed parameters and options that can be adjusted to meet the specific needs of your model and application.

+
config/method/smile_upscaling.yaml
name: smile_upscaling
+
+# merge device on cuda can accelerate the SVD computation
+device: cpu
+# device to compute svd
+upscaling_accelerator: cuda
+full_matrices: true # set to false if you are sure k < rank
+
+gate_k: 1
+k: 128
+top_k: 1
+
+routing_use_diff: true
+# average the remaining part, if this is set the False, the remaining part will kept as base model (the pretrained model)
+average_experts: false
+
+# path to save/load the model
+model_path: null
+
+
config/method/smile_mistral_upscaling.yaml
name: smile_mistral_upscaling
+
+device: cpu
+accelerator: cuda
+
+# path to save/load the model
+model_path: null
+model_dtype: float16
+
+num_experts_per_tok: 1
+rank_of_router: 8
+rank_of_expert: 512
+
+

Examples

+

CLIP-ViT-B/32 on eight tasks

+

Evaluate single fine-tuned models and save the results to outputs/ViT-B-32/single-task/ and outputs/ViT-L-14/single-task/ for CLIP-ViT-B/32 and CLIP-ViT-L/14 models, respectively.

+
# evaluate singlue fine-tuned models
+for task in sun397 stanford-cars resisc45 eurosat svhn gtsrb mnist dtd
+do
+    fusion_bench method=dummy \
+        modelpool=clip-vit-base-patch32_individual \
+            modelpool.models.0.path=tanganke/clip-vit-base-patch32_${task} \
+        taskpool=clip-vit-classification_TA8 \
+        report_save_path="outputs/ViT-B-32/single-task/clip-vit-base-patch32_${task}.json"
+done
+
+# if you have multiple GPUs, you can run the following code to evaluate the CLIP-ViT-L/14 models in parallel
+# evaluate singlue fine-tuned models clip-vit-large
+tasks=(sun397 stanford-cars resisc45 eurosat svhn gtsrb mnist dtd)
+CUDA_DEVICES=(0 1 2 3 4 5 6 7)  # List of CUDA devices to use
+
+for i in "${!CUDA_DEVICES[@]}"; do
+    task=${tasks[$i]}
+    CUDA_VISIBLE_DEVICES=${CUDA_DEVICES[$i]} fusion_bench method=dummy \
+        modelpool=clip-vit-large-patch14_individual \
+            modelpool.models.0.path=tanganke/clip-vit-large-patch14_${task} \
+        taskpool=clip-vit-classification_TA8 \
+            taskpool.clip_model=openai/clip-vit-large-patch14 \
+        report_save_path="outputs/ViT-L-14/single-task/clip-vit-large-patch14_${task}.json" &
+done
+
+

Upscale eight CLIP-ViT-B/32 models with SMILE, each CLIP-ViT-B/32 model is trained on a downstream task.

+
gate_k=16
+k=32
+fusion_bench \
+    method=smile_upscaling \
+        method.device=cuda \
+        method.gate_k=$gate_k method.k=$k \
+    modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
+    taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 \
+    report_save_path="outputs/ViT-B-32/eight_tasks/gate_k\=${gate_k}_k\=${k}.json"
+
+

Hyperparameter search for SMILE upscaling. Pre-run results can be found in examples/smile_upscaling/clip-vit-base-patch32.ipynb.

+
for gate_k in 1 2 4 8 16 32 64 128 256 512 768; do
+    for k in 4 8 16 32 64 128 -1; do
+        fusion_bench \
+            method=smile_upscaling \
+                method.device=cuda \
+                method.gate_k=$gate_k method.k=$k \
+            modelpool=clip-vit-base-patch32_TA8 \
+            taskpool=clip-vit-classification_TA8 \
+            report_save_path="outputs/ViT-B-32/eight_tasks/gate_k\=${gate_k}_k\=${k}.json"
+    done
+done
+
+

Ablations on number of experts per token (Top-K). Pre-run results can be found in examples/smile_upscaling/clip-vit-base-patch32-ablations-topk.ipynb.

+
gate_k=16
+k=32
+for top_k in 1 2 4
+do
+fusion_bench \
+    method=smile_upscaling \
+        method.device=cuda \
+        method.gate_k=$gate_k method.k=$k \
+    modelpool=clip-vit-base-patch32_TA8 \
+    taskpool=clip-vit-classification_TA8 \
+    report_save_path="outputs/ViT-B-32/ablation/gate_k\=${gate_k}_k\=${k}.json"
+done
+
+

CLIP-ViT-L/14 on eight tasks

+

hyperparameter search for SMILE upscaling. Pre-run results can be found in examples/smile_upscaling/clip-vit-large-patch14.ipynb.

+
for gate_k in 1 2 4 8 16 32 64 128; do
+    for k in 4 8 16 32 64 128 -1; do
+        fusion_bench \
+            method=smile_upscaling \
+                method.gate_k=$gate_k method.k=$k \
+            modelpool=clip-vit-large-patch14_TA8 \
+            taskpool=clip-vit-classification_TA8 \
+                taskpool.clip_model=openai/clip-vit-large-patch14 \
+            report_save_path="outputs/ViT-B-32/eight_tasks/gate_k\=${gate_k}_k\=${k}.json"
+    done
+done
+
+

Flan-T5 models on eight tasks from GLUE benchmark

+

Hyperparameter search for full fine-tuned and lora fine-tuned Flan-T5 models. +Pre-run results can be found in examples/smile_upscaling/flan-t5-base.ipynb and examples/smile_upscaling/flan-t5-base-lora16.ipynb.

+
# hyperparameter search for full fine-tuned flan-t5-base
+for gate_k in 4 8 16 32; do
+    for k in 16 32 64 128; do
+        fusion_bench \
+            method=smile_upscaling \
+                method.device=cpu \
+                method.gate_k=$gate_k method.k=$k \
+            modelpool=flan-t5-base_glue \
+            taskpool=flan-t5_glue_text_generation \
+            report_save_path="outputs/flan-t5-base/glue_text_generation/gate_k\=${gate_k}_k\=${k}.json"
+    done
+done
+
+# hyperparameter search for lora fine-tuned flan-t5-base
+for gate_k in 2 4 8; do
+    for k in 4 8 16; do
+        fusion_bench \
+            method=smile_upscaling \
+                method.device=cuda \
+                method.gate_k=$gate_k method.k=$k \
+            modelpool=flan-t5-base_glue_lora16 \
+            taskpool=flan-t5_glue_text_generation \
+            report_save_path="outputs/flan-t5-base_lora16/glue_text_generation/gate_k\=${gate_k}_k\=${k}.json"
+    done
+done
+
+

Upscale Mistral-7B models

+

Here we upscale several Mistral-7B models using SMILE. The models are trained on different tasks and are used as experts in the SMILE upscaling.

+

We first provide an example of the upscaled model, where we upscale the linear layers of the original Mistral model into a SMILE linear layer.

+
import torch
+from accelerate import init_empty_weights
+from transformers import AutoConfig
+
+from fusion_bench.models.modeling_smile_mistral import (
+    SmileMistralConfig,
+    SmileMistralForCausalLM,
+)
+
+
+config = AutoConfig.from_pretrained(
+    "mistralai/Mistral-7B-v0.1"
+)
+config = SmileMistralConfig(
+    num_experts_per_tok=1,
+    rank_of_router=8,
+    rank_of_expert=8,
+    num_local_experts=3,
+    **config.to_dict()
+)
+with init_empty_weights():
+    model = SmileMistralForCausalLM(config)
+model.to(dtype=torch.float16).to_empty(device="cuda")
+
+

The model architecture is as follows:

+
SmileMistralForCausalLM(
+  (model): SmileMistralModel(
+    (embed_tokens): Embedding(32000, 4096)
+    (layers): ModuleList(
+      (0-31): 32 x SmileMistralDecoderLayer(
+        (self_attn): SmileMistralAttention(
+          (q_proj): SingularMoELinear(in_features=4096, out_features=4096, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (k_proj): SingularMoELinear(in_features=4096, out_features=1024, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (v_proj): SingularMoELinear(in_features=4096, out_features=1024, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (o_proj): SingularMoELinear(in_features=4096, out_features=4096, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (rotary_emb): MistralRotaryEmbedding()
+        )
+        (mlp): SmileMistralMLP(
+          (gate_proj): SingularMoELinear(in_features=4096, out_features=14336, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (up_proj): SingularMoELinear(in_features=4096, out_features=14336, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (down_proj): SingularMoELinear(in_features=14336, out_features=4096, num_local_experts=3, num_experts_per_tok=1, rank_of_router=8, rank_of_expert=8)
+          (act_fn): SiLU()
+        )
+        (input_layernorm): MistralRMSNorm()
+        (post_attention_layernorm): MistralRMSNorm()
+      )
+    )
+    (norm): MistralRMSNorm()
+  )
+  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
+)
+
+

Knowing the model architecture, we can upscale the Mistral-7B models using the following steps:

+
    +
  1. +

    Prepare the following 4 configuration files in configs/modelpool:

    +
    config/modelpool/smile_mistral_exp_v1.yaml
    type: AutoModelForCausalLMPool
    +models:
    +- name: _pretrained_
    +    path: mistralai/Mistral-7B-v0.1
    +- name: expert_1
    +    path: meta-math/MetaMath-Mistral-7B
    +
    +dtype: float16
    +
    +
    config/modelpool/smile_mistral_exp_v2.yaml
    type: AutoModelForCausalLMPool
    +models:
    +- name: _pretrained_
    +    path: mistralai/Mistral-7B-v0.1
    +- name: expert_1
    +    path: cognitivecomputations/dolphin-2.1-mistral-7b
    +
    +dtype: float16
    +
    +
    config/modelpool/smile_mistral_exp_v3.yaml
    type: AutoModelForCausalLMPool
    +models:
    +- name: _pretrained_
    +    path: mistralai/Mistral-7B-v0.1
    +- name: expert_1
    +    path: uukuguy/speechless-code-mistral-7b-v1.0
    +
    +dtype: float16
    +
    +
    config/modelpool/smile_mistral_exp_v4.yaml
    type: AutoModelForCausalLMPool
    +models:
    +- name: _pretrained_
    +    path: mistralai/Mistral-7B-v0.1
    +- name: expert_1
    +    path: meta-math/MetaMath-Mistral-7B
    +- name: expert_2
    +    path: cognitivecomputations/dolphin-2.1-mistral-7b
    +- name: expert_3
    +    path: uukuguy/speechless-code-mistral-7b-v1.0
    +
    +dtype: float16
    +
    +
  2. +
  3. +

    Upscale Mistral-7B models. The upscaled models are saved in outputs/mistral/gate_k-${gate_k}_k-${k}/version_${version}.

    +
    function model_fusion() {
    +    output_dir=outputs/mistral/gate_k-${gate_k}_k-${k}/version_${version}
    +    fusion_bench \
    +        method=smile_mistral_upscaling \
    +            method.rank_of_router=$gate_k method.rank_of_expert=$k \
    +            method.model_path=${output_dir} \
    +        modelpool=smile_mistral_exp_v${version} \
    +            modelpool.dtype=float32 \
    +        taskpool=dummy \
    +        report_save_path="${output_dir}/model_info.json"
    +}
    +
    +gate_k=8
    +for k in 8 16 32 64 128 256 384 512; do
    +    for version in 1 2 3 4; do
    +        model_fusion
    +    done
    +done
    +
    +
  4. +
  5. +

    Use lm-evaluation-harness to evaluate the models. We use the default configurations for each task.

    +
    # For some GPUs, the following environment variables need to be set
    +# export NCCL_P2P_DISABLE="1"
    +# export NCCL_IB_DISABLE="1"
    +
    +function model_eval() {
    +    output_dir=outputs/mistral/gate_k-${gate_k}_k-${k}/version_${version}
    +
    +    # Check if ${output_dir}/${task}.json exists as a directory and return if it does
    +    if [ -d "${output_dir}/${task}.json" ]; then
    +        echo "Directory ${output_dir}/${task}.json already exists. Skipping evaluation."
    +        return
    +    fi
    +
    +    lm_eval --model hf \
    +        --model_args pretrained=${output_dir},dtype="float16",parallelize=True \
    +        --tasks ${task} \
    +        --output_path ${output_dir}/${task}.json \
    +        --batch_size 6
    +}
    +
    +

    The above function can be used to evaluate the models on specified task. +Pre-run results can be found in examples/smile_upscaling/mistral_gsm8k.ipynb.

    +
    # Evaluate all the models on GSM8K task
    +gate_k=8
    +task=gsm8k
    +for k in 8 16 32 64 128 256 384 512; do
    +    for version in 1 2 3 4; do
    +        model_eval
    +    done
    +done
    +
    +# Evaluate all M0;123 models on truthfulqa gsm8k arc_challenge mmlu
    +k=8
    +version=4
    +for task in truthfulqa gsm8k arc_challenge mmlu; do
    +    model_eval
    +done
    +
    +

    The reported metrics are:

    +
      +
    • mmlu (general): acc
    • +
    • truthfulqa (truthful): mc2
    • +
    • gsm8k (math): flexible exact match
    • +
    • arc_challenge (reasoning): acc_norm
    • +
    +
  6. +
+

Scope

+

Projection Merge Experiments

+

Pre-run results can be found in examples/smile_upscaling/clip-vit-base-patch32_single-task_projection-merging.ipynb.

+
# project into different subspaces
+for task in sun397 stanford-cars resisc45 eurosat svhn gtsrb mnist dtd
+do
+    # Space I
+    CUDA_VISIBLE_DEVICES=0 fusion_bench \
+        method=singular_projection_merging \
+            method.device=cuda method.rank=low method.k=-1 method.full_matrices=false \
+        modelpool=clip-vit-base-patch32_single_finetuned \
+            modelpool.models.1.name=${task} \
+            modelpool.models.1.path=tanganke/clip-vit-base-patch32_${task} \
+        taskpool=clip-vit-classification_TA8 \
+        report_save_path="outputs/ViT-B-32/single-task/projection_merging_zone1_${task}.json" &
+
+    # Space II
+    CUDA_VISIBLE_DEVICES=1 fusion_bench \
+        method=singular_projection_merging \
+            method.device=cuda method.rank=high method.k=-1 method.full_matrices=false \
+        modelpool=clip-vit-base-patch32_single_finetuned \
+            modelpool.models.1.name=${task} \
+            modelpool.models.1.path=tanganke/clip-vit-base-patch32_${task} \
+        taskpool=clip-vit-classification_TA8 \
+        report_save_path="outputs/ViT-B-32/single-task/projection_merging_zone2_${task}.json" &
+
+    # Space III
+    CUDA_VISIBLE_DEVICES=2 fusion_bench \
+        method=singular_projection_merging \
+            method.device=cuda method.rank=high method.k=-1 method.full_matrices=true \
+        modelpool=clip-vit-base-patch32_single_finetuned \
+            modelpool.models.1.name=${task} \
+            modelpool.models.1.path=tanganke/clip-vit-base-patch32_${task} \
+        taskpool=clip-vit-classification_TA8 \
+        report_save_path="outputs/ViT-B-32/single-task/projection_merging_zone23_${task}.json" &
+    wait
+done
+
+

References

+

Algorithms

+ + +
+ + + +

+ SmileUpscalingAlgorithm + + +

+ + +
+

+ Bases: SimpleProfilerMixin, BaseAlgorithm

+ + + + + + + +
+ Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py +
class SmileUpscalingAlgorithm(
+    SimpleProfilerMixin,
+    BaseAlgorithm,
+):
+    _linear_layer_cls = (nn.Linear,)
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "device": "device",
+        "upscaling_accelerator": "upscaling_accelerator",
+        "full_matrices": "full_matrices",
+        "gate_k": "gate_k",
+        "k": "k",
+        "top_k": "top_k",
+        "routing_use_diff": "routing_use_diff",
+        "average_experts": "average_experts",
+        "model_path": "model_path",
+    }
+
+    def __init__(
+        self,
+        *,
+        device: str = "cuda",
+        upscaling_accelerator: str = None,
+        full_matrices: bool = True,
+        gate_k: int = 256,
+        k: int = 256,
+        top_k: int = 1,
+        routing_use_diff: bool = True,
+        average_experts: bool = False,
+        model_path: str = None,
+        **kwargs,
+    ):
+        """
+        Initialize the SmileUpscalingAlgorithm.
+
+        Args:
+            device (str): The device to perform the computation on.
+            upscaling_accelerator (str): The device to perform the SVD computation on.
+            full_matrices (bool): Whether to compute the full-sized U and V matrices.
+            gate_k (int): The number of singular values to keep for the gate.
+            k (int): The number of singular values to keep for the experts.
+            top_k (int): The number of top experts to select.
+            routing_use_diff (bool): Whether to use weight differences for routing.
+            average_experts (bool): Whether to average the experts.
+            model_path (str): The path to save/load the model.
+            **kwargs: Additional arguments.
+        """
+        super().__init__()
+        self.device = device
+        self.upscaling_accelerator = upscaling_accelerator
+        self.full_matrices = full_matrices
+        self.gate_k = gate_k
+        self.k = k
+        self.top_k = top_k
+        self.routing_use_diff = routing_use_diff
+        self.average_experts = average_experts
+        self.model_path = model_path
+        for key, value in kwargs.items():
+            log.warning(f"Unrecognized argument: {key}")
+            setattr(self, key, value)
+
+        # print `self.config` as yaml
+        print(f"=== Config for `{type(self).__name__}` ===")
+        print(OmegaConf.to_yaml(self.config))
+        print(f"=== Config for `{type(self).__name__}` ===")
+
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool):
+        """
+        Executes the upscaling process.
+
+        Args:
+            modelpool (ModelPool): The pool of models to be used for upscaling.
+
+        Returns:
+            nn.Module: The upscaled model.
+        """
+        if not isinstance(modelpool, BaseModelPool):
+            modelpool = BaseModelPool(modelpool)
+
+        if self.config.model_path is not None and os.path.exists(
+            self.config.model_path
+        ):
+            log.info(f"Loading model from {self.config.model_path}")
+            model = torch.load(self.config.model_path)
+            print_parameters(model)
+            return model
+
+        with self.profile("load pretrained model"):
+            pretrained_model = modelpool.load_model("_pretrained_")
+        with self.profile("load fine-tuned model"):
+            finetuned_models = [
+                m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
+            ]
+
+        if self.config.device == "cuda" and torch.cuda.is_available():
+            pretrained_model = pretrained_model.cuda()
+            finetuned_models = [m.cuda() for m in finetuned_models]
+
+        with self.profile("merge model"):
+            model = self.merge(pretrained_model, finetuned_models)
+
+        self.print_profile_summary()
+        if self.config.model_path is not None:
+            os.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
+            log.info(f"Saving model to {self.config.model_path}")
+            torch.save(model, self.config.model_path)
+        print_parameters(model)
+        return model
+
+    def merge(
+        self,
+        pretrained_model: nn.Module,
+        finetuned_models: List[nn.Module],
+        in_place: bool = True,
+    ):
+        """
+        Merges the pretrained model with the fine-tuned models to create an upscaled model.
+
+        Args:
+            pretrained_model (nn.Module): The pretrained model.
+            finetuned_models (List[nn.Module]): A list of fine-tuned models.
+            in_place (bool): If True, modifies the pretrained model in place. Otherwise, creates a copy.
+
+        Returns:
+            nn.Module: The merged model.
+        """
+        if in_place:
+            model = pretrained_model
+        else:
+            model = deepcopy(pretrained_model)
+
+        self._upscale_submodules(model, finetuned_models)
+        return model
+
+    def _upscale_linear_layer(
+        self,
+        pretrained_model,
+        finetuned_models,
+        name: str,
+    ):
+        """
+        Upscale a linear layer by merging it with the corresponding layers from the fine-tuned models.
+
+        Args:
+            pretrained_model (nn.Module): The pretrained model.
+            finetuned_models (List[nn.Module]): A list of fine-tuned models.
+            name (str): The name of the linear layer to upscale.
+        """
+        config = self.config
+
+        name_list = name.split(".")
+        module = get_attr(pretrained_model, name_list)
+        experts = [get_attr(m, name_list) for m in finetuned_models]
+        try:
+            moe_linear = SmileMoELinear(
+                module,
+                experts,
+                gate_k=config.gate_k,
+                k=config.k,
+                top_k=config.top_k,
+                routing_use_diff=self.routing_use_diff,
+                full_matrices=self.full_matrices,
+                upscaling_accelerator=self.upscaling_accelerator,
+            )
+        except ExpertNotTrainedError:
+            print(f"skip {name} because the experts are not trained.")
+            return
+        set_attr(pretrained_model, name_list, moe_linear)
+        # remove the original module from fine-tuned models to save memory
+        for m in finetuned_models:
+            set_attr(m, name_list, None)
+
+    def _average_experts(self, pretarined_model, finetuned_models, name: str):
+        """
+        Average the experts for a given layer.
+
+        Args:
+            pretarined_model (nn.Module): The pretrained model.
+            finetuned_models (List[nn.Module]): A list of fine-tuned models.
+            name (str): The name of the layer to average.
+        """
+        name_list = name.split(".")
+        experts = [get_attr(m, name_list) for m in finetuned_models]
+        averaged_module = simple_average(experts)
+        set_attr(pretarined_model, name_list, averaged_module)
+
+    def _upscale_submodules(
+        self,
+        pretrained_model: nn.Module,
+        finetuned_models: List[nn.Module],
+        tqdm_desc: str = "Upscaling Linear Modules",
+    ):
+        """
+        Upscales the submodules of the pretrained model by merging them with the corresponding submodules from the fine-tuned models.
+
+        Args:
+            pretrained_model (nn.Module): The pretrained model.
+            finetuned_models (List[nn.Module]): A list of fine-tuned models.
+            tqdm_desc (str): Description for the tqdm progress bar.
+        """
+        config = self.config
+        for name, module in tqdm(
+            tuple(pretrained_model.named_modules()),
+            tqdm_desc,
+            leave=False,
+            dynamic_ncols=True,
+        ):
+            if isinstance(module, self._linear_layer_cls):
+                self._upscale_linear_layer(
+                    pretrained_model=pretrained_model,
+                    finetuned_models=finetuned_models,
+                    name=name,
+                )
+            elif config.average_experts and len(tuple(module.named_modules())) == 1:
+                # if the module is a leaf module, we perform a parameter average
+                self._average_experts(pretrained_model, finetuned_models, name)
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ __init__(*, device='cuda', upscaling_accelerator=None, full_matrices=True, gate_k=256, k=256, top_k=1, routing_use_diff=True, average_experts=False, model_path=None, **kwargs) + +
+ + +
+ +

Initialize the SmileUpscalingAlgorithm.

+ + +

Parameters:

+
    +
  • +
    device +
    (str, default: + 'cuda' +) + – +
    +

    The device to perform the computation on.

    +
    +
  • +
  • +
    upscaling_accelerator +
    (str, default: + None +) + – +
    +

    The device to perform the SVD computation on.

    +
    +
  • +
  • +
    full_matrices +
    (bool, default: + True +) + – +
    +

    Whether to compute the full-sized U and V matrices.

    +
    +
  • +
  • +
    gate_k +
    (int, default: + 256 +) + – +
    +

    The number of singular values to keep for the gate.

    +
    +
  • +
  • +
    k +
    (int, default: + 256 +) + – +
    +

    The number of singular values to keep for the experts.

    +
    +
  • +
  • +
    top_k +
    (int, default: + 1 +) + – +
    +

    The number of top experts to select.

    +
    +
  • +
  • +
    routing_use_diff +
    (bool, default: + True +) + – +
    +

    Whether to use weight differences for routing.

    +
    +
  • +
  • +
    average_experts +
    (bool, default: + False +) + – +
    +

    Whether to average the experts.

    +
    +
  • +
  • +
    model_path +
    (str, default: + None +) + – +
    +

    The path to save/load the model.

    +
    +
  • +
  • +
    **kwargs +
    – +
    +

    Additional arguments.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py +
def __init__(
+    self,
+    *,
+    device: str = "cuda",
+    upscaling_accelerator: str = None,
+    full_matrices: bool = True,
+    gate_k: int = 256,
+    k: int = 256,
+    top_k: int = 1,
+    routing_use_diff: bool = True,
+    average_experts: bool = False,
+    model_path: str = None,
+    **kwargs,
+):
+    """
+    Initialize the SmileUpscalingAlgorithm.
+
+    Args:
+        device (str): The device to perform the computation on.
+        upscaling_accelerator (str): The device to perform the SVD computation on.
+        full_matrices (bool): Whether to compute the full-sized U and V matrices.
+        gate_k (int): The number of singular values to keep for the gate.
+        k (int): The number of singular values to keep for the experts.
+        top_k (int): The number of top experts to select.
+        routing_use_diff (bool): Whether to use weight differences for routing.
+        average_experts (bool): Whether to average the experts.
+        model_path (str): The path to save/load the model.
+        **kwargs: Additional arguments.
+    """
+    super().__init__()
+    self.device = device
+    self.upscaling_accelerator = upscaling_accelerator
+    self.full_matrices = full_matrices
+    self.gate_k = gate_k
+    self.k = k
+    self.top_k = top_k
+    self.routing_use_diff = routing_use_diff
+    self.average_experts = average_experts
+    self.model_path = model_path
+    for key, value in kwargs.items():
+        log.warning(f"Unrecognized argument: {key}")
+        setattr(self, key, value)
+
+    # print `self.config` as yaml
+    print(f"=== Config for `{type(self).__name__}` ===")
+    print(OmegaConf.to_yaml(self.config))
+    print(f"=== Config for `{type(self).__name__}` ===")
+
+
+
+ +
+ +
+ + +
+ merge(pretrained_model, finetuned_models, in_place=True) + +
+ + +
+ +

Merges the pretrained model with the fine-tuned models to create an upscaled model.

+ + +

Parameters:

+
    +
  • +
    pretrained_model +
    (Module) + – +
    +

    The pretrained model.

    +
    +
  • +
  • +
    finetuned_models +
    (List[Module]) + – +
    +

    A list of fine-tuned models.

    +
    +
  • +
  • +
    in_place +
    (bool, default: + True +) + – +
    +

    If True, modifies the pretrained model in place. Otherwise, creates a copy.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    nn.Module: The merged model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py +
def merge(
+    self,
+    pretrained_model: nn.Module,
+    finetuned_models: List[nn.Module],
+    in_place: bool = True,
+):
+    """
+    Merges the pretrained model with the fine-tuned models to create an upscaled model.
+
+    Args:
+        pretrained_model (nn.Module): The pretrained model.
+        finetuned_models (List[nn.Module]): A list of fine-tuned models.
+        in_place (bool): If True, modifies the pretrained model in place. Otherwise, creates a copy.
+
+    Returns:
+        nn.Module: The merged model.
+    """
+    if in_place:
+        model = pretrained_model
+    else:
+        model = deepcopy(pretrained_model)
+
+    self._upscale_submodules(model, finetuned_models)
+    return model
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Executes the upscaling process.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (ModelPool) + – +
    +

    The pool of models to be used for upscaling.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    nn.Module: The upscaled model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py +
@torch.no_grad()
+def run(self, modelpool: BaseModelPool):
+    """
+    Executes the upscaling process.
+
+    Args:
+        modelpool (ModelPool): The pool of models to be used for upscaling.
+
+    Returns:
+        nn.Module: The upscaled model.
+    """
+    if not isinstance(modelpool, BaseModelPool):
+        modelpool = BaseModelPool(modelpool)
+
+    if self.config.model_path is not None and os.path.exists(
+        self.config.model_path
+    ):
+        log.info(f"Loading model from {self.config.model_path}")
+        model = torch.load(self.config.model_path)
+        print_parameters(model)
+        return model
+
+    with self.profile("load pretrained model"):
+        pretrained_model = modelpool.load_model("_pretrained_")
+    with self.profile("load fine-tuned model"):
+        finetuned_models = [
+            m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
+        ]
+
+    if self.config.device == "cuda" and torch.cuda.is_available():
+        pretrained_model = pretrained_model.cuda()
+        finetuned_models = [m.cuda() for m in finetuned_models]
+
+    with self.profile("merge model"):
+        model = self.merge(pretrained_model, finetuned_models)
+
+    self.print_profile_summary()
+    if self.config.model_path is not None:
+        os.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
+        log.info(f"Saving model to {self.config.model_path}")
+        torch.save(model, self.config.model_path)
+    print_parameters(model)
+    return model
+
+
+
+ +
+ + + +
+ +
+ +
+
+
    +
  1. +

    A. Tang et. al. SMILE: Zero-Shot Sparse Mixture of Low-Rank Experts Construction From Pre-Trained Foundation Models. Aug, 2024. +https://arxiv.org/abs/2408.10174 

    +
  2. +
  3. +

    Yadav, Prateek, et al. "A Survey on Model MoErging: Recycling and Routing Among Specialized Experts for Collaborative Learning." arXiv preprint arXiv:2408.07057 (2024). 

    +
  4. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/specification_ensemble/index.html b/algorithms/specification_ensemble/index.html new file mode 100644 index 00000000..4755ba17 --- /dev/null +++ b/algorithms/specification_ensemble/index.html @@ -0,0 +1,2208 @@ + + + + + + + + + + + + + + + + + + + Specification ensemble - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Specification ensemble

+ + + + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/task_arithmetic/index.html b/algorithms/task_arithmetic/index.html new file mode 100644 index 00000000..a23d4dac --- /dev/null +++ b/algorithms/task_arithmetic/index.html @@ -0,0 +1,3081 @@ + + + + + + + + + + + + + + + + + + + + + + + Task Arithmetic - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Task Arithmetic

+

In the rapidly advancing field of machine learning, multi-task learning has emerged as a powerful paradigm, allowing models to leverage information from multiple tasks to improve performance and generalization. One intriguing method in this domain is Task Arithmetic, which involves the combination of task-specific vectors derived from model parameters.

+
+ Image title +
Task Arithmetic. This figure is credited to 2
+
+

Task Vector. A task vector is used to encapsulate the adjustments needed by a model to specialize in a specific task. +It is derived from the differences between a pre-trained model's parameters and those fine-tuned for a particular task. +Formally, if \(\theta_i\) represents the model parameters fine-tuned for the i-th task and \(\theta_0\) denotes the parameters of the pre-trained model, the task vector for the i-th task is defined as:

+
\[\tau_i = \theta_i - \theta_0\]
+

This representation is crucial for methods like Task Arithmetic, where multiple task vectors are aggregated and scaled to form a comprehensive multi-task model.

+

Task Arithmetic1 begins by computing a task vector \(\tau_i\) for each individual task, using the set of model parameters \(\theta_0 \cup \{\theta_i\}_i\) where \(\theta_0\) is the pre-trained model and \(\theta_i\) are the fine-tuned parameters for i-th task. +These task vectors are then aggregated to form a multi-task vector. +Subsequently, the multi-task vector is combined with the pre-trained model parameters to obtain the final multi-task model. +This process involves scaling the combined vector element-wise by a scaling coefficient (denoted as \(\lambda\)), before adding it to the initial pre-trained model parameters. +The resulting formulation for obtaining a multi-task model is expressed as

+
\[ \theta = \theta_0 + \lambda \sum_{i} \tau_i. \]
+

The choice of the scaling coefficient \(\lambda\) plays a crucial role in the final model performance. Typically, \(\lambda\) is chosen based on validation set performance.

+

Examples

+

To use the Task Arithmetic algorithm, you can use the TaskArithmeticAlgorithm class from the fusion_bench.method module.

+
from fusion_bench.method.task_arithmetic import TaskArithmeticAlgorithm
+from omegaconf import DictConfig
+
+# Instantiate the TaskArithmeticAlgorithm
+method_config = {'name': 'task_arithmetic', 'scaling_factor': 0.5}
+algorithm = TaskArithmeticAlgorithm(DictConfig(method_config))
+
+# Assume we have a dict of PyTorch models (nn.Module instances) that we want to merge.
+# The models should all have the same architecture.
+# the dict must contain the pre-trained model with the key '_pretrained_', and arbitrary number of fine-tuned models.
+models = {'_pretrained_': nn.Linear(10,10), 'model_1': nn.Linear(10,10), 'model_2': nn.Linear(10,10)}
+
+# Run the algorithm on the models.
+# This will return a new model that is the result of task arithmetic on the input models.
+merged_model = algorithm.run(models)
+
+

Code Integration

+

Configuration template for the Task Arithmetic algorithm:

+
config/method/task_arithmetic.yaml
name: task_arithmetic
+scaling_factor: 0.5 # Scaling factor for task vectors
+
+

Use the following command to run the Task Arithmetic algorithm:

+
fusion_bench method=task_arithmetic ...
+
+

For example, to run the Task Arithmetic algorithm on two models with scaling factor 0.5:

+
fusion_bench method=task_arithmetic \
+    method.scaling_factor=0.5 \
+  modelpool=clip-vit-base-patch32_svhn_and_mnist \
+  taskpool=clip-vit-base-patch32_svhn_and_mnist
+
+

where the configuration for the model pool is:

+
config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml
type: huggingface_clip_vision
+# the modelpool must contain the pre-trained model with the name '_pretrained_', 
+# and arbitrary number of fine-tuned models.
+models:
+  - name: _pretrained_
+    path: openai/clip-vit-base-patch32
+  - name: svhn
+    path: tanganke/clip-vit-base-patch32_svhn
+  - name: mnist
+    path: tanganke/clip-vit-base-patch32_mnist
+
+

and the configuration for the task pool:

+
config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml
type: clip_vit_classification
+
+dataset_type: huggingface_image_classification
+tasks:
+  - name: svhn
+    dataset:
+      type: instantiate
+      name: svhn
+      object: 
+        _target_: datasets.load_dataset
+        _args_:
+          - svhn
+          - cropped_digits
+        split: test
+  - name: mnist
+    dataset:
+      name: mnist
+      split: test
+
+...
+
+

References

+ + +
+ + + +

+ TaskArithmeticAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm, SimpleProfilerMixin

+ + +

Task Arithmetic Algorithm for model fusion.

+

This class implements the Task Arithmetic method for fusing models. It inherits from +BaseModelFusionAlgorithm and SimpleProfilerMixin to provide the necessary functionality +for model fusion and profiling.

+ + +

Attributes:

+
    +
  • + scaling_factor + (int) + – +
    +

    The factor by which the task vectors will be scaled before merging.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/task_arithmetic/task_arithmetic.py +
class TaskArithmeticAlgorithm(
+    BaseAlgorithm,
+    SimpleProfilerMixin,
+):
+    """
+    Task Arithmetic Algorithm for model fusion.
+
+    This class implements the Task Arithmetic method for fusing models. It inherits from
+    BaseModelFusionAlgorithm and SimpleProfilerMixin to provide the necessary functionality
+    for model fusion and profiling.
+
+    Attributes:
+        scaling_factor (int): The factor by which the task vectors will be scaled before merging.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "scaling_factor": "scaling_factor"
+    }
+
+    def __init__(self, scaling_factor: int):
+        """
+        Initializes the TaskArithmeticAlgorithm with the given scaling factor.
+
+        Args:
+            scaling_factor (int): The factor by which the task vectors will be scaled before merging.
+        """
+        self.scaling_factor = scaling_factor
+        super().__init__()
+
+    @torch.no_grad()
+    def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
+        """
+        Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
+
+        Args:
+            modelpool (Union[BaseModelPool, Dict[str, nn.Module]]): The pool of models to fuse.
+
+        Returns:
+            nn.Module: The pre-trained model with the merged task vectors.
+        """
+        if not isinstance(modelpool, BaseModelPool):
+            modelpool = BaseModelPool(modelpool)
+
+        log.info("Fusing models using task arithmetic.")
+        task_vector = None
+        with self.profile("load model"):
+            pretrained_model = modelpool.load_model("_pretrained_")
+
+        # Calculate the total task vector
+        for model_name in modelpool.model_names:
+            with self.profile("load model"):
+                model = modelpool.load_model(model_name)
+            with self.profile("merge weights"):
+                if task_vector is None:
+                    task_vector = state_dict_sub(
+                        model.state_dict(keep_vars=True),
+                        pretrained_model.state_dict(keep_vars=True),
+                    )
+                else:
+                    task_vector = state_dict_add(
+                        task_vector,
+                        state_dict_sub(
+                            model.state_dict(keep_vars=True),
+                            pretrained_model.state_dict(keep_vars=True),
+                        ),
+                    )
+        with self.profile("merge weights"):
+            # scale the task vector
+            task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
+            # add the task vector to the pretrained model
+            state_dict = state_dict_add(
+                pretrained_model.state_dict(keep_vars=True), task_vector
+            )
+
+        self.print_profile_summary()
+        pretrained_model.load_state_dict(state_dict)
+        return pretrained_model
+
+
+ + + +
+ + + + + + + +
+ + + +
+ _config_mapping = BaseAlgorithm._config_mapping | {'scaling_factor': 'scaling_factor'} + + + class-attribute + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ scaling_factor = scaling_factor + + + instance-attribute + + +
+ + +
+
+ +
+ + + +
+ + +
+ __init__(scaling_factor) + +
+ + +
+ +

Initializes the TaskArithmeticAlgorithm with the given scaling factor.

+ + +

Parameters:

+
    +
  • +
    scaling_factor +
    (int) + – +
    +

    The factor by which the task vectors will be scaled before merging.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/task_arithmetic/task_arithmetic.py +
def __init__(self, scaling_factor: int):
+    """
+    Initializes the TaskArithmeticAlgorithm with the given scaling factor.
+
+    Args:
+        scaling_factor (int): The factor by which the task vectors will be scaled before merging.
+    """
+    self.scaling_factor = scaling_factor
+    super().__init__()
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Runs the Task Arithmetic Algorithm to fuse models in the given model pool.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (Union[BaseModelPool, Dict[str, Module]]) + – +
    +

    The pool of models to fuse.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    nn.Module: The pre-trained model with the merged task vectors.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/task_arithmetic/task_arithmetic.py +
@torch.no_grad()
+def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
+    """
+    Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
+
+    Args:
+        modelpool (Union[BaseModelPool, Dict[str, nn.Module]]): The pool of models to fuse.
+
+    Returns:
+        nn.Module: The pre-trained model with the merged task vectors.
+    """
+    if not isinstance(modelpool, BaseModelPool):
+        modelpool = BaseModelPool(modelpool)
+
+    log.info("Fusing models using task arithmetic.")
+    task_vector = None
+    with self.profile("load model"):
+        pretrained_model = modelpool.load_model("_pretrained_")
+
+    # Calculate the total task vector
+    for model_name in modelpool.model_names:
+        with self.profile("load model"):
+            model = modelpool.load_model(model_name)
+        with self.profile("merge weights"):
+            if task_vector is None:
+                task_vector = state_dict_sub(
+                    model.state_dict(keep_vars=True),
+                    pretrained_model.state_dict(keep_vars=True),
+                )
+            else:
+                task_vector = state_dict_add(
+                    task_vector,
+                    state_dict_sub(
+                        model.state_dict(keep_vars=True),
+                        pretrained_model.state_dict(keep_vars=True),
+                    ),
+                )
+    with self.profile("merge weights"):
+        # scale the task vector
+        task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
+        # add the task vector to the pretrained model
+        state_dict = state_dict_add(
+            pretrained_model.state_dict(keep_vars=True), task_vector
+        )
+
+    self.print_profile_summary()
+    pretrained_model.load_state_dict(state_dict)
+    return pretrained_model
+
+
+
+ +
+ + + +
+ +
+ +
+
+
    +
  1. +

    (ICLR 2023) Editing Models with Task Arithmetic. http://arxiv.org/abs/2212.04089 

    +
  2. +
  3. +

    (ICLR 2024) AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575 

    +
  4. +
  5. +

    (NIPS 2023 Oral) Guillermo Ortiz-Jimenez, Alessandro Favero, and Pascal Frossard, “Task Arithmetic in the Tangent Space: Improved Editing of Pre-Trained Models,” doi: 10.48550/arXiv.2305.12827. 

    +
  6. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/ties_merging/index.html b/algorithms/ties_merging/index.html new file mode 100644 index 00000000..c4a24559 --- /dev/null +++ b/algorithms/ties_merging/index.html @@ -0,0 +1,3327 @@ + + + + + + + + + + + + + + + + + + + + + + + Ties-Merging - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+ +
+ + + +
+
+ + + + + + + +

Ties Merging

+
+ Image title +
+ Ties-Merging. Credit to 1 +
+
+

Ties-Merging1 represents a novel and structured approach to consolidating multiple task-specific models into a single, efficient multi-task model. This method employs a sequence of deliberate steps to systematically merge task vectors, ensuring that the final model effectively integrates the strengths of each individual task-specific model and resolves potential conflicts between them.

+

The Ties-Merging algorithm operates through three primary steps:

+
    +
  1. Trim: This initial step involves refining the task-specific models by trimming unnecessary parameters, focusing the model on essential elements for each task.
  2. +
  3. Elect Sign of Parameters: In this step, the algorithm selects the appropriate signs for the parameters, ensuring that the integrated model parameters are optimally oriented for multi-task learning.
  4. +
  5. Disjoint Merge: Finally, the method performs a disjoint merge to combine the task-specific parameters into a single cohesive task vector, denoted as \(\tau\).
  6. +
+

Given the final merged task vector \(\tau\), the ultimate model is determined similarly to the method used in task arithmetic. The formulation is expressed as:

+
\[ +\theta = \theta_0 + \lambda \tau +\]
+

where \(\lambda\) is a hyperparameter chosen based on the validation set to ensure the best-performing model.

+

By following these structured steps, Ties-Merging effectively integrates multiple task-specific models into a unified multi-task model, balancing the contributions of each task to enhance overall performance. The process ensures that the final model retains the benefits of the pre-trained model while optimally incorporating the diverse knowledge contained within the individual task-specific models.

+

Hyperparameter Tuning

+
+alt text +
+Task Arithmetic and Ties-Merging. Here we illustrate the average performance of models merged using Task Arithmetic and Ties-Merging methods, with varying scaling coefficients. +The subfigures represent different models: CLIP-ViT-B/32, CLIP-ViT-L/14, Flan-T5-base (LoRA fine-tuned), and Flan-T5-large (LoRA fine-tuned). +
+
+

In the above figure, we show the average performance of Task Arithmetic and Ties-Merging merged models as the scaling coefficient varies. Subfigure (a), (b), (c), and (d) show the results of CLIP-ViT-B/32, CLIP-ViT-L/14, Flan-T5-base (LoRA fine-tuned), and Flan-T5-large (LoRA fine-tuned), respectively. It is evident that the merged multi-task model hits a peak in average performance across various tasks when the scaling coefficient is set around 0.3. This value was empirically selected as the scaling coefficient in our experiments. As we increase the scaling coefficient beyond this point, the average performance of the model begins to decline, eventually even falling below the level of the pre-trained model’s original performance. This suggests that too high of a scaling coefficient can have a negative impact on the knowledge that the pre-trained model initially possessed, emphasizing the importance of calibrating the scaling coefficient parameter \(\lambda\) to avoid diminishing the model’s existing strengths.

+

Code Integration

+

Configuration template for the Ties-Merging algorithm:

+
config/method/ties_merging.yaml
name: ties_merging
+# Scaling factor $\lambda$
+scaling_factor: 0.5
+threshold: 0.5
+# List of keys to remove from the state dict, default is empty
+remove_keys: []
+# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max'
+merge_func: sum 
+
+

Use the following command to run the Ties-Merging algorithm:

+
fusion_bench method=ties_merging ...
+
+

Reference

+ + +
+ + + +

+ TiesMergingAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + +

TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.

+ + +

Attributes:

+
    +
  • + scaling_factor + (float) + – +
    +

    The scaling factor to apply to the merged task vector.

    +
    +
  • +
  • + threshold + (float) + – +
    +

    The threshold for resetting values in the task vector.

    +
    +
  • +
  • + remove_keys + (List[str]) + – +
    +

    List of keys to remove from the state dictionary.

    +
    +
  • +
  • + merge_func + (Literal['sum', 'mean', 'max']) + – +
    +

    The merge function to use for disjoint merging.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/ties_merging/ties_merging.py +
class TiesMergingAlgorithm(BaseAlgorithm):
+    """
+    TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
+
+    Attributes:
+        scaling_factor (float): The scaling factor to apply to the merged task vector.
+        threshold (float): The threshold for resetting values in the task vector.
+        remove_keys (List[str]): List of keys to remove from the state dictionary.
+        merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "scaling_factor": "scaling_factor",
+        "threshold": "threshold",
+        "remove_keys": "remove_keys",
+        "merge_func": "merge_func",
+    }
+
+    def __init__(
+        self,
+        scaling_factor: float,
+        threshold: float,
+        remove_keys: List[str],
+        merge_func: Literal["sum", "mean", "max"],
+        **kwargs,
+    ):
+        """
+        Initialize the TiesMergingAlgorithm with the given parameters.
+
+        Args:
+            scaling_factor (float): The scaling factor to apply to the merged task vector.
+            threshold (float): The threshold for resetting values in the task vector.
+            remove_keys (List[str]): List of keys to remove from the state dictionary.
+            merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
+            **kwargs: Additional keyword arguments for the base class.
+        """
+        self.scaling_factor = scaling_factor
+        self.threshold = threshold
+        self.remove_keys = remove_keys
+        self.merge_func = merge_func
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool | Dict[str, nn.Module]):
+        """
+        Run the TIES merging algorithm to fuse models in the model pool.
+
+        Args:
+            modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.
+
+        Returns:
+            nn.Module: The fused model.
+        """
+        log.info("Fusing models using ties merging.")
+        modelpool = to_modelpool(modelpool)
+        remove_keys = self.config.get("remove_keys", [])
+        merge_func = self.config.get("merge_func", "sum")
+        scaling_factor = self.scaling_factor
+        threshold = self.threshold
+
+        # Load the pretrained model
+        pretrained_model = modelpool.load_model("_pretrained_")
+
+        # Load the state dicts of the models
+        ft_checks: List[StateDictType] = [
+            modelpool.load_model(model_name).state_dict(keep_vars=True)
+            for model_name in modelpool.model_names
+        ]
+        ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
+
+        # Compute the task vectors
+        flat_ft: Tensor = torch.vstack(
+            [state_dict_to_vector(check, remove_keys) for check in ft_checks]
+        )
+        flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
+        tv_flat_checks = flat_ft - flat_ptm
+
+        # Perform TIES Merging
+        merged_tv = ties_merging(
+            tv_flat_checks,
+            reset_thresh=threshold,
+            merge_func=merge_func,
+        )
+        merged_check = flat_ptm + scaling_factor * merged_tv
+        merged_state_dict = vector_to_state_dict(
+            merged_check, ptm_check, remove_keys=remove_keys
+        )
+
+        # Load the merged state dict into the pretrained model
+        pretrained_model.load_state_dict(merged_state_dict)
+        return pretrained_model
+
+
+ + + +
+ + + + + + + +
+ + + +
+ _config_mapping = BaseAlgorithm._config_mapping | {'scaling_factor': 'scaling_factor', 'threshold': 'threshold', 'remove_keys': 'remove_keys', 'merge_func': 'merge_func'} + + + class-attribute + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ merge_func = merge_func + + + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ remove_keys = remove_keys + + + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ scaling_factor = scaling_factor + + + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ threshold = threshold + + + instance-attribute + + +
+ + +
+
+ +
+ + + +
+ + +
+ __init__(scaling_factor, threshold, remove_keys, merge_func, **kwargs) + +
+ + +
+ +

Initialize the TiesMergingAlgorithm with the given parameters.

+ + +

Parameters:

+
    +
  • +
    scaling_factor +
    (float) + – +
    +

    The scaling factor to apply to the merged task vector.

    +
    +
  • +
  • +
    threshold +
    (float) + – +
    +

    The threshold for resetting values in the task vector.

    +
    +
  • +
  • +
    remove_keys +
    (List[str]) + – +
    +

    List of keys to remove from the state dictionary.

    +
    +
  • +
  • +
    merge_func +
    (Literal['sum', 'mean', 'max']) + – +
    +

    The merge function to use for disjoint merging.

    +
    +
  • +
  • +
    **kwargs +
    – +
    +

    Additional keyword arguments for the base class.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/ties_merging/ties_merging.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
def __init__(
+    self,
+    scaling_factor: float,
+    threshold: float,
+    remove_keys: List[str],
+    merge_func: Literal["sum", "mean", "max"],
+    **kwargs,
+):
+    """
+    Initialize the TiesMergingAlgorithm with the given parameters.
+
+    Args:
+        scaling_factor (float): The scaling factor to apply to the merged task vector.
+        threshold (float): The threshold for resetting values in the task vector.
+        remove_keys (List[str]): List of keys to remove from the state dictionary.
+        merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
+        **kwargs: Additional keyword arguments for the base class.
+    """
+    self.scaling_factor = scaling_factor
+    self.threshold = threshold
+    self.remove_keys = remove_keys
+    self.merge_func = merge_func
+    super().__init__(**kwargs)
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Run the TIES merging algorithm to fuse models in the model pool.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (BaseModelPool | Dict[str, Module]) + – +
    +

    The model pool containing the models to fuse.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + – +
    +

    nn.Module: The fused model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/ties_merging/ties_merging.py +
@torch.no_grad()
+def run(self, modelpool: BaseModelPool | Dict[str, nn.Module]):
+    """
+    Run the TIES merging algorithm to fuse models in the model pool.
+
+    Args:
+        modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.
+
+    Returns:
+        nn.Module: The fused model.
+    """
+    log.info("Fusing models using ties merging.")
+    modelpool = to_modelpool(modelpool)
+    remove_keys = self.config.get("remove_keys", [])
+    merge_func = self.config.get("merge_func", "sum")
+    scaling_factor = self.scaling_factor
+    threshold = self.threshold
+
+    # Load the pretrained model
+    pretrained_model = modelpool.load_model("_pretrained_")
+
+    # Load the state dicts of the models
+    ft_checks: List[StateDictType] = [
+        modelpool.load_model(model_name).state_dict(keep_vars=True)
+        for model_name in modelpool.model_names
+    ]
+    ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
+
+    # Compute the task vectors
+    flat_ft: Tensor = torch.vstack(
+        [state_dict_to_vector(check, remove_keys) for check in ft_checks]
+    )
+    flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
+    tv_flat_checks = flat_ft - flat_ptm
+
+    # Perform TIES Merging
+    merged_tv = ties_merging(
+        tv_flat_checks,
+        reset_thresh=threshold,
+        merge_func=merge_func,
+    )
+    merged_check = flat_ptm + scaling_factor * merged_tv
+    merged_state_dict = vector_to_state_dict(
+        merged_check, ptm_check, remove_keys=remove_keys
+    )
+
+    # Load the merged state dict into the pretrained model
+    pretrained_model.load_state_dict(merged_state_dict)
+    return pretrained_model
+
+
+
+ +
+ + + +
+ +
+ +
+
+
    +
  1. +

    (NIPS 2023) Resolving Interference When Merging Models. http://arxiv.org/abs/2306.01708 

    +
  2. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/weight_ensembling_moe/index.html b/algorithms/weight_ensembling_moe/index.html new file mode 100644 index 00000000..62cb1907 --- /dev/null +++ b/algorithms/weight_ensembling_moe/index.html @@ -0,0 +1,5472 @@ + + + + + + + + + + + + + + + + + + + + + + + Weight-Ensembling MoE - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Weight-Ensembling Mixture of Experts (Data-Adaptive Model Merging)

+

arxiv +github

+
+ alt text +
+(a) Framework overview. This figure shows the overall framework of our proposed method to merge the pre-trained model and fine-tuned task-specific models. We merge weights in the Transformer Layers except for the MLPs. For the MLPs, we upcycle them into weight-assembling MoE modules. +(b) Wieght-Ensembling Mixture of Experts (MoE) Module. Here we outline the detailed structure of the Weight-Ensembling MoE module, composed of the router, pre-trained MLP weights, and a collection of task vectors. Collaboration between shared weights and task vectors is employed to create input-conditioned weights dynamically. In this way, we separate shared information and task-specific knowledge, which are then combined based on input in time. +
+
+

This method is designed to handle a wide range of tasks by segregating shared information and task-specific knowledge. +It dynamically combines these elements based on the input samples.

+

The Weight-Ensembling MoE module consists of three main components: the router, the pre-trained MLP weights, and a collection of task vectors. +The router, which is an MLP, processes the input data and generates routing weights. These weights determine how the knowledge from different tasks is combined. +The pre-trained MLP weights are crucial as they have been trained to recognize a wide range of data patterns. +The task vectors represent the differences between the MLPs that have been fine-tuned for specific tasks and the pre-trained ones, capturing the unique adjustments made to optimize them for specific tasks. +The routing weights are averaged across the input tokens, and these weights are used to select task vectors from a dictionary matrix. +These task vectors are then added to the pre-trained MLP weights to create input-conditioned weights.

+

Algorithm Requirements:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodAccess to labeled tasks dataAccess to validation data (labeled)Test time adaptation
Fisher MergingYes (Estimate Fisher information matrix)NoNo
RegMeanYes (compute Gram Matrix)NoNo
Task ArithmeticNoYes (select sacling factor)No
Ties-MergingNoYes (select sacling factor)No
AdaMergingNoNoYes
OursNoNoYes
+

WEMoE V2: E-WEMoE

+

L. Shen, A. Tang, E. Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging. Oct, 2024.3

+

arXiv +github

+
+ alt text + + (a) Overview of the Efficient Weight-Ensembling Mixture of Experts (E-WEMoE) Framework. It merges all non-MLP modules through task arithmetic and upgrades the MLP modules into an efficient E-WEMoE module. (b) E-WEMoE Module. The module includes a router shared across all Transformer blocks, the pre-trained MLP module, and a set of sparse task-specific vectors w.r.t. MLP modules. + +
+
+ alt text + + Comparison of (a) trainable parameters and (b) total parameters between WEMoE and E-WEMoE-90%. + +
+
+ alt text + + Comparison of the relationship between parameter count and performance across various model merging methods. + +
+

Parameters Comparison

+
+

Tip for reducing the parameter count

+

Here we present the parameter count for the method outlined in the original paper1. +An effective strategy to minimize the number of parameters involves employing Singular Value Decomposition (SVD) to compress the task vectors. +This approach significantly cuts down on the number of parameters while only marginally impacting performance. +For additional information, please refer to the Twin-Merging paper2. +Which not only reduces the number of parameters but also conducts extensive experiments to demonstrate the effectiveness of data-adaptive merging on language domain.

+
+

Here is the number of parameters compared to a single pre-trained model (OpenCLIP CLIP-ViT-B/32):

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodTrainable ParametersTotal ParametersParemeters Reduced by Merging
Single Pre-trained113.45M (100%)113.45M-
WEMoE (2-layer, 1 task)7.10M (4.00%)177.21M-
WEMoE (2-layer, 2 tasks)7.11M (3.04%)233.89M2*113.45-233.89=-6.99M
WEMoE (2-layer, 3 tasks)7.11M (2.45%)290.57M3*113.45-290.57=49.78M
WEMoE (2-layer, 4 tasks)7.12M (2.02%)347.25M4*113.45-347.25=106.55M
WEMoE (2-layer, 5 tasks)7.13M (1.77%)403.93M5*113.45-403.93=163.32M
WEMoE (2-layer, 6 tasks)7.14M (1.55%)460.61M6*113.45-460.61=220.09M
WEMoE (2-layer, 7 tasks)7.15M (1.38%)517.28M7*113.45-517.28=276.87M
WEMoE (2-layer, 8 tasks)7.16M (1.25%)573.96M8*113.45-573.96=333.64M
+

The number of parameter count of HuggingFace CLIP vision models (of type transformers.models.clip.modeling_clip.CLIPVisionModel) are different from the OpenCLIP models downloaded from the task arithmetic repo, because the OpenCLIP models (of type src.modeling.ImageEncoder) include the embedding layer for text tokens, while the HuggingFace CLIP vision models do not. +Therefore, the relative parameter count of the upscaled model using Transformer CLIP vision models will be larger than the OpenCLIP models.

+
+
+
+
ImageEncoder( # (1)
+  (model): CLIP(
+    (visual): VisualTransformer( # (2)
+      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
+      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+      (transformer): Transformer(
+        (resblocks): ModuleList(
+          (0-11): 12 x ResidualAttentionBlock(
+            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+            (attn): MultiheadAttention(
+              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
+            )
+            (ln_attn): Identity()
+            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+            (mlp): Sequential(
+              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
+              (ln): Identity()
+              (gelu): QuickGELU()
+              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
+            )
+          )
+        )
+      )
+      (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+    )
+    (token_embedding): Embedding(49408, 512) # (3)
+    (ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
+  )
+)
+
+
    +
  1. trainable params: 113.45M || all params: 113.45M || trainable%: 100.0000
  2. +
  3. trainable params: 87.85M || all params: 87.85M || trainable%: 100.0000
  4. +
  5. trainable params: 25.30M || all params: 25.30M || trainable%: 100.0000
  6. +
+
+
+
CLIPVisionModel( # (1)
+  (vision_model): CLIPVisionTransformer(
+    (embeddings): CLIPVisionEmbeddings(
+      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
+      (position_embedding): Embedding(50, 768)
+    )
+    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+    (encoder): CLIPEncoder(
+      (layers): ModuleList(
+        (0-11): 12 x CLIPEncoderLayer(
+          (self_attn): CLIPAttention(
+            (k_proj): Linear(in_features=768, out_features=768, bias=True)
+            (v_proj): Linear(in_features=768, out_features=768, bias=True)
+            (q_proj): Linear(in_features=768, out_features=768, bias=True)
+            (out_proj): Linear(in_features=768, out_features=768, bias=True)
+          )
+          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+          (mlp): CLIPMLP(
+            (activation_fn): QuickGELUActivation()
+            (fc1): Linear(in_features=768, out_features=3072, bias=True)
+            (fc2): Linear(in_features=3072, out_features=768, bias=True)
+          )
+          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+        )
+      )
+    )
+    (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
+  )
+)
+
+
    +
  1. trainable params: 87.85M || all params: 87.85M || trainable%: 100.0000
  2. +
+
+
+
+

Loss Landscape Visualization

+
+alt text +
+Visualization of the joint loss \(\mathcal{L}_1 + \mathcal{L}_2\) and five task pairs for CLIP-ViT-B/32 in the loss landscape. + We perform interpolations between pre-trained weights and two fine-tuned weights in the weight space on a 2D plane using the formula \(\theta=\theta_0 + \lambda_1 \tau_1 + \lambda_2 \tau_2\), where \(\theta_0\) represents pre-trained weights, \(\tau_i=\theta_i -\theta_0\) are two task vectors with \(\lambda_i\) in the range [-1, 1]. +
+
+

Hyperparameter Tuning

+

In the below figure, we show the performance of the merged models with varying numbers of steps. +Figure (b) shows the performance of the merged WEMoE models with varying number of steps. +In Figure (a), we merge CLIP-ViT-B/32 models with different learning rate configurations. +We observe that the performance of the merged model shows an upward trend with an increase in the number of training steps, and it converges rapidly, reaching a high accuracy level in just 200 steps. +Furthermore, the influence of different learning rates is not significant, suggesting that our method is insensitive to the learning rate parameter. This is a desirable property as it reduces the need for hyperparameter tuning.

+
+alt text +
+The performance of the merged models with a varying number of steps.
+(a) CLIP-ViT-B/32 model with different learning rates.
+(b) Comparison of CLIP-ViT-B/32 and CLIP-ViT-L/14. +
+
+

Ablations of Router Depth

+

Table: Parameter comparison of WEMoE (1-layer) and WEMoE (2-layer) on CLIP-ViT-B/32 models (OpenCLIP).

+ + + + + + + + + + + + + + + + + + + + + +
MethodNumber of Trainable Parameters
AdaMerging (layer-wise)1.3K
WEMoE (1-layer)73.8K (0.01%)
WEMoE (2-layer)7.16M (1.25%)
+

Table: Ablation study of the router depth on the performance of the up-scaled CLIP-ViT-B/32 models (OpenCLIP).

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodSUN397CARSRESISC45EuroSATSVHNGRSRBMNISTDTDAvg.
AdaMerging (layer-wise)66.668.382.492.586.593.797.761.180.9
WEMoE (1-layer)73.276.793.898.695.798.699.574.588.3
WEMoE (2-layer)74.177.493.799.196.298.999.676.489.4
+

To explore the influence of router depth on the performance of the scaled-up model, we perform an ablation study where the router depth is varied. In WEMoE modules, the router is implemented as a multi-layer perceptron (MLP).

+
    +
  • WEMoE (0-layer) functions as a bias-only model, representing a special case of an MLP with no hidden layers. It generates a constant routing weight for all inputs, captured by the formula as \(r(h) = b_0\), indicating that it does not adjust based on the input. + When we only up-scale the MLP modules of the vision Transformers to MoE modules, WEMoE (0-layer) can be considered as a partial implementation of AdaMerging. Add when we up-scale the vision Transformers layer-wisely, WEMoE (0-layer) can be considered equivalent to AdaMerging. + For WEMoE (0-layer), the MoE modules can be unloaded, thus no additional parameters and inference cost are introduced.
  • +
  • For WEMoE (1-layer), each router is a one-layer MLP that takes the input sample \(h\) and outputs the routing weight \(r(h)\), which is adaptive to the input. The routing weight is calculated as \(r(h) = W_1 h + b_1\).
  • +
  • For WEMoE (2-layer), each router is a two-layer MLP and the routing weight is calculated as \(r(h) = W_2 ReLU(W_1 h + b_1) + b_2\).
  • +
+

In the above two Tables, we present additional findings to support our argument. We compare the number of trainable parameters and performance between WEMoE (1-layer) and WEMoE (2-layer). The data reveal that WEMoE (1-layer) possesses 73.8K trainable parameters, which constitute only 0.01% of the total parameters in the merged model. Notably, the performance of WEMoE (1-layer) is significantly better than AdaMerging and nearly matches that of WEMoE (2-layer) across all tasks. This evidence underscores our claim that the MoE design is crucial for performance enhancement.

+

Code Integration

+

multi-task model fusion experiment on eight image classification tasks.

+
# merge eight CLIP-ViT-B/32 models using WE MoE
+fusion_bench \
+  method=weight_ensembling_moe \
+    method.name=clip_weight_ensembling_moe \
+    method.use_grad_accumulate=false \
+    method.save_checkpoint=outputs/clip-vit-base-patch32_TA8_weight_ensembling_moe_checkpoint.ckpt \
+  modelpool=clip-vit-base-patch32_TA8 \
+  taskpool=clip-vit-classification_TA8
+
+

merge eight CLIP-ViT-L/14 models:

+
# merge eight CLIP-ViT-L/14 models using WE MoE, fine-tune the routers
+fusion_bench print_config=false \
+  method=weight_ensembling_moe \
+    method.name=clip_weight_ensembling_moe \
+    method.use_grad_accumulate=true \
+    method.save_checkpoint=outputs/clip-vit-large-patch14_TA8_weight_ensembling_moe_checkpoint.ckpt \
+    method.batch_size=4 method.devices=4 \
+  modelpool=clip-vit-large-patch14_TA8 \
+  taskpool=dummy &&
+
+# load the checkpoint and evaluate the model
+fusion_bench \
+  method=weight_ensembling_moe \
+    method.name=clip_weight_ensembling_moe \
+    method.checkpoint=outputs/clip-vit-large-patch14_TA8_weight_ensembling_moe_checkpoint.ckpt \
+  modelpool=clip-vit-large-patch14_TA8 \
+  taskpool=clip-vit-classification_TA8 \
+    taskpool.clip_model=openai/clip-vit-large-patch14
+
+

Reference

+ + +
+ + + +

+ we_moe + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ WeightEnsemblingMoEAlgorithm + + +
+ + +
+

+ Bases: ModelFusionAlgorithm

+ + +

Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).

+

This class provides methods for constructing the MoE model, performing test-time adaptation, +and running the fusion process.

+ + +

Attributes:

+
    +
  • + _fabric + (Fabric) + – +
    +

    The fabric for distributed training.

    +
    +
  • +
  • + modelpool + (ModelPool) + – +
    +

    The pool of models to be fused.

    +
    +
  • +
  • + profiler + (SimpleProfiler) + – +
    +

    The profiler for measuring performance.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
+    """
+    Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
+
+    This class provides methods for constructing the MoE model, performing test-time adaptation,
+    and running the fusion process.
+
+    Attributes:
+        _fabric (L.Fabric): The fabric for distributed training.
+        modelpool (ModelPool): The pool of models to be fused.
+        profiler (SimpleProfiler): The profiler for measuring performance.
+    """
+
+    _fabric: L.Fabric = None
+    modelpool: ModelPool = None
+
+    def __init__(self, algorithm_config: DictConfig):
+        """
+        Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
+
+        Args:
+            algorithm_config (DictConfig): The configuration for the algorithm.
+        """
+        super().__init__(algorithm_config)
+
+        if self._fabric is None and torch.cuda.is_available():
+            self._fabric = L.Fabric(
+                devices=self.config.get("devices", 1),
+            )
+            self._fabric.launch()
+        else:
+            assert "No CUDA device available."
+        self.profiler = SimpleProfiler(
+            self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
+        )
+
+    @abstractmethod
+    def load_checkpoint(self, model, checkpoint):
+        """
+        Load the checkpoint file.
+
+        Args:
+            model: The model to load the checkpoint into.
+            checkpoint: The checkpoint file to load.
+        """
+        pass
+
+    @abstractmethod
+    def save_checkpoint(self, model, checkpoint):
+        """
+        Save the checkpoint file.
+
+        Args:
+            model: The model to save the checkpoint from.
+            checkpoint: The checkpoint file to save.
+        """
+        pass
+
+    @abstractmethod
+    def construct_moe_model(self) -> WeightEnsemblingMoE:
+        """
+        Construct the Mixture of Experts model using the models in the model pool.
+
+        Returns:
+            WeightEnsemblingMoE: The constructed MoE model.
+        """
+        pass
+
+    def on_test_time_adaptation_start(self):
+        """
+        Hook method called at the start of test-time adaptation.
+        """
+        pass
+
+    @abstractmethod
+    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
+        """
+        Get an iterator for the shuffled test data loader for a specific task.
+
+        Args:
+            task (str): The task for which to get the test data loader.
+
+        Returns:
+            DataLoader: The shuffled test data loader iterator.
+        """
+        pass
+
+    @abstractmethod
+    def compute_logits(self, module, batch, task) -> Tensor:
+        """
+        Compute the logits for a given batch and task.
+
+        Args:
+            module: The model module to use for computing logits.
+            batch: The batch of data.
+            task: The task for which to compute logits.
+
+        Returns:
+            Tensor: The computed logits.
+        """
+        pass
+
+    def test_time_adaptation(self, module: WeightEnsemblingMoE):
+        """
+        Perform test-time adaptation for the given module.
+
+        Args:
+            module (WeightEnsemblingMoE): The MoE module to adapt.
+
+        Returns:
+            WeightEnsemblingMoE: The adapted MoE module.
+        """
+        self.on_test_time_adaptation_start()
+
+        # configure optimizer
+        if self.config.optimizer == "adam":
+            optimizer = torch.optim.Adam(
+                [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
+            )
+        else:
+            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
+
+        if self._fabric is not None:
+            module, optimizer = self._fabric.setup(module, optimizer)
+
+        module.train()
+
+        if self.config.get("fast_dev_run", False):
+            log.info("Running fast_dev_run, only one step")
+            pbar = tqdm(
+                range(1),
+                "Test-time adaptation",
+                dynamic_ncols=True,
+            )
+        else:
+            pbar = tqdm(
+                range(self.config.max_steps),
+                "Test-time adaptation",
+                dynamic_ncols=True,
+            )
+        for step_idx in pbar:
+            if self.config.use_grad_accumulate:
+                for task in self.modelpool.model_names:
+                    with self.profiler.profile("data time"):
+                        batch = next(self.get_shuffled_test_loader_iter(task))
+                    with self.profiler.profile("forward pass"):
+                        logits = self.compute_logits(module, batch, task)
+                        assert (
+                            logits.dim() == 2
+                        ), f"Expected logits to be 2D, got {logits.dim()}"
+                        loss = entropy_loss(logits)
+                    # .backward() accumulates when .zero_grad() wasn't called
+                    # this can save memory
+                    with self.profiler.profile("backward pass"):
+                        self._fabric.backward(loss, retain_graph=True)
+            else:
+                loss = 0
+                for task in self.modelpool.model_names:
+                    with self.profiler.profile("data time"):
+                        batch = next(self.get_shuffled_test_loader_iter(task))
+                    with self.profiler.profile("forward pass"):
+                        logits = self.compute_logits(module, batch, task)
+                        assert (
+                            logits.dim() == 2
+                        ), f"Expected logits to be 2D, got {logits.dim()}"
+                        loss = loss + entropy_loss(logits)
+                with self.profiler.profile("backward pass"):
+                    self._fabric.backward(loss, retain_graph=True)
+
+            with self.profiler.profile("optimizer step"):
+                optimizer.step()
+                optimizer.zero_grad()
+
+        return module
+
+    def run(self, modelpool: ModelPool):
+        """
+        Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
+
+        Args:
+            modelpool (ModelPool): The pool of models to be fused.
+
+        Returns:
+            WeightEnsemblingMoE: The fused MoE model.
+        """
+        log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
+        self.modelpool = modelpool
+
+        with timeit_context("upscaling models to a weight-ensembling MoE model"):
+            moe_model = self.construct_moe_model()
+            print_parameters(moe_model)
+
+        if self.config.get("checkpoint", False):
+            log.info(
+                f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
+            )
+            self.load_checkpoint(moe_model, self.config.checkpoint)
+        else:
+            with self.profiler.profile("test-time adaptation"):
+                moe_model = self.test_time_adaptation(moe_model)
+            if self.config.get("save_checkpoint", False):
+                log.info(f"save checkpoint to {self.config.save_checkpoint}")
+                self.save_checkpoint(moe_model, self.config.save_checkpoint)
+
+            if lightning.fabric.wrappers.is_wrapped(moe_model):
+                moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
+
+        # enable sample-wise adaptation
+        moe_model.batch_reduce = False
+        print(self.profiler.summary())
+        return moe_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ __init__(algorithm_config) + +
+ + +
+ +

Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

+ + +

Parameters:

+
    +
  • + algorithm_config + (DictConfig) + – +
    +

    The configuration for the algorithm.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def __init__(self, algorithm_config: DictConfig):
+    """
+    Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
+
+    Args:
+        algorithm_config (DictConfig): The configuration for the algorithm.
+    """
+    super().__init__(algorithm_config)
+
+    if self._fabric is None and torch.cuda.is_available():
+        self._fabric = L.Fabric(
+            devices=self.config.get("devices", 1),
+        )
+        self._fabric.launch()
+    else:
+        assert "No CUDA device available."
+    self.profiler = SimpleProfiler(
+        self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
+    )
+
+
+
+ +
+ +
+ + +
+ compute_logits(module, batch, task) + + + abstractmethod + + +
+ + +
+ +

Compute the logits for a given batch and task.

+ + +

Parameters:

+
    +
  • + module + – +
    +

    The model module to use for computing logits.

    +
    +
  • +
  • + batch + – +
    +

    The batch of data.

    +
    +
  • +
  • + task + – +
    +

    The task for which to compute logits.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The computed logits.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
@abstractmethod
+def compute_logits(self, module, batch, task) -> Tensor:
+    """
+    Compute the logits for a given batch and task.
+
+    Args:
+        module: The model module to use for computing logits.
+        batch: The batch of data.
+        task: The task for which to compute logits.
+
+    Returns:
+        Tensor: The computed logits.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ construct_moe_model() + + + abstractmethod + + +
+ + +
+ +

Construct the Mixture of Experts model using the models in the model pool.

+ + +

Returns:

+
    +
  • +WeightEnsemblingMoE ( WeightEnsemblingMoE +) – +
    +

    The constructed MoE model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
@abstractmethod
+def construct_moe_model(self) -> WeightEnsemblingMoE:
+    """
+    Construct the Mixture of Experts model using the models in the model pool.
+
+    Returns:
+        WeightEnsemblingMoE: The constructed MoE model.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ get_shuffled_test_loader_iter(task) + + + abstractmethod + + +
+ + +
+ +

Get an iterator for the shuffled test data loader for a specific task.

+ + +

Parameters:

+
    +
  • + task + (str) + – +
    +

    The task for which to get the test data loader.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +DataLoader ( DataLoader +) – +
    +

    The shuffled test data loader iterator.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
@abstractmethod
+def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
+    """
+    Get an iterator for the shuffled test data loader for a specific task.
+
+    Args:
+        task (str): The task for which to get the test data loader.
+
+    Returns:
+        DataLoader: The shuffled test data loader iterator.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ load_checkpoint(model, checkpoint) + + + abstractmethod + + +
+ + +
+ +

Load the checkpoint file.

+ + +

Parameters:

+
    +
  • + model + – +
    +

    The model to load the checkpoint into.

    +
    +
  • +
  • + checkpoint + – +
    +

    The checkpoint file to load.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
73
+74
+75
+76
+77
+78
+79
+80
+81
+82
@abstractmethod
+def load_checkpoint(self, model, checkpoint):
+    """
+    Load the checkpoint file.
+
+    Args:
+        model: The model to load the checkpoint into.
+        checkpoint: The checkpoint file to load.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ on_test_time_adaptation_start() + +
+ + +
+ +

Hook method called at the start of test-time adaptation.

+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
def on_test_time_adaptation_start(self):
+    """
+    Hook method called at the start of test-time adaptation.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

+ + +

Parameters:

+
    +
  • + modelpool + (ModelPool) + – +
    +

    The pool of models to be fused.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +WeightEnsemblingMoE – +
    +

    The fused MoE model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
def run(self, modelpool: ModelPool):
+    """
+    Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
+
+    Args:
+        modelpool (ModelPool): The pool of models to be fused.
+
+    Returns:
+        WeightEnsemblingMoE: The fused MoE model.
+    """
+    log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
+    self.modelpool = modelpool
+
+    with timeit_context("upscaling models to a weight-ensembling MoE model"):
+        moe_model = self.construct_moe_model()
+        print_parameters(moe_model)
+
+    if self.config.get("checkpoint", False):
+        log.info(
+            f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
+        )
+        self.load_checkpoint(moe_model, self.config.checkpoint)
+    else:
+        with self.profiler.profile("test-time adaptation"):
+            moe_model = self.test_time_adaptation(moe_model)
+        if self.config.get("save_checkpoint", False):
+            log.info(f"save checkpoint to {self.config.save_checkpoint}")
+            self.save_checkpoint(moe_model, self.config.save_checkpoint)
+
+        if lightning.fabric.wrappers.is_wrapped(moe_model):
+            moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
+
+    # enable sample-wise adaptation
+    moe_model.batch_reduce = False
+    print(self.profiler.summary())
+    return moe_model
+
+
+
+ +
+ +
+ + +
+ save_checkpoint(model, checkpoint) + + + abstractmethod + + +
+ + +
+ +

Save the checkpoint file.

+ + +

Parameters:

+
    +
  • + model + – +
    +

    The model to save the checkpoint from.

    +
    +
  • +
  • + checkpoint + – +
    +

    The checkpoint file to save.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
84
+85
+86
+87
+88
+89
+90
+91
+92
+93
@abstractmethod
+def save_checkpoint(self, model, checkpoint):
+    """
+    Save the checkpoint file.
+
+    Args:
+        model: The model to save the checkpoint from.
+        checkpoint: The checkpoint file to save.
+    """
+    pass
+
+
+
+ +
+ +
+ + +
+ test_time_adaptation(module) + +
+ + +
+ +

Perform test-time adaptation for the given module.

+ + +

Parameters:

+
    +
  • + module + (WeightEnsemblingMoE) + – +
    +

    The MoE module to adapt.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +WeightEnsemblingMoE – +
    +

    The adapted MoE module.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
def test_time_adaptation(self, module: WeightEnsemblingMoE):
+    """
+    Perform test-time adaptation for the given module.
+
+    Args:
+        module (WeightEnsemblingMoE): The MoE module to adapt.
+
+    Returns:
+        WeightEnsemblingMoE: The adapted MoE module.
+    """
+    self.on_test_time_adaptation_start()
+
+    # configure optimizer
+    if self.config.optimizer == "adam":
+        optimizer = torch.optim.Adam(
+            [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
+        )
+    else:
+        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
+
+    if self._fabric is not None:
+        module, optimizer = self._fabric.setup(module, optimizer)
+
+    module.train()
+
+    if self.config.get("fast_dev_run", False):
+        log.info("Running fast_dev_run, only one step")
+        pbar = tqdm(
+            range(1),
+            "Test-time adaptation",
+            dynamic_ncols=True,
+        )
+    else:
+        pbar = tqdm(
+            range(self.config.max_steps),
+            "Test-time adaptation",
+            dynamic_ncols=True,
+        )
+    for step_idx in pbar:
+        if self.config.use_grad_accumulate:
+            for task in self.modelpool.model_names:
+                with self.profiler.profile("data time"):
+                    batch = next(self.get_shuffled_test_loader_iter(task))
+                with self.profiler.profile("forward pass"):
+                    logits = self.compute_logits(module, batch, task)
+                    assert (
+                        logits.dim() == 2
+                    ), f"Expected logits to be 2D, got {logits.dim()}"
+                    loss = entropy_loss(logits)
+                # .backward() accumulates when .zero_grad() wasn't called
+                # this can save memory
+                with self.profiler.profile("backward pass"):
+                    self._fabric.backward(loss, retain_graph=True)
+        else:
+            loss = 0
+            for task in self.modelpool.model_names:
+                with self.profiler.profile("data time"):
+                    batch = next(self.get_shuffled_test_loader_iter(task))
+                with self.profiler.profile("forward pass"):
+                    logits = self.compute_logits(module, batch, task)
+                    assert (
+                        logits.dim() == 2
+                    ), f"Expected logits to be 2D, got {logits.dim()}"
+                    loss = loss + entropy_loss(logits)
+            with self.profiler.profile("backward pass"):
+                self._fabric.backward(loss, retain_graph=True)
+
+        with self.profiler.profile("optimizer step"):
+            optimizer.step()
+            optimizer.zero_grad()
+
+    return module
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +
+ entropy_loss(logits) + +
+ + +
+ +

Compute the entropy loss of a set of logits.

+ + +

Parameters:

+
    +
  • +
    logits +
    (Tensor) + – +
    +

    The logits to compute the entropy loss of.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The entropy loss of the logits.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/we_moe.py +
23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
def entropy_loss(logits: Tensor) -> Tensor:
+    """
+    Compute the entropy loss of a set of logits.
+
+    Args:
+        logits (Tensor): The logits to compute the entropy loss of.
+
+    Returns:
+        Tensor: The entropy loss of the logits.
+    """
+    probs = torch.softmax(logits, dim=-1)
+    return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ clip_we_moe + + +

+ +
+ + + + + + + + +
+ + + + + + + + +
+ + + +
+ CLIPWeightEnsemblingMoEAlgorithm + + +
+ + +
+

+ Bases: WeightEnsemblingMoEAlgorithm, CLIPClassificationMixin

+ + +

CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm +for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

+ + +

Attributes:

+
    +
  • + modelpool + (CLIPVisionModelPool) + – +
    +

    The model pool containing the CLIP models.

    +
    +
  • +
+ + + + + + +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
class CLIPWeightEnsemblingMoEAlgorithm(
+    WeightEnsemblingMoEAlgorithm,
+    CLIPClassificationMixin,
+):
+    """
+    CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
+    for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
+
+    Attributes:
+        modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
+    """
+
+    modelpool: CLIPVisionModelPool = None
+
+    def load_checkpoint(self, model, checkpoint):
+        """
+        Load the checkpoint file.
+
+        Args:
+            model: The model to load the checkpoint into.
+            checkpoint: The path to the checkpoint file.
+        """
+        state = {"model": model}
+        self._fabric.load(checkpoint, state)
+
+    def save_checkpoint(self, model, checkpoint):
+        """
+        Save the checkpoint file.
+
+        Args:
+            model: The model to save the checkpoint from.
+            checkpoint: The path to the checkpoint file.
+        """
+        self._fabric.save(checkpoint, {"model": model})
+
+    def construct_moe_model(self) -> WeightEnsemblingMoE:
+        """
+        Construct the Mixture of Experts (MoE) model using the models in the model pool.
+
+        Returns:
+            WeightEnsemblingMoE: The constructed MoE model.
+        """
+        base_model = self.modelpool.load_model("_pretrained_")
+        expert_models = [
+            self.modelpool.load_model(m) for m in self.modelpool.model_names
+        ]
+
+        # Merge the models using task arithmetic
+        moe_model = task_arithmetic_merge(
+            # This function modifies the model in place, so we need to pass a deepcopy
+            deepcopy(base_model),
+            expert_models,
+            scaling_factor=self.config.init_lambda,
+        ).requires_grad_(False)
+
+        # Up-scale MLP modules
+        base_encoder: CLIPEncoder = base_model.vision_model.encoder
+        moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
+        expert_encoders = [m.vision_model.encoder for m in expert_models]
+
+        num_layers = len(base_encoder.layers)
+        for layer_idx in range(num_layers):
+            base_mlp = base_encoder.layers[layer_idx].mlp
+            expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
+
+            moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
+                hidden_size=base_encoder.config.hidden_size,
+                base_model=base_mlp,
+                expert_models=expert_mlps,
+                init_lambda=self.config.init_lambda,
+                batch_first=True,  # For open_clip models this is False
+                router_hidden_layers=self.config.router_hidden_layers,
+                batch_reduce=self.config.batch_reduce,
+            )
+
+        return moe_model
+
+    @functools.cache
+    def get_shuffled_test_loader_iter(self, tta_dataset: str):
+        """
+        Get an iterator for the shuffled test data loader.
+
+        Args:
+            tta_dataset (str): The name of the test-time adaptation dataset.
+
+        Returns:
+            Iterator: An iterator for the shuffled test data loader.
+        """
+        dataset = self.modelpool.load_test_dataset(tta_dataset)
+        dataset = CLIPDataset(dataset, processor=self.clip_processor)
+        log.info("get_shuffled_test_loader_iter")
+        loader = DataLoader(
+            dataset,
+            batch_size=self.config.batch_size,
+            shuffle=True,
+            num_workers=self.config.num_workers,
+            pin_memory=True,
+        )
+        loader = self.fabric.setup_dataloaders(loader)
+        return iter(InfiniteDataLoader(loader))
+
+    def on_test_time_adaptation_start(self):
+        """
+        Load the CLIP processor and construct the zero-shot classification head for each task.
+        """
+        self.setup_zero_shot_classification_head()
+
+    def compute_logits(self, module, batch, task) -> Tensor:
+        """
+        Compute the logits for the given batch and task.
+
+        Args:
+            module: The model module.
+            batch: The input batch.
+            task: The task name.
+
+        Returns:
+            Tensor: The computed logits.
+        """
+        images, _ = batch
+        text_embeds = self.zeroshot_weights[task]
+
+        image_embeds = module(images)[1]
+        image_embeds = self.visual_projection(image_embeds)
+
+        # Normalize embeddings
+        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+
+        # Cosine similarity
+        logits_per_text = (
+            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
+        )
+        logits_per_image = logits_per_text.t()
+
+        return logits_per_image
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ compute_logits(module, batch, task) + +
+ + +
+ +

Compute the logits for the given batch and task.

+ + +

Parameters:

+
    +
  • + module + – +
    +

    The model module.

    +
    +
  • +
  • + batch + – +
    +

    The input batch.

    +
    +
  • +
  • + task + – +
    +

    The task name.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Tensor ( Tensor +) – +
    +

    The computed logits.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
def compute_logits(self, module, batch, task) -> Tensor:
+    """
+    Compute the logits for the given batch and task.
+
+    Args:
+        module: The model module.
+        batch: The input batch.
+        task: The task name.
+
+    Returns:
+        Tensor: The computed logits.
+    """
+    images, _ = batch
+    text_embeds = self.zeroshot_weights[task]
+
+    image_embeds = module(images)[1]
+    image_embeds = self.visual_projection(image_embeds)
+
+    # Normalize embeddings
+    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+
+    # Cosine similarity
+    logits_per_text = (
+        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
+    )
+    logits_per_image = logits_per_text.t()
+
+    return logits_per_image
+
+
+
+ +
+ +
+ + +
+ construct_moe_model() + +
+ + +
+ +

Construct the Mixture of Experts (MoE) model using the models in the model pool.

+ + +

Returns:

+
    +
  • +WeightEnsemblingMoE ( WeightEnsemblingMoE +) – +
    +

    The constructed MoE model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
def construct_moe_model(self) -> WeightEnsemblingMoE:
+    """
+    Construct the Mixture of Experts (MoE) model using the models in the model pool.
+
+    Returns:
+        WeightEnsemblingMoE: The constructed MoE model.
+    """
+    base_model = self.modelpool.load_model("_pretrained_")
+    expert_models = [
+        self.modelpool.load_model(m) for m in self.modelpool.model_names
+    ]
+
+    # Merge the models using task arithmetic
+    moe_model = task_arithmetic_merge(
+        # This function modifies the model in place, so we need to pass a deepcopy
+        deepcopy(base_model),
+        expert_models,
+        scaling_factor=self.config.init_lambda,
+    ).requires_grad_(False)
+
+    # Up-scale MLP modules
+    base_encoder: CLIPEncoder = base_model.vision_model.encoder
+    moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
+    expert_encoders = [m.vision_model.encoder for m in expert_models]
+
+    num_layers = len(base_encoder.layers)
+    for layer_idx in range(num_layers):
+        base_mlp = base_encoder.layers[layer_idx].mlp
+        expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
+
+        moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
+            hidden_size=base_encoder.config.hidden_size,
+            base_model=base_mlp,
+            expert_models=expert_mlps,
+            init_lambda=self.config.init_lambda,
+            batch_first=True,  # For open_clip models this is False
+            router_hidden_layers=self.config.router_hidden_layers,
+            batch_reduce=self.config.batch_reduce,
+        )
+
+    return moe_model
+
+
+
+ +
+ +
+ + +
+ get_shuffled_test_loader_iter(tta_dataset) + + + cached + + +
+ + +
+ +

Get an iterator for the shuffled test data loader.

+ + +

Parameters:

+
    +
  • + tta_dataset + (str) + – +
    +

    The name of the test-time adaptation dataset.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +Iterator – +
    +

    An iterator for the shuffled test data loader.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
@functools.cache
+def get_shuffled_test_loader_iter(self, tta_dataset: str):
+    """
+    Get an iterator for the shuffled test data loader.
+
+    Args:
+        tta_dataset (str): The name of the test-time adaptation dataset.
+
+    Returns:
+        Iterator: An iterator for the shuffled test data loader.
+    """
+    dataset = self.modelpool.load_test_dataset(tta_dataset)
+    dataset = CLIPDataset(dataset, processor=self.clip_processor)
+    log.info("get_shuffled_test_loader_iter")
+    loader = DataLoader(
+        dataset,
+        batch_size=self.config.batch_size,
+        shuffle=True,
+        num_workers=self.config.num_workers,
+        pin_memory=True,
+    )
+    loader = self.fabric.setup_dataloaders(loader)
+    return iter(InfiniteDataLoader(loader))
+
+
+
+ +
+ +
+ + +
+ load_checkpoint(model, checkpoint) + +
+ + +
+ +

Load the checkpoint file.

+ + +

Parameters:

+
    +
  • + model + – +
    +

    The model to load the checkpoint into.

    +
    +
  • +
  • + checkpoint + – +
    +

    The path to the checkpoint file.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
def load_checkpoint(self, model, checkpoint):
+    """
+    Load the checkpoint file.
+
+    Args:
+        model: The model to load the checkpoint into.
+        checkpoint: The path to the checkpoint file.
+    """
+    state = {"model": model}
+    self._fabric.load(checkpoint, state)
+
+
+
+ +
+ +
+ + +
+ on_test_time_adaptation_start() + +
+ + +
+ +

Load the CLIP processor and construct the zero-shot classification head for each task.

+ +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
def on_test_time_adaptation_start(self):
+    """
+    Load the CLIP processor and construct the zero-shot classification head for each task.
+    """
+    self.setup_zero_shot_classification_head()
+
+
+
+ +
+ +
+ + +
+ save_checkpoint(model, checkpoint) + +
+ + +
+ +

Save the checkpoint file.

+ + +

Parameters:

+
    +
  • + model + – +
    +

    The model to save the checkpoint from.

    +
    +
  • +
  • + checkpoint + – +
    +

    The path to the checkpoint file.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/we_moe/clip_we_moe.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
def save_checkpoint(self, model, checkpoint):
+    """
+    Save the checkpoint file.
+
+    Args:
+        model: The model to save the checkpoint from.
+        checkpoint: The path to the checkpoint file.
+    """
+    self._fabric.save(checkpoint, {"model": model})
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+
+
    +
  1. +

    Anke Tang et.al. ICML 2024. Merging Multi-Task Models via Weight-Ensembling Mixture of Experts. http://arxiv.org/abs/2402.00433 ICML 2024. 

    +
  2. +
  3. +

    Z. Lu, C. Fan, W. Wei, X. Qu, D. Chen, and Y. Cheng, “Twin-Merging: Dynamic Integration of Modular Expertise in Model Merging,” doi: 10.48550/arXiv.2406.15479. NeurIPS 2024. 

    +
  4. +
  5. +

    L. Shen, A. Tang, E. Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging. Oct, 2024. 

    +
  6. +
+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/weighted_averaging/index.html b/algorithms/weighted_averaging/index.html new file mode 100644 index 00000000..105f9d50 --- /dev/null +++ b/algorithms/weighted_averaging/index.html @@ -0,0 +1,3808 @@ + + + + + + + + + + + + + + + + + + + + + + + Weighted Averaging - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Weighted Averaging

+

Weighted averaging, also known as weight-ensembling. +In the context of full fine-tuned models, the weights are averaged according to their respective performance weights. Concretely, this means that if we have \(n\) models with their respective weights \(\theta_i\) and model-wise weights \(w_i\), the weights of the final model \(\theta\) are computed as:

+
\[ \theta = \sum_{i=1}^{n} w_i \theta_i \]
+

Examples

+

General Usage

+

Configuration template for the Weighted Averaging algorithm:

+
config/method/weighted_average.yaml
name: weighted_average
+normalize: true # if true, the weights will be normalized before merging
+weights: # List of weights for each model
+  - 0.5
+  - 0.5
+
+

Use the following command to run the Weighted Averaging algorithm:

+
fusion_bench method=weighted_average ...
+
+

Merge CLIP-ViT Models

+

The following command merges eight clip-ViT models using a weighted average approach. +Because method.normalize is set to true, the weights are normalized to sum to 1, thus equivalent to simple average.

+
fusion_bench \
+    method=weighted_average \
+    method.normalize=true \
+    method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
+    modelpool=clip-vit-base-patch32_TA8_model_only \
+    taskpool=clip-vit-classification_TA8
+
+

Merge Llama/Mistral Models

+

Here is an example of how to use the Weighted Averaging algorithm to merge two LLama models. In particular, LLaMa models of the type transformers.LlamaForCausalLM are merged using the Weighted Averaging algorithm.

+
fusion_bench \
+    method=weighted_average_for_llama \
+    method.merged_model_save_path=outputs/test_merged_llama_model \
+    modelpool=llama_for_causallm \
+    taskpool=dummy
+
+

or using the following configuration file config/llama_weighted_average.yaml

+
fusion_bench --config-name llama_weighted_average
+
+
config/llama_weighted_average.yaml
defaults:
+  - example_config
+  - override method: weighted_average_for_llama
+  - override modelpool: llama_for_causallm
+  - _self_
+
+modelpool:
+  models:
+    # the pre-trained model (base model) is optional
+    # if not provided, the first model will be used as the base model
+    - name: _pretrained_
+      path: meta-llama/Meta-Llama-3-8B
+    - name: expert_1
+      path: meta-llama/Meta-Llama-3-8B
+    - name: expert_2
+      path: meta-llama/Meta-Llama-3-8B-Instruct
+
+method:
+  normalize: true # if true, the weights will be normalized before merging
+  weights: # List of weights for each model
+    - 0.5
+    - 0.5
+  # if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
+  # if false, the whole model will be merged
+  backbone_only: true
+
+  merged_model_save_path: null
+  save_tokenizer: true
+  push_to_hub: false
+
+

References

+ + +
+ + + +

+ WeightedAverageAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm, SimpleProfilerMixin

+ + + + + + + +
+ Source code in fusion_bench/method/weighted_average/weighted_average.py +
class WeightedAverageAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "normalize": "normalize",
+        "weights": "weights",
+    }
+
+    def __init__(
+        self,
+        normalize: bool,
+        weights: List[float],
+        verbose: bool = True,
+        **kwargs,
+    ):
+        self.normalize = normalize
+        self.weights = weights
+        self.verbose = verbose
+        log.disabled = not self.verbose
+        super().__init__(**kwargs)
+
+    @override
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool):
+        """
+        Fuses the models in the model pool using a weighted average approach.
+
+        Parameters
+            modelpool (ModelPool): The pool of models to be fused.
+
+        Raises
+            ValueError: If the number of weights does not match the number of models in the model pool.
+
+        Returns
+            forward_model (torch.nn.Module): The resulting model after fusion.
+        """
+        if not isinstance(modelpool, BaseModelPool):
+            modelpool = BaseModelPool(modelpool)
+
+        log.info("Fusing models using weighted average.")
+        weights = np.asarray(self.weights)
+        if len(weights) != len(modelpool.model_names):
+            raise ValueError(
+                "Number of weights must match the number of models.,"
+                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
+                f"weights: {weights}, models: {modelpool.model_names}"
+            )
+        if self.normalize:
+            weights = weights / np.sum(weights)
+        if self.verbose:
+            print(f"weights: {weights}, normalized: {self.normalize}")
+
+        sd: Optional[StateDictType] = None
+        forward_model = None
+
+        for model_name, weight in zip(modelpool.model_names, weights):
+            with self.profile("load_model"):
+                model = modelpool.load_model(model_name)
+            with self.profile("merge weights"):
+                if sd is None:
+                    sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
+                    forward_model = model
+                else:
+                    sd = state_dict_add(
+                        sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
+                    )
+
+        forward_model.load_state_dict(sd)
+        if self.verbose:
+            self.print_profile_summary()
+        return forward_model
+
+
+ + + +
+ + + + + + + +
+ + + +
+ _config_mapping = BaseAlgorithm._config_mapping | {'normalize': 'normalize', 'weights': 'weights'} + + + class-attribute + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ normalize = normalize + + + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ verbose = verbose + + + instance-attribute + + +
+ + +
+
+ +
+ +
+ + + +
+ weights = weights + + + instance-attribute + + +
+ + +
+
+ +
+ + + +
+ + +
+ __init__(normalize, weights, verbose=True, **kwargs) + +
+ + +
+ +
+ Source code in fusion_bench/method/weighted_average/weighted_average.py +
40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def __init__(
+    self,
+    normalize: bool,
+    weights: List[float],
+    verbose: bool = True,
+    **kwargs,
+):
+    self.normalize = normalize
+    self.weights = weights
+    self.verbose = verbose
+    log.disabled = not self.verbose
+    super().__init__(**kwargs)
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Fuses the models in the model pool using a weighted average approach.

+

Parameters + modelpool (ModelPool): The pool of models to be fused.

+

Raises + ValueError: If the number of weights does not match the number of models in the model pool.

+

Returns + forward_model (torch.nn.Module): The resulting model after fusion.

+ +
+ Source code in fusion_bench/method/weighted_average/weighted_average.py +
@override
+@torch.no_grad()
+def run(self, modelpool: BaseModelPool):
+    """
+    Fuses the models in the model pool using a weighted average approach.
+
+    Parameters
+        modelpool (ModelPool): The pool of models to be fused.
+
+    Raises
+        ValueError: If the number of weights does not match the number of models in the model pool.
+
+    Returns
+        forward_model (torch.nn.Module): The resulting model after fusion.
+    """
+    if not isinstance(modelpool, BaseModelPool):
+        modelpool = BaseModelPool(modelpool)
+
+    log.info("Fusing models using weighted average.")
+    weights = np.asarray(self.weights)
+    if len(weights) != len(modelpool.model_names):
+        raise ValueError(
+            "Number of weights must match the number of models.,"
+            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
+            f"weights: {weights}, models: {modelpool.model_names}"
+        )
+    if self.normalize:
+        weights = weights / np.sum(weights)
+    if self.verbose:
+        print(f"weights: {weights}, normalized: {self.normalize}")
+
+    sd: Optional[StateDictType] = None
+    forward_model = None
+
+    for model_name, weight in zip(modelpool.model_names, weights):
+        with self.profile("load_model"):
+            model = modelpool.load_model(model_name)
+        with self.profile("merge weights"):
+            if sd is None:
+                sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
+                forward_model = model
+            else:
+                sd = state_dict_add(
+                    sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
+                )
+
+    forward_model.load_state_dict(sd)
+    if self.verbose:
+        self.print_profile_summary()
+    return forward_model
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ WeightedAverageForLLama + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + +

A class to perform weighted averaging of LlaMa/Mistral models.

+ + + + + + +
+ Source code in fusion_bench/method/weighted_average/llama.py +
class WeightedAverageForLLama(BaseAlgorithm):
+    """
+    A class to perform weighted averaging of LlaMa/Mistral models.
+    """
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "normalize": "normalize",
+        "weights": "weights",
+        "backbone_only": "backbone_only",
+        "merged_model_save_path": "merged_model_save_path",
+        "save_tokenizer": "save_tokenizer",
+        "push_to_hub": "push_to_hub",
+    }
+
+    def __init__(
+        self,
+        normalize: bool,
+        weights: List[float],
+        backbone_only: bool,
+        merged_model_save_path: str,
+        save_tokenizer: bool,
+        push_to_hub: bool,
+        **kwargs,
+    ):
+        """
+        Initialize the WeightedAverageForLLama class with the given parameters.
+
+        Args:
+            normalize (bool): Whether to normalize the weights.
+            weights (List[float]): The weights for averaging the models.
+            backbone_only (bool): Whether to use only the backbone of the models.
+            merged_model_save_path (str): The path to save the merged model.
+            save_tokenizer (bool): Whether to save the tokenizer.
+            push_to_hub (bool): Whether to push the model to the hub.
+        """
+        self.normalize = normalize
+        self.weights = weights
+        self.backbone_only = backbone_only
+        self.merged_model_save_path = merged_model_save_path
+        self.save_tokenizer = save_tokenizer
+        self.push_to_hub = push_to_hub
+        super().__init__(**kwargs)
+
+    @override
+    @torch.no_grad()
+    def run(self, modelpool: CausalLMPool):
+        """
+        Executes the weighted averaging of models in the provided model pool.
+
+        Args:
+            modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.
+
+        Returns:
+            base_model: The base model after merging the state dictionaries of the models in the pool.
+
+        Raises:
+            ValueError: If the number of weights does not match the number of models in the pool.
+        """
+        if modelpool.has_pretrained:
+            base_model = modelpool.load_model("_pretrained_")
+        else:
+            base_model = modelpool.load_model(modelpool.model_names[0])
+
+        weights = self.weights
+        if len(weights) != len(modelpool.model_names):
+            raise ValueError(
+                "Number of weights must match the number of models.,"
+                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
+                f"weights: {weights}, models: {modelpool.model_names}"
+            )
+        if self.normalize:
+            weights = np.asarray(weights)
+            weights = weights / np.sum(weights)
+
+        merged_state_dict: StateDictType = None
+        for model_name, weight in zip(modelpool.model_names, weights):
+            model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
+            sd = state_dict_mul(model.state_dict(), weight)
+            if merged_state_dict is None:
+                merged_state_dict = sd
+            else:
+                merged_state_dict = state_dict_add(merged_state_dict, sd)
+
+        base_model.load_state_dict(
+            merged_state_dict, strict=False if self.backbone_only else True
+        )
+        if self.merged_model_save_path is not None:
+            with timeit_context(
+                f"Saving the merged model to {self.merged_model_save_path}"
+            ):
+                modelpool.save_model(
+                    base_model,
+                    path=self.merged_model_save_path,
+                    save_tokenizer=self.save_tokenizer,
+                    push_to_hub=self.push_to_hub,
+                )
+        return base_model
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ __init__(normalize, weights, backbone_only, merged_model_save_path, save_tokenizer, push_to_hub, **kwargs) + +
+ + +
+ +

Initialize the WeightedAverageForLLama class with the given parameters.

+ + +

Parameters:

+
    +
  • +
    normalize +
    (bool) + – +
    +

    Whether to normalize the weights.

    +
    +
  • +
  • +
    weights +
    (List[float]) + – +
    +

    The weights for averaging the models.

    +
    +
  • +
  • +
    backbone_only +
    (bool) + – +
    +

    Whether to use only the backbone of the models.

    +
    +
  • +
  • +
    merged_model_save_path +
    (str) + – +
    +

    The path to save the merged model.

    +
    +
  • +
  • +
    save_tokenizer +
    (bool) + – +
    +

    Whether to save the tokenizer.

    +
    +
  • +
  • +
    push_to_hub +
    (bool) + – +
    +

    Whether to push the model to the hub.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/weighted_average/llama.py +
31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
def __init__(
+    self,
+    normalize: bool,
+    weights: List[float],
+    backbone_only: bool,
+    merged_model_save_path: str,
+    save_tokenizer: bool,
+    push_to_hub: bool,
+    **kwargs,
+):
+    """
+    Initialize the WeightedAverageForLLama class with the given parameters.
+
+    Args:
+        normalize (bool): Whether to normalize the weights.
+        weights (List[float]): The weights for averaging the models.
+        backbone_only (bool): Whether to use only the backbone of the models.
+        merged_model_save_path (str): The path to save the merged model.
+        save_tokenizer (bool): Whether to save the tokenizer.
+        push_to_hub (bool): Whether to push the model to the hub.
+    """
+    self.normalize = normalize
+    self.weights = weights
+    self.backbone_only = backbone_only
+    self.merged_model_save_path = merged_model_save_path
+    self.save_tokenizer = save_tokenizer
+    self.push_to_hub = push_to_hub
+    super().__init__(**kwargs)
+
+
+
+ +
+ +
+ + +
+ run(modelpool) + +
+ + +
+ +

Executes the weighted averaging of models in the provided model pool.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (LLamaForCausalLMPoolThe) + – +
    +

    pool of models to be averaged.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +base_model – +
    +

    The base model after merging the state dictionaries of the models in the pool.

    +
    +
  • +
+ + +

Raises:

+
    +
  • + ValueError + – +
    +

    If the number of weights does not match the number of models in the pool.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/weighted_average/llama.py +
@override
+@torch.no_grad()
+def run(self, modelpool: CausalLMPool):
+    """
+    Executes the weighted averaging of models in the provided model pool.
+
+    Args:
+        modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.
+
+    Returns:
+        base_model: The base model after merging the state dictionaries of the models in the pool.
+
+    Raises:
+        ValueError: If the number of weights does not match the number of models in the pool.
+    """
+    if modelpool.has_pretrained:
+        base_model = modelpool.load_model("_pretrained_")
+    else:
+        base_model = modelpool.load_model(modelpool.model_names[0])
+
+    weights = self.weights
+    if len(weights) != len(modelpool.model_names):
+        raise ValueError(
+            "Number of weights must match the number of models.,"
+            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
+            f"weights: {weights}, models: {modelpool.model_names}"
+        )
+    if self.normalize:
+        weights = np.asarray(weights)
+        weights = weights / np.sum(weights)
+
+    merged_state_dict: StateDictType = None
+    for model_name, weight in zip(modelpool.model_names, weights):
+        model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
+        sd = state_dict_mul(model.state_dict(), weight)
+        if merged_state_dict is None:
+            merged_state_dict = sd
+        else:
+            merged_state_dict = state_dict_add(merged_state_dict, sd)
+
+    base_model.load_state_dict(
+        merged_state_dict, strict=False if self.backbone_only else True
+    )
+    if self.merged_model_save_path is not None:
+        with timeit_context(
+            f"Saving the merged model to {self.merged_model_save_path}"
+        ):
+            modelpool.save_model(
+                base_model,
+                path=self.merged_model_save_path,
+                save_tokenizer=self.save_tokenizer,
+                push_to_hub=self.push_to_hub,
+            )
+    return base_model
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/algorithms/weighted_ensemble/index.html b/algorithms/weighted_ensemble/index.html new file mode 100644 index 00000000..8ca8b36d --- /dev/null +++ b/algorithms/weighted_ensemble/index.html @@ -0,0 +1,2757 @@ + + + + + + + + + + + + + + + + + + + + + + + Weighted Ensemble - FusionBench + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Weighted Ensemble

+

A weighted ensemble is a machine learning technique that combines the predictions of multiple models to produce a final prediction. The idea is to leverage the strengths of each individual model to improve overall performance and robustness.

+

Formally, a weighted ensemble can be defined as follows:

+

Given a set of \(n\) models, each model \(f_i\) produces a prediction \(f_i(x)\) for an input \(x\). Each model \(i\) also has an associated weight \(w_i\). The final prediction \(F(x)\) of the weighted ensemble is a weighted sum of the individual model predictions:

+
\[ +F(x) = w_1 f_1(x) + w_2 f_2(x) + ... + w_n f_n(x) +\]
+

The weights \(w_i\) are typically non-negative and sum to 1 (i.e., \(\sum_{i=1}^n w_i = 1\)), which ensures that the final prediction is a convex combination of the individual model predictions. +The weights can be determined in various ways. They could be set based on the performance of the models on a validation set, or they could be learned as part of the training process. In some cases, all models might be given equal weight. +The goal of a weighted ensemble is to produce a final prediction that is more accurate or robust than any individual model. This is particularly useful when the individual models have complementary strengths and weaknesses.

+

Examples

+

The following Python code snippet demonstrates how to use the WeightedEnsembleAlgorithm class from the fusion_bench.method module to create a weighted ensemble of PyTorch models.

+
from omegaconf import DictConfig
+from fusion_bench.method import WeightedEnsembleAlgorithm
+
+#Instantiate the algorithm
+method_config = {'name': 'weighted_ensemble', 'weights': [0.3, 0.7]}
+algorithm = WeightedEnsembleAlgorithm(DictConfig(method_config))
+
+# Assume we have a list of PyTorch models (nn.Module instances) that we want to ensemble.
+models = [...]
+
+# Run the algorithm on the models.
+merged_model = algorithm.run(models)
+
+

Here's a step-by-step explanation:

+
    +
  1. +

    Instantiate the WeightedEnsembleAlgorithm:

    +
      +
    • A dictionary method_config is created with two keys: 'name' and 'weights'. The 'name' key is set to 'weighted_ensemble' indicating the type of ensemble method to use. The 'weights' key is set to a list of weights [0.3, 0.7] indicating the weights assigned to each model in the ensemble.
    • +
    • The method_config dictionary is converted to a DictConfig object, which is a configuration object used by the omegaconf library.
    • +
    • The WeightedEnsembleAlgorithm is then instantiated with the DictConfig object as an argument.
    • +
    +
  2. +
  3. +

    Assume a list of PyTorch models that you want to ensemble. This list is assigned to the variable models. The actual models are not shown in this code snippet.

    +
  4. +
  5. +

    Run the algorithm on the models: The run method of the WeightedEnsembleAlgorithm instance is called with the models list as an argument. The result is a merged model that represents the weighted ensemble of the input models. This merged model is assigned to the variable merged_model.

    +
  6. +
+

Here we list the options for the weighted ensemble algorithm:

+ + + + + + + + + + + + + + + + + + + + +
OptionDefaultDescription
weightsA list of floats representing the weights for each model in the ensemble.
normalizeTrueWhether to normalize the weights so that they sum to 1. Default is True.
+

if normalize is set to True, the weights will be normalized so that they sum to 1. Mathematically, this means that the weights \(w_i\) will be divided by the sum of all weights, so that

+
\[ +F(x) = \frac{w_1}{\sum_{i=1}^n w_i} f_1(x) + \frac{w_2}{\sum_{i=1}^n w_i} f_2(x) + ... + \frac{w_n}{\sum_{i=1}^n w_i} f_n(x) +\]
+

Code Intergration

+

Configuration template for the weighted ensemble algorithm:

+
config/method.weighted_ensemble.yaml
name: weighted_ensemble
+
+# this should be a list of floats, one for each model in the ensemble
+# If weights is null, the ensemble will use the default weights, which are equal weights for all models.
+weights: null
+nomalize: true
+
+

Construct a weighted ensemble using our CLI tool fusion_bench:

+
fusion_bench method=weighted_ensemble \
+    method.weights=[0.3, 0.7] \
+  modelpool=... \
+  taskpool=...
+
+

References

+ + +
+ + + +

+ WeightedEnsembleAlgorithm + + +

+ + +
+

+ Bases: BaseAlgorithm

+ + + + + + + +
+ Source code in fusion_bench/method/ensemble.py +
38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
class WeightedEnsembleAlgorithm(BaseAlgorithm):
+
+    _config_mapping = BaseAlgorithm._config_mapping | {
+        "normalize": "normalize",
+        "weights": "weights",
+    }
+
+    def __init__(self, normalize: bool, weights: List[float], **kwargs):
+        self.normalize = normalize
+        self.weights = weights
+        super().__init__(**kwargs)
+
+    @torch.no_grad()
+    def run(self, modelpool: BaseModelPool | List[nn.Module]):
+        """
+        Run the weighted ensemble algorithm on the given model pool.
+
+        Args:
+            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
+
+        Returns:
+            WeightedEnsembleModule: The weighted ensembled model.
+        """
+        if not isinstance(modelpool, BaseModelPool):
+            modelpool = BaseModelPool(models=modelpool)
+
+        log.info(f"Running weighted ensemble algorithm with {len(modelpool)} models")
+
+        models = [modelpool.load_model(m) for m in modelpool.model_names]
+        if self.weights is None:
+            weights = np.ones(len(models)) / len(models)
+        else:
+            weights = self.weights
+        ensemble = WeightedEnsembleModule(
+            models,
+            weights=weights,
+            normalize=self.config.get("normalize", True),
+        )
+        return ensemble
+
+
+ + + +
+ + + + + + + + + +
+ + +
+ run(modelpool) + +
+ + +
+ +

Run the weighted ensemble algorithm on the given model pool.

+ + +

Parameters:

+
    +
  • +
    modelpool +
    (BaseModelPool | List[Module]) + – +
    +

    The pool of models to ensemble.

    +
    +
  • +
+ + +

Returns:

+
    +
  • +WeightedEnsembleModule – +
    +

    The weighted ensembled model.

    +
    +
  • +
+ +
+ Source code in fusion_bench/method/ensemble.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
@torch.no_grad()
+def run(self, modelpool: BaseModelPool | List[nn.Module]):
+    """
+    Run the weighted ensemble algorithm on the given model pool.
+
+    Args:
+        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.
+
+    Returns:
+        WeightedEnsembleModule: The weighted ensembled model.
+    """
+    if not isinstance(modelpool, BaseModelPool):
+        modelpool = BaseModelPool(models=modelpool)
+
+    log.info(f"Running weighted ensemble algorithm with {len(modelpool)} models")
+
+    models = [modelpool.load_model(m) for m in modelpool.model_names]
+    if self.weights is None:
+        weights = np.ones(len(models)) / len(models)
+    else:
+        weights = self.weights
+    ensemble = WeightedEnsembleModule(
+        models,
+        weights=weights,
+        normalize=self.config.get("normalize", True),
+    )
+    return ensemble
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/_mkdocstrings.css b/assets/_mkdocstrings.css new file mode 100644 index 00000000..b500381b --- /dev/null +++ b/assets/_mkdocstrings.css @@ -0,0 +1,143 @@ + +/* Avoid breaking parameter names, etc. in table cells. */ +.doc-contents td code { + word-break: normal !important; +} + +/* No line break before first paragraph of descriptions. */ +.doc-md-description, +.doc-md-description>p:first-child { + display: inline; +} + +/* Max width for docstring sections tables. */ +.doc .md-typeset__table, +.doc .md-typeset__table table { + display: table !important; + width: 100%; +} + +.doc .md-typeset__table tr { + display: table-row; +} + +/* Defaults in Spacy table style. */ +.doc-param-default { + float: right; +} + +/* Parameter headings must be inline, not blocks. */ +.doc-heading-parameter { + display: inline; +} + +/* Prefer space on the right, not the left of parameter permalinks. */ +.doc-heading-parameter .headerlink { + margin-left: 0 !important; + margin-right: 0.2rem; +} + +/* Backward-compatibility: docstring section titles in bold. */ +.doc-section-title { + font-weight: bold; +} + +/* Symbols in Navigation and ToC. */ +:root, :host, +[data-md-color-scheme="default"] { + --doc-symbol-parameter-fg-color: #df50af; + --doc-symbol-attribute-fg-color: #953800; + --doc-symbol-function-fg-color: #8250df; + --doc-symbol-method-fg-color: #8250df; + --doc-symbol-class-fg-color: #0550ae; + --doc-symbol-module-fg-color: #5cad0f; + + --doc-symbol-parameter-bg-color: #df50af1a; + --doc-symbol-attribute-bg-color: #9538001a; + --doc-symbol-function-bg-color: #8250df1a; + --doc-symbol-method-bg-color: #8250df1a; + --doc-symbol-class-bg-color: #0550ae1a; + --doc-symbol-module-bg-color: #5cad0f1a; +} + +[data-md-color-scheme="slate"] { + --doc-symbol-parameter-fg-color: #ffa8cc; + --doc-symbol-attribute-fg-color: #ffa657; + --doc-symbol-function-fg-color: #d2a8ff; + --doc-symbol-method-fg-color: #d2a8ff; + --doc-symbol-class-fg-color: #79c0ff; + --doc-symbol-module-fg-color: #baff79; + + --doc-symbol-parameter-bg-color: #ffa8cc1a; + --doc-symbol-attribute-bg-color: #ffa6571a; + --doc-symbol-function-bg-color: #d2a8ff1a; + --doc-symbol-method-bg-color: #d2a8ff1a; + --doc-symbol-class-bg-color: #79c0ff1a; + --doc-symbol-module-bg-color: #baff791a; +} + +code.doc-symbol { + border-radius: .1rem; + font-size: .85em; + padding: 0 .3em; + font-weight: bold; +} + +code.doc-symbol-parameter { + color: var(--doc-symbol-parameter-fg-color); + background-color: var(--doc-symbol-parameter-bg-color); +} + +code.doc-symbol-parameter::after { + content: "param"; +} + +code.doc-symbol-attribute { + color: var(--doc-symbol-attribute-fg-color); + background-color: var(--doc-symbol-attribute-bg-color); +} + +code.doc-symbol-attribute::after { + content: "attr"; +} + +code.doc-symbol-function { + color: var(--doc-symbol-function-fg-color); + background-color: var(--doc-symbol-function-bg-color); +} + +code.doc-symbol-function::after { + content: "func"; +} + +code.doc-symbol-method { + color: var(--doc-symbol-method-fg-color); + background-color: var(--doc-symbol-method-bg-color); +} + +code.doc-symbol-method::after { + content: "meth"; +} + +code.doc-symbol-class { + color: var(--doc-symbol-class-fg-color); + background-color: var(--doc-symbol-class-bg-color); +} + +code.doc-symbol-class::after { + content: "class"; +} + +code.doc-symbol-module { + color: var(--doc-symbol-module-fg-color); + background-color: var(--doc-symbol-module-bg-color); +} + +code.doc-symbol-module::after { + content: "mod"; +} + +.doc-signature .autorefs { + color: inherit; + border-bottom: 1px dotted currentcolor; +} diff --git a/assets/images/favicon.png b/assets/images/favicon.png new file mode 100644 index 00000000..1cf13b9f Binary files /dev/null and b/assets/images/favicon.png differ diff --git a/assets/javascripts/bundle.83f73b43.min.js b/assets/javascripts/bundle.83f73b43.min.js new file mode 100644 index 00000000..43d8b70f --- /dev/null +++ b/assets/javascripts/bundle.83f73b43.min.js @@ -0,0 +1,16 @@ +"use strict";(()=>{var Wi=Object.create;var gr=Object.defineProperty;var Di=Object.getOwnPropertyDescriptor;var Vi=Object.getOwnPropertyNames,Vt=Object.getOwnPropertySymbols,Ni=Object.getPrototypeOf,yr=Object.prototype.hasOwnProperty,ao=Object.prototype.propertyIsEnumerable;var io=(e,t,r)=>t in e?gr(e,t,{enumerable:!0,configurable:!0,writable:!0,value:r}):e[t]=r,$=(e,t)=>{for(var r in t||(t={}))yr.call(t,r)&&io(e,r,t[r]);if(Vt)for(var r of Vt(t))ao.call(t,r)&&io(e,r,t[r]);return e};var so=(e,t)=>{var r={};for(var o in e)yr.call(e,o)&&t.indexOf(o)<0&&(r[o]=e[o]);if(e!=null&&Vt)for(var o of Vt(e))t.indexOf(o)<0&&ao.call(e,o)&&(r[o]=e[o]);return r};var xr=(e,t)=>()=>(t||e((t={exports:{}}).exports,t),t.exports);var zi=(e,t,r,o)=>{if(t&&typeof t=="object"||typeof t=="function")for(let n of Vi(t))!yr.call(e,n)&&n!==r&&gr(e,n,{get:()=>t[n],enumerable:!(o=Di(t,n))||o.enumerable});return e};var Mt=(e,t,r)=>(r=e!=null?Wi(Ni(e)):{},zi(t||!e||!e.__esModule?gr(r,"default",{value:e,enumerable:!0}):r,e));var co=(e,t,r)=>new Promise((o,n)=>{var i=p=>{try{s(r.next(p))}catch(c){n(c)}},a=p=>{try{s(r.throw(p))}catch(c){n(c)}},s=p=>p.done?o(p.value):Promise.resolve(p.value).then(i,a);s((r=r.apply(e,t)).next())});var lo=xr((Er,po)=>{(function(e,t){typeof Er=="object"&&typeof po!="undefined"?t():typeof define=="function"&&define.amd?define(t):t()})(Er,function(){"use strict";function e(r){var o=!0,n=!1,i=null,a={text:!0,search:!0,url:!0,tel:!0,email:!0,password:!0,number:!0,date:!0,month:!0,week:!0,time:!0,datetime:!0,"datetime-local":!0};function s(k){return!!(k&&k!==document&&k.nodeName!=="HTML"&&k.nodeName!=="BODY"&&"classList"in k&&"contains"in k.classList)}function p(k){var ft=k.type,qe=k.tagName;return!!(qe==="INPUT"&&a[ft]&&!k.readOnly||qe==="TEXTAREA"&&!k.readOnly||k.isContentEditable)}function c(k){k.classList.contains("focus-visible")||(k.classList.add("focus-visible"),k.setAttribute("data-focus-visible-added",""))}function l(k){k.hasAttribute("data-focus-visible-added")&&(k.classList.remove("focus-visible"),k.removeAttribute("data-focus-visible-added"))}function f(k){k.metaKey||k.altKey||k.ctrlKey||(s(r.activeElement)&&c(r.activeElement),o=!0)}function u(k){o=!1}function d(k){s(k.target)&&(o||p(k.target))&&c(k.target)}function y(k){s(k.target)&&(k.target.classList.contains("focus-visible")||k.target.hasAttribute("data-focus-visible-added"))&&(n=!0,window.clearTimeout(i),i=window.setTimeout(function(){n=!1},100),l(k.target))}function L(k){document.visibilityState==="hidden"&&(n&&(o=!0),X())}function X(){document.addEventListener("mousemove",J),document.addEventListener("mousedown",J),document.addEventListener("mouseup",J),document.addEventListener("pointermove",J),document.addEventListener("pointerdown",J),document.addEventListener("pointerup",J),document.addEventListener("touchmove",J),document.addEventListener("touchstart",J),document.addEventListener("touchend",J)}function te(){document.removeEventListener("mousemove",J),document.removeEventListener("mousedown",J),document.removeEventListener("mouseup",J),document.removeEventListener("pointermove",J),document.removeEventListener("pointerdown",J),document.removeEventListener("pointerup",J),document.removeEventListener("touchmove",J),document.removeEventListener("touchstart",J),document.removeEventListener("touchend",J)}function J(k){k.target.nodeName&&k.target.nodeName.toLowerCase()==="html"||(o=!1,te())}document.addEventListener("keydown",f,!0),document.addEventListener("mousedown",u,!0),document.addEventListener("pointerdown",u,!0),document.addEventListener("touchstart",u,!0),document.addEventListener("visibilitychange",L,!0),X(),r.addEventListener("focus",d,!0),r.addEventListener("blur",y,!0),r.nodeType===Node.DOCUMENT_FRAGMENT_NODE&&r.host?r.host.setAttribute("data-js-focus-visible",""):r.nodeType===Node.DOCUMENT_NODE&&(document.documentElement.classList.add("js-focus-visible"),document.documentElement.setAttribute("data-js-focus-visible",""))}if(typeof window!="undefined"&&typeof document!="undefined"){window.applyFocusVisiblePolyfill=e;var t;try{t=new CustomEvent("focus-visible-polyfill-ready")}catch(r){t=document.createEvent("CustomEvent"),t.initCustomEvent("focus-visible-polyfill-ready",!1,!1,{})}window.dispatchEvent(t)}typeof document!="undefined"&&e(document)})});var qr=xr((hy,On)=>{"use strict";/*! + * escape-html + * Copyright(c) 2012-2013 TJ Holowaychuk + * Copyright(c) 2015 Andreas Lubbe + * Copyright(c) 2015 Tiancheng "Timothy" Gu + * MIT Licensed + */var $a=/["'&<>]/;On.exports=Pa;function Pa(e){var t=""+e,r=$a.exec(t);if(!r)return t;var o,n="",i=0,a=0;for(i=r.index;i{/*! + * clipboard.js v2.0.11 + * https://clipboardjs.com/ + * + * Licensed MIT © Zeno Rocha + */(function(t,r){typeof It=="object"&&typeof Yr=="object"?Yr.exports=r():typeof define=="function"&&define.amd?define([],r):typeof It=="object"?It.ClipboardJS=r():t.ClipboardJS=r()})(It,function(){return function(){var e={686:function(o,n,i){"use strict";i.d(n,{default:function(){return Ui}});var a=i(279),s=i.n(a),p=i(370),c=i.n(p),l=i(817),f=i.n(l);function u(V){try{return document.execCommand(V)}catch(A){return!1}}var d=function(A){var M=f()(A);return u("cut"),M},y=d;function L(V){var A=document.documentElement.getAttribute("dir")==="rtl",M=document.createElement("textarea");M.style.fontSize="12pt",M.style.border="0",M.style.padding="0",M.style.margin="0",M.style.position="absolute",M.style[A?"right":"left"]="-9999px";var F=window.pageYOffset||document.documentElement.scrollTop;return M.style.top="".concat(F,"px"),M.setAttribute("readonly",""),M.value=V,M}var X=function(A,M){var F=L(A);M.container.appendChild(F);var D=f()(F);return u("copy"),F.remove(),D},te=function(A){var M=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body},F="";return typeof A=="string"?F=X(A,M):A instanceof HTMLInputElement&&!["text","search","url","tel","password"].includes(A==null?void 0:A.type)?F=X(A.value,M):(F=f()(A),u("copy")),F},J=te;function k(V){"@babel/helpers - typeof";return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?k=function(M){return typeof M}:k=function(M){return M&&typeof Symbol=="function"&&M.constructor===Symbol&&M!==Symbol.prototype?"symbol":typeof M},k(V)}var ft=function(){var A=arguments.length>0&&arguments[0]!==void 0?arguments[0]:{},M=A.action,F=M===void 0?"copy":M,D=A.container,Y=A.target,$e=A.text;if(F!=="copy"&&F!=="cut")throw new Error('Invalid "action" value, use either "copy" or "cut"');if(Y!==void 0)if(Y&&k(Y)==="object"&&Y.nodeType===1){if(F==="copy"&&Y.hasAttribute("disabled"))throw new Error('Invalid "target" attribute. Please use "readonly" instead of "disabled" attribute');if(F==="cut"&&(Y.hasAttribute("readonly")||Y.hasAttribute("disabled")))throw new Error(`Invalid "target" attribute. You can't cut text from elements with "readonly" or "disabled" attributes`)}else throw new Error('Invalid "target" value, use a valid Element');if($e)return J($e,{container:D});if(Y)return F==="cut"?y(Y):J(Y,{container:D})},qe=ft;function Fe(V){"@babel/helpers - typeof";return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?Fe=function(M){return typeof M}:Fe=function(M){return M&&typeof Symbol=="function"&&M.constructor===Symbol&&M!==Symbol.prototype?"symbol":typeof M},Fe(V)}function ki(V,A){if(!(V instanceof A))throw new TypeError("Cannot call a class as a function")}function no(V,A){for(var M=0;M0&&arguments[0]!==void 0?arguments[0]:{};this.action=typeof D.action=="function"?D.action:this.defaultAction,this.target=typeof D.target=="function"?D.target:this.defaultTarget,this.text=typeof D.text=="function"?D.text:this.defaultText,this.container=Fe(D.container)==="object"?D.container:document.body}},{key:"listenClick",value:function(D){var Y=this;this.listener=c()(D,"click",function($e){return Y.onClick($e)})}},{key:"onClick",value:function(D){var Y=D.delegateTarget||D.currentTarget,$e=this.action(Y)||"copy",Dt=qe({action:$e,container:this.container,target:this.target(Y),text:this.text(Y)});this.emit(Dt?"success":"error",{action:$e,text:Dt,trigger:Y,clearSelection:function(){Y&&Y.focus(),window.getSelection().removeAllRanges()}})}},{key:"defaultAction",value:function(D){return vr("action",D)}},{key:"defaultTarget",value:function(D){var Y=vr("target",D);if(Y)return document.querySelector(Y)}},{key:"defaultText",value:function(D){return vr("text",D)}},{key:"destroy",value:function(){this.listener.destroy()}}],[{key:"copy",value:function(D){var Y=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body};return J(D,Y)}},{key:"cut",value:function(D){return y(D)}},{key:"isSupported",value:function(){var D=arguments.length>0&&arguments[0]!==void 0?arguments[0]:["copy","cut"],Y=typeof D=="string"?[D]:D,$e=!!document.queryCommandSupported;return Y.forEach(function(Dt){$e=$e&&!!document.queryCommandSupported(Dt)}),$e}}]),M}(s()),Ui=Fi},828:function(o){var n=9;if(typeof Element!="undefined"&&!Element.prototype.matches){var i=Element.prototype;i.matches=i.matchesSelector||i.mozMatchesSelector||i.msMatchesSelector||i.oMatchesSelector||i.webkitMatchesSelector}function a(s,p){for(;s&&s.nodeType!==n;){if(typeof s.matches=="function"&&s.matches(p))return s;s=s.parentNode}}o.exports=a},438:function(o,n,i){var a=i(828);function s(l,f,u,d,y){var L=c.apply(this,arguments);return l.addEventListener(u,L,y),{destroy:function(){l.removeEventListener(u,L,y)}}}function p(l,f,u,d,y){return typeof l.addEventListener=="function"?s.apply(null,arguments):typeof u=="function"?s.bind(null,document).apply(null,arguments):(typeof l=="string"&&(l=document.querySelectorAll(l)),Array.prototype.map.call(l,function(L){return s(L,f,u,d,y)}))}function c(l,f,u,d){return function(y){y.delegateTarget=a(y.target,f),y.delegateTarget&&d.call(l,y)}}o.exports=p},879:function(o,n){n.node=function(i){return i!==void 0&&i instanceof HTMLElement&&i.nodeType===1},n.nodeList=function(i){var a=Object.prototype.toString.call(i);return i!==void 0&&(a==="[object NodeList]"||a==="[object HTMLCollection]")&&"length"in i&&(i.length===0||n.node(i[0]))},n.string=function(i){return typeof i=="string"||i instanceof String},n.fn=function(i){var a=Object.prototype.toString.call(i);return a==="[object Function]"}},370:function(o,n,i){var a=i(879),s=i(438);function p(u,d,y){if(!u&&!d&&!y)throw new Error("Missing required arguments");if(!a.string(d))throw new TypeError("Second argument must be a String");if(!a.fn(y))throw new TypeError("Third argument must be a Function");if(a.node(u))return c(u,d,y);if(a.nodeList(u))return l(u,d,y);if(a.string(u))return f(u,d,y);throw new TypeError("First argument must be a String, HTMLElement, HTMLCollection, or NodeList")}function c(u,d,y){return u.addEventListener(d,y),{destroy:function(){u.removeEventListener(d,y)}}}function l(u,d,y){return Array.prototype.forEach.call(u,function(L){L.addEventListener(d,y)}),{destroy:function(){Array.prototype.forEach.call(u,function(L){L.removeEventListener(d,y)})}}}function f(u,d,y){return s(document.body,u,d,y)}o.exports=p},817:function(o){function n(i){var a;if(i.nodeName==="SELECT")i.focus(),a=i.value;else if(i.nodeName==="INPUT"||i.nodeName==="TEXTAREA"){var s=i.hasAttribute("readonly");s||i.setAttribute("readonly",""),i.select(),i.setSelectionRange(0,i.value.length),s||i.removeAttribute("readonly"),a=i.value}else{i.hasAttribute("contenteditable")&&i.focus();var p=window.getSelection(),c=document.createRange();c.selectNodeContents(i),p.removeAllRanges(),p.addRange(c),a=p.toString()}return a}o.exports=n},279:function(o){function n(){}n.prototype={on:function(i,a,s){var p=this.e||(this.e={});return(p[i]||(p[i]=[])).push({fn:a,ctx:s}),this},once:function(i,a,s){var p=this;function c(){p.off(i,c),a.apply(s,arguments)}return c._=a,this.on(i,c,s)},emit:function(i){var a=[].slice.call(arguments,1),s=((this.e||(this.e={}))[i]||[]).slice(),p=0,c=s.length;for(p;p0&&i[i.length-1])&&(c[0]===6||c[0]===2)){r=0;continue}if(c[0]===3&&(!i||c[1]>i[0]&&c[1]=e.length&&(e=void 0),{value:e&&e[o++],done:!e}}};throw new TypeError(t?"Object is not iterable.":"Symbol.iterator is not defined.")}function N(e,t){var r=typeof Symbol=="function"&&e[Symbol.iterator];if(!r)return e;var o=r.call(e),n,i=[],a;try{for(;(t===void 0||t-- >0)&&!(n=o.next()).done;)i.push(n.value)}catch(s){a={error:s}}finally{try{n&&!n.done&&(r=o.return)&&r.call(o)}finally{if(a)throw a.error}}return i}function q(e,t,r){if(r||arguments.length===2)for(var o=0,n=t.length,i;o1||p(d,L)})},y&&(n[d]=y(n[d])))}function p(d,y){try{c(o[d](y))}catch(L){u(i[0][3],L)}}function c(d){d.value instanceof nt?Promise.resolve(d.value.v).then(l,f):u(i[0][2],d)}function l(d){p("next",d)}function f(d){p("throw",d)}function u(d,y){d(y),i.shift(),i.length&&p(i[0][0],i[0][1])}}function uo(e){if(!Symbol.asyncIterator)throw new TypeError("Symbol.asyncIterator is not defined.");var t=e[Symbol.asyncIterator],r;return t?t.call(e):(e=typeof he=="function"?he(e):e[Symbol.iterator](),r={},o("next"),o("throw"),o("return"),r[Symbol.asyncIterator]=function(){return this},r);function o(i){r[i]=e[i]&&function(a){return new Promise(function(s,p){a=e[i](a),n(s,p,a.done,a.value)})}}function n(i,a,s,p){Promise.resolve(p).then(function(c){i({value:c,done:s})},a)}}function H(e){return typeof e=="function"}function ut(e){var t=function(o){Error.call(o),o.stack=new Error().stack},r=e(t);return r.prototype=Object.create(Error.prototype),r.prototype.constructor=r,r}var zt=ut(function(e){return function(r){e(this),this.message=r?r.length+` errors occurred during unsubscription: +`+r.map(function(o,n){return n+1+") "+o.toString()}).join(` + `):"",this.name="UnsubscriptionError",this.errors=r}});function Qe(e,t){if(e){var r=e.indexOf(t);0<=r&&e.splice(r,1)}}var Ue=function(){function e(t){this.initialTeardown=t,this.closed=!1,this._parentage=null,this._finalizers=null}return e.prototype.unsubscribe=function(){var t,r,o,n,i;if(!this.closed){this.closed=!0;var a=this._parentage;if(a)if(this._parentage=null,Array.isArray(a))try{for(var s=he(a),p=s.next();!p.done;p=s.next()){var c=p.value;c.remove(this)}}catch(L){t={error:L}}finally{try{p&&!p.done&&(r=s.return)&&r.call(s)}finally{if(t)throw t.error}}else a.remove(this);var l=this.initialTeardown;if(H(l))try{l()}catch(L){i=L instanceof zt?L.errors:[L]}var f=this._finalizers;if(f){this._finalizers=null;try{for(var u=he(f),d=u.next();!d.done;d=u.next()){var y=d.value;try{ho(y)}catch(L){i=i!=null?i:[],L instanceof zt?i=q(q([],N(i)),N(L.errors)):i.push(L)}}}catch(L){o={error:L}}finally{try{d&&!d.done&&(n=u.return)&&n.call(u)}finally{if(o)throw o.error}}}if(i)throw new zt(i)}},e.prototype.add=function(t){var r;if(t&&t!==this)if(this.closed)ho(t);else{if(t instanceof e){if(t.closed||t._hasParent(this))return;t._addParent(this)}(this._finalizers=(r=this._finalizers)!==null&&r!==void 0?r:[]).push(t)}},e.prototype._hasParent=function(t){var r=this._parentage;return r===t||Array.isArray(r)&&r.includes(t)},e.prototype._addParent=function(t){var r=this._parentage;this._parentage=Array.isArray(r)?(r.push(t),r):r?[r,t]:t},e.prototype._removeParent=function(t){var r=this._parentage;r===t?this._parentage=null:Array.isArray(r)&&Qe(r,t)},e.prototype.remove=function(t){var r=this._finalizers;r&&Qe(r,t),t instanceof e&&t._removeParent(this)},e.EMPTY=function(){var t=new e;return t.closed=!0,t}(),e}();var Tr=Ue.EMPTY;function qt(e){return e instanceof Ue||e&&"closed"in e&&H(e.remove)&&H(e.add)&&H(e.unsubscribe)}function ho(e){H(e)?e():e.unsubscribe()}var Pe={onUnhandledError:null,onStoppedNotification:null,Promise:void 0,useDeprecatedSynchronousErrorHandling:!1,useDeprecatedNextContext:!1};var dt={setTimeout:function(e,t){for(var r=[],o=2;o0},enumerable:!1,configurable:!0}),t.prototype._trySubscribe=function(r){return this._throwIfClosed(),e.prototype._trySubscribe.call(this,r)},t.prototype._subscribe=function(r){return this._throwIfClosed(),this._checkFinalizedStatuses(r),this._innerSubscribe(r)},t.prototype._innerSubscribe=function(r){var o=this,n=this,i=n.hasError,a=n.isStopped,s=n.observers;return i||a?Tr:(this.currentObservers=null,s.push(r),new Ue(function(){o.currentObservers=null,Qe(s,r)}))},t.prototype._checkFinalizedStatuses=function(r){var o=this,n=o.hasError,i=o.thrownError,a=o.isStopped;n?r.error(i):a&&r.complete()},t.prototype.asObservable=function(){var r=new j;return r.source=this,r},t.create=function(r,o){return new To(r,o)},t}(j);var To=function(e){oe(t,e);function t(r,o){var n=e.call(this)||this;return n.destination=r,n.source=o,n}return t.prototype.next=function(r){var o,n;(n=(o=this.destination)===null||o===void 0?void 0:o.next)===null||n===void 0||n.call(o,r)},t.prototype.error=function(r){var o,n;(n=(o=this.destination)===null||o===void 0?void 0:o.error)===null||n===void 0||n.call(o,r)},t.prototype.complete=function(){var r,o;(o=(r=this.destination)===null||r===void 0?void 0:r.complete)===null||o===void 0||o.call(r)},t.prototype._subscribe=function(r){var o,n;return(n=(o=this.source)===null||o===void 0?void 0:o.subscribe(r))!==null&&n!==void 0?n:Tr},t}(g);var _r=function(e){oe(t,e);function t(r){var o=e.call(this)||this;return o._value=r,o}return Object.defineProperty(t.prototype,"value",{get:function(){return this.getValue()},enumerable:!1,configurable:!0}),t.prototype._subscribe=function(r){var o=e.prototype._subscribe.call(this,r);return!o.closed&&r.next(this._value),o},t.prototype.getValue=function(){var r=this,o=r.hasError,n=r.thrownError,i=r._value;if(o)throw n;return this._throwIfClosed(),i},t.prototype.next=function(r){e.prototype.next.call(this,this._value=r)},t}(g);var At={now:function(){return(At.delegate||Date).now()},delegate:void 0};var Ct=function(e){oe(t,e);function t(r,o,n){r===void 0&&(r=1/0),o===void 0&&(o=1/0),n===void 0&&(n=At);var i=e.call(this)||this;return i._bufferSize=r,i._windowTime=o,i._timestampProvider=n,i._buffer=[],i._infiniteTimeWindow=!0,i._infiniteTimeWindow=o===1/0,i._bufferSize=Math.max(1,r),i._windowTime=Math.max(1,o),i}return t.prototype.next=function(r){var o=this,n=o.isStopped,i=o._buffer,a=o._infiniteTimeWindow,s=o._timestampProvider,p=o._windowTime;n||(i.push(r),!a&&i.push(s.now()+p)),this._trimBuffer(),e.prototype.next.call(this,r)},t.prototype._subscribe=function(r){this._throwIfClosed(),this._trimBuffer();for(var o=this._innerSubscribe(r),n=this,i=n._infiniteTimeWindow,a=n._buffer,s=a.slice(),p=0;p0?e.prototype.schedule.call(this,r,o):(this.delay=o,this.state=r,this.scheduler.flush(this),this)},t.prototype.execute=function(r,o){return o>0||this.closed?e.prototype.execute.call(this,r,o):this._execute(r,o)},t.prototype.requestAsyncId=function(r,o,n){return n===void 0&&(n=0),n!=null&&n>0||n==null&&this.delay>0?e.prototype.requestAsyncId.call(this,r,o,n):(r.flush(this),0)},t}(gt);var Lo=function(e){oe(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t}(yt);var kr=new Lo(Oo);var Mo=function(e){oe(t,e);function t(r,o){var n=e.call(this,r,o)||this;return n.scheduler=r,n.work=o,n}return t.prototype.requestAsyncId=function(r,o,n){return n===void 0&&(n=0),n!==null&&n>0?e.prototype.requestAsyncId.call(this,r,o,n):(r.actions.push(this),r._scheduled||(r._scheduled=vt.requestAnimationFrame(function(){return r.flush(void 0)})))},t.prototype.recycleAsyncId=function(r,o,n){var i;if(n===void 0&&(n=0),n!=null?n>0:this.delay>0)return e.prototype.recycleAsyncId.call(this,r,o,n);var a=r.actions;o!=null&&((i=a[a.length-1])===null||i===void 0?void 0:i.id)!==o&&(vt.cancelAnimationFrame(o),r._scheduled=void 0)},t}(gt);var _o=function(e){oe(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t.prototype.flush=function(r){this._active=!0;var o=this._scheduled;this._scheduled=void 0;var n=this.actions,i;r=r||n.shift();do if(i=r.execute(r.state,r.delay))break;while((r=n[0])&&r.id===o&&n.shift());if(this._active=!1,i){for(;(r=n[0])&&r.id===o&&n.shift();)r.unsubscribe();throw i}},t}(yt);var me=new _o(Mo);var S=new j(function(e){return e.complete()});function Yt(e){return e&&H(e.schedule)}function Hr(e){return e[e.length-1]}function Xe(e){return H(Hr(e))?e.pop():void 0}function ke(e){return Yt(Hr(e))?e.pop():void 0}function Bt(e,t){return typeof Hr(e)=="number"?e.pop():t}var xt=function(e){return e&&typeof e.length=="number"&&typeof e!="function"};function Gt(e){return H(e==null?void 0:e.then)}function Jt(e){return H(e[bt])}function Xt(e){return Symbol.asyncIterator&&H(e==null?void 0:e[Symbol.asyncIterator])}function Zt(e){return new TypeError("You provided "+(e!==null&&typeof e=="object"?"an invalid object":"'"+e+"'")+" where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.")}function Zi(){return typeof Symbol!="function"||!Symbol.iterator?"@@iterator":Symbol.iterator}var er=Zi();function tr(e){return H(e==null?void 0:e[er])}function rr(e){return fo(this,arguments,function(){var r,o,n,i;return Nt(this,function(a){switch(a.label){case 0:r=e.getReader(),a.label=1;case 1:a.trys.push([1,,9,10]),a.label=2;case 2:return[4,nt(r.read())];case 3:return o=a.sent(),n=o.value,i=o.done,i?[4,nt(void 0)]:[3,5];case 4:return[2,a.sent()];case 5:return[4,nt(n)];case 6:return[4,a.sent()];case 7:return a.sent(),[3,2];case 8:return[3,10];case 9:return r.releaseLock(),[7];case 10:return[2]}})})}function or(e){return H(e==null?void 0:e.getReader)}function U(e){if(e instanceof j)return e;if(e!=null){if(Jt(e))return ea(e);if(xt(e))return ta(e);if(Gt(e))return ra(e);if(Xt(e))return Ao(e);if(tr(e))return oa(e);if(or(e))return na(e)}throw Zt(e)}function ea(e){return new j(function(t){var r=e[bt]();if(H(r.subscribe))return r.subscribe(t);throw new TypeError("Provided object does not correctly implement Symbol.observable")})}function ta(e){return new j(function(t){for(var r=0;r=2;return function(o){return o.pipe(e?b(function(n,i){return e(n,i,o)}):le,Te(1),r?De(t):Qo(function(){return new ir}))}}function jr(e){return e<=0?function(){return S}:E(function(t,r){var o=[];t.subscribe(T(r,function(n){o.push(n),e=2,!0))}function pe(e){e===void 0&&(e={});var t=e.connector,r=t===void 0?function(){return new g}:t,o=e.resetOnError,n=o===void 0?!0:o,i=e.resetOnComplete,a=i===void 0?!0:i,s=e.resetOnRefCountZero,p=s===void 0?!0:s;return function(c){var l,f,u,d=0,y=!1,L=!1,X=function(){f==null||f.unsubscribe(),f=void 0},te=function(){X(),l=u=void 0,y=L=!1},J=function(){var k=l;te(),k==null||k.unsubscribe()};return E(function(k,ft){d++,!L&&!y&&X();var qe=u=u!=null?u:r();ft.add(function(){d--,d===0&&!L&&!y&&(f=Ur(J,p))}),qe.subscribe(ft),!l&&d>0&&(l=new at({next:function(Fe){return qe.next(Fe)},error:function(Fe){L=!0,X(),f=Ur(te,n,Fe),qe.error(Fe)},complete:function(){y=!0,X(),f=Ur(te,a),qe.complete()}}),U(k).subscribe(l))})(c)}}function Ur(e,t){for(var r=[],o=2;oe.next(document)),e}function P(e,t=document){return Array.from(t.querySelectorAll(e))}function R(e,t=document){let r=fe(e,t);if(typeof r=="undefined")throw new ReferenceError(`Missing element: expected "${e}" to be present`);return r}function fe(e,t=document){return t.querySelector(e)||void 0}function Ie(){var e,t,r,o;return(o=(r=(t=(e=document.activeElement)==null?void 0:e.shadowRoot)==null?void 0:t.activeElement)!=null?r:document.activeElement)!=null?o:void 0}var wa=O(h(document.body,"focusin"),h(document.body,"focusout")).pipe(_e(1),Q(void 0),m(()=>Ie()||document.body),G(1));function et(e){return wa.pipe(m(t=>e.contains(t)),K())}function $t(e,t){return C(()=>O(h(e,"mouseenter").pipe(m(()=>!0)),h(e,"mouseleave").pipe(m(()=>!1))).pipe(t?Ht(r=>Le(+!r*t)):le,Q(e.matches(":hover"))))}function Jo(e,t){if(typeof t=="string"||typeof t=="number")e.innerHTML+=t.toString();else if(t instanceof Node)e.appendChild(t);else if(Array.isArray(t))for(let r of t)Jo(e,r)}function x(e,t,...r){let o=document.createElement(e);if(t)for(let n of Object.keys(t))typeof t[n]!="undefined"&&(typeof t[n]!="boolean"?o.setAttribute(n,t[n]):o.setAttribute(n,""));for(let n of r)Jo(o,n);return o}function sr(e){if(e>999){let t=+((e-950)%1e3>99);return`${((e+1e-6)/1e3).toFixed(t)}k`}else return e.toString()}function Tt(e){let t=x("script",{src:e});return C(()=>(document.head.appendChild(t),O(h(t,"load"),h(t,"error").pipe(v(()=>$r(()=>new ReferenceError(`Invalid script: ${e}`))))).pipe(m(()=>{}),_(()=>document.head.removeChild(t)),Te(1))))}var Xo=new g,Ta=C(()=>typeof ResizeObserver=="undefined"?Tt("https://unpkg.com/resize-observer-polyfill"):I(void 0)).pipe(m(()=>new ResizeObserver(e=>e.forEach(t=>Xo.next(t)))),v(e=>O(Ye,I(e)).pipe(_(()=>e.disconnect()))),G(1));function ce(e){return{width:e.offsetWidth,height:e.offsetHeight}}function ge(e){let t=e;for(;t.clientWidth===0&&t.parentElement;)t=t.parentElement;return Ta.pipe(w(r=>r.observe(t)),v(r=>Xo.pipe(b(o=>o.target===t),_(()=>r.unobserve(t)))),m(()=>ce(e)),Q(ce(e)))}function St(e){return{width:e.scrollWidth,height:e.scrollHeight}}function cr(e){let t=e.parentElement;for(;t&&(e.scrollWidth<=t.scrollWidth&&e.scrollHeight<=t.scrollHeight);)t=(e=t).parentElement;return t?e:void 0}function Zo(e){let t=[],r=e.parentElement;for(;r;)(e.clientWidth>r.clientWidth||e.clientHeight>r.clientHeight)&&t.push(r),r=(e=r).parentElement;return t.length===0&&t.push(document.documentElement),t}function Ve(e){return{x:e.offsetLeft,y:e.offsetTop}}function en(e){let t=e.getBoundingClientRect();return{x:t.x+window.scrollX,y:t.y+window.scrollY}}function tn(e){return O(h(window,"load"),h(window,"resize")).pipe(Me(0,me),m(()=>Ve(e)),Q(Ve(e)))}function pr(e){return{x:e.scrollLeft,y:e.scrollTop}}function Ne(e){return O(h(e,"scroll"),h(window,"scroll"),h(window,"resize")).pipe(Me(0,me),m(()=>pr(e)),Q(pr(e)))}var rn=new g,Sa=C(()=>I(new IntersectionObserver(e=>{for(let t of e)rn.next(t)},{threshold:0}))).pipe(v(e=>O(Ye,I(e)).pipe(_(()=>e.disconnect()))),G(1));function tt(e){return Sa.pipe(w(t=>t.observe(e)),v(t=>rn.pipe(b(({target:r})=>r===e),_(()=>t.unobserve(e)),m(({isIntersecting:r})=>r))))}function on(e,t=16){return Ne(e).pipe(m(({y:r})=>{let o=ce(e),n=St(e);return r>=n.height-o.height-t}),K())}var lr={drawer:R("[data-md-toggle=drawer]"),search:R("[data-md-toggle=search]")};function nn(e){return lr[e].checked}function Je(e,t){lr[e].checked!==t&&lr[e].click()}function ze(e){let t=lr[e];return h(t,"change").pipe(m(()=>t.checked),Q(t.checked))}function Oa(e,t){switch(e.constructor){case HTMLInputElement:return e.type==="radio"?/^Arrow/.test(t):!0;case HTMLSelectElement:case HTMLTextAreaElement:return!0;default:return e.isContentEditable}}function La(){return O(h(window,"compositionstart").pipe(m(()=>!0)),h(window,"compositionend").pipe(m(()=>!1))).pipe(Q(!1))}function an(){let e=h(window,"keydown").pipe(b(t=>!(t.metaKey||t.ctrlKey)),m(t=>({mode:nn("search")?"search":"global",type:t.key,claim(){t.preventDefault(),t.stopPropagation()}})),b(({mode:t,type:r})=>{if(t==="global"){let o=Ie();if(typeof o!="undefined")return!Oa(o,r)}return!0}),pe());return La().pipe(v(t=>t?S:e))}function ye(){return new URL(location.href)}function lt(e,t=!1){if(B("navigation.instant")&&!t){let r=x("a",{href:e.href});document.body.appendChild(r),r.click(),r.remove()}else location.href=e.href}function sn(){return new g}function cn(){return location.hash.slice(1)}function pn(e){let t=x("a",{href:e});t.addEventListener("click",r=>r.stopPropagation()),t.click()}function Ma(e){return O(h(window,"hashchange"),e).pipe(m(cn),Q(cn()),b(t=>t.length>0),G(1))}function ln(e){return Ma(e).pipe(m(t=>fe(`[id="${t}"]`)),b(t=>typeof t!="undefined"))}function Pt(e){let t=matchMedia(e);return ar(r=>t.addListener(()=>r(t.matches))).pipe(Q(t.matches))}function mn(){let e=matchMedia("print");return O(h(window,"beforeprint").pipe(m(()=>!0)),h(window,"afterprint").pipe(m(()=>!1))).pipe(Q(e.matches))}function Nr(e,t){return e.pipe(v(r=>r?t():S))}function zr(e,t){return new j(r=>{let o=new XMLHttpRequest;return o.open("GET",`${e}`),o.responseType="blob",o.addEventListener("load",()=>{o.status>=200&&o.status<300?(r.next(o.response),r.complete()):r.error(new Error(o.statusText))}),o.addEventListener("error",()=>{r.error(new Error("Network error"))}),o.addEventListener("abort",()=>{r.complete()}),typeof(t==null?void 0:t.progress$)!="undefined"&&(o.addEventListener("progress",n=>{var i;if(n.lengthComputable)t.progress$.next(n.loaded/n.total*100);else{let a=(i=o.getResponseHeader("Content-Length"))!=null?i:0;t.progress$.next(n.loaded/+a*100)}}),t.progress$.next(5)),o.send(),()=>o.abort()})}function je(e,t){return zr(e,t).pipe(v(r=>r.text()),m(r=>JSON.parse(r)),G(1))}function fn(e,t){let r=new DOMParser;return zr(e,t).pipe(v(o=>o.text()),m(o=>r.parseFromString(o,"text/html")),G(1))}function un(e,t){let r=new DOMParser;return zr(e,t).pipe(v(o=>o.text()),m(o=>r.parseFromString(o,"text/xml")),G(1))}function dn(){return{x:Math.max(0,scrollX),y:Math.max(0,scrollY)}}function hn(){return O(h(window,"scroll",{passive:!0}),h(window,"resize",{passive:!0})).pipe(m(dn),Q(dn()))}function bn(){return{width:innerWidth,height:innerHeight}}function vn(){return h(window,"resize",{passive:!0}).pipe(m(bn),Q(bn()))}function gn(){return z([hn(),vn()]).pipe(m(([e,t])=>({offset:e,size:t})),G(1))}function mr(e,{viewport$:t,header$:r}){let o=t.pipe(ee("size")),n=z([o,r]).pipe(m(()=>Ve(e)));return z([r,t,n]).pipe(m(([{height:i},{offset:a,size:s},{x:p,y:c}])=>({offset:{x:a.x-p,y:a.y-c+i},size:s})))}function _a(e){return h(e,"message",t=>t.data)}function Aa(e){let t=new g;return t.subscribe(r=>e.postMessage(r)),t}function yn(e,t=new Worker(e)){let r=_a(t),o=Aa(t),n=new g;n.subscribe(o);let i=o.pipe(Z(),ie(!0));return n.pipe(Z(),Re(r.pipe(W(i))),pe())}var Ca=R("#__config"),Ot=JSON.parse(Ca.textContent);Ot.base=`${new URL(Ot.base,ye())}`;function xe(){return Ot}function B(e){return Ot.features.includes(e)}function Ee(e,t){return typeof t!="undefined"?Ot.translations[e].replace("#",t.toString()):Ot.translations[e]}function Se(e,t=document){return R(`[data-md-component=${e}]`,t)}function ae(e,t=document){return P(`[data-md-component=${e}]`,t)}function ka(e){let t=R(".md-typeset > :first-child",e);return h(t,"click",{once:!0}).pipe(m(()=>R(".md-typeset",e)),m(r=>({hash:__md_hash(r.innerHTML)})))}function xn(e){if(!B("announce.dismiss")||!e.childElementCount)return S;if(!e.hidden){let t=R(".md-typeset",e);__md_hash(t.innerHTML)===__md_get("__announce")&&(e.hidden=!0)}return C(()=>{let t=new g;return t.subscribe(({hash:r})=>{e.hidden=!0,__md_set("__announce",r)}),ka(e).pipe(w(r=>t.next(r)),_(()=>t.complete()),m(r=>$({ref:e},r)))})}function Ha(e,{target$:t}){return t.pipe(m(r=>({hidden:r!==e})))}function En(e,t){let r=new g;return r.subscribe(({hidden:o})=>{e.hidden=o}),Ha(e,t).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))}function Rt(e,t){return t==="inline"?x("div",{class:"md-tooltip md-tooltip--inline",id:e,role:"tooltip"},x("div",{class:"md-tooltip__inner md-typeset"})):x("div",{class:"md-tooltip",id:e,role:"tooltip"},x("div",{class:"md-tooltip__inner md-typeset"}))}function wn(...e){return x("div",{class:"md-tooltip2",role:"tooltip"},x("div",{class:"md-tooltip2__inner md-typeset"},e))}function Tn(e,t){if(t=t?`${t}_annotation_${e}`:void 0,t){let r=t?`#${t}`:void 0;return x("aside",{class:"md-annotation",tabIndex:0},Rt(t),x("a",{href:r,class:"md-annotation__index",tabIndex:-1},x("span",{"data-md-annotation-id":e})))}else return x("aside",{class:"md-annotation",tabIndex:0},Rt(t),x("span",{class:"md-annotation__index",tabIndex:-1},x("span",{"data-md-annotation-id":e})))}function Sn(e){return x("button",{class:"md-clipboard md-icon",title:Ee("clipboard.copy"),"data-clipboard-target":`#${e} > code`})}var Ln=Mt(qr());function Qr(e,t){let r=t&2,o=t&1,n=Object.keys(e.terms).filter(p=>!e.terms[p]).reduce((p,c)=>[...p,x("del",null,(0,Ln.default)(c))," "],[]).slice(0,-1),i=xe(),a=new URL(e.location,i.base);B("search.highlight")&&a.searchParams.set("h",Object.entries(e.terms).filter(([,p])=>p).reduce((p,[c])=>`${p} ${c}`.trim(),""));let{tags:s}=xe();return x("a",{href:`${a}`,class:"md-search-result__link",tabIndex:-1},x("article",{class:"md-search-result__article md-typeset","data-md-score":e.score.toFixed(2)},r>0&&x("div",{class:"md-search-result__icon md-icon"}),r>0&&x("h1",null,e.title),r<=0&&x("h2",null,e.title),o>0&&e.text.length>0&&e.text,e.tags&&x("nav",{class:"md-tags"},e.tags.map(p=>{let c=s?p in s?`md-tag-icon md-tag--${s[p]}`:"md-tag-icon":"";return x("span",{class:`md-tag ${c}`},p)})),o>0&&n.length>0&&x("p",{class:"md-search-result__terms"},Ee("search.result.term.missing"),": ",...n)))}function Mn(e){let t=e[0].score,r=[...e],o=xe(),n=r.findIndex(l=>!`${new URL(l.location,o.base)}`.includes("#")),[i]=r.splice(n,1),a=r.findIndex(l=>l.scoreQr(l,1)),...p.length?[x("details",{class:"md-search-result__more"},x("summary",{tabIndex:-1},x("div",null,p.length>0&&p.length===1?Ee("search.result.more.one"):Ee("search.result.more.other",p.length))),...p.map(l=>Qr(l,1)))]:[]];return x("li",{class:"md-search-result__item"},c)}function _n(e){return x("ul",{class:"md-source__facts"},Object.entries(e).map(([t,r])=>x("li",{class:`md-source__fact md-source__fact--${t}`},typeof r=="number"?sr(r):r)))}function Kr(e){let t=`tabbed-control tabbed-control--${e}`;return x("div",{class:t,hidden:!0},x("button",{class:"tabbed-button",tabIndex:-1,"aria-hidden":"true"}))}function An(e){return x("div",{class:"md-typeset__scrollwrap"},x("div",{class:"md-typeset__table"},e))}function Ra(e){var o;let t=xe(),r=new URL(`../${e.version}/`,t.base);return x("li",{class:"md-version__item"},x("a",{href:`${r}`,class:"md-version__link"},e.title,((o=t.version)==null?void 0:o.alias)&&e.aliases.length>0&&x("span",{class:"md-version__alias"},e.aliases[0])))}function Cn(e,t){var o;let r=xe();return e=e.filter(n=>{var i;return!((i=n.properties)!=null&&i.hidden)}),x("div",{class:"md-version"},x("button",{class:"md-version__current","aria-label":Ee("select.version")},t.title,((o=r.version)==null?void 0:o.alias)&&t.aliases.length>0&&x("span",{class:"md-version__alias"},t.aliases[0])),x("ul",{class:"md-version__list"},e.map(Ra)))}var Ia=0;function ja(e){let t=z([et(e),$t(e)]).pipe(m(([o,n])=>o||n),K()),r=C(()=>Zo(e)).pipe(ne(Ne),pt(1),He(t),m(()=>en(e)));return t.pipe(Ae(o=>o),v(()=>z([t,r])),m(([o,n])=>({active:o,offset:n})),pe())}function Fa(e,t){let{content$:r,viewport$:o}=t,n=`__tooltip2_${Ia++}`;return C(()=>{let i=new g,a=new _r(!1);i.pipe(Z(),ie(!1)).subscribe(a);let s=a.pipe(Ht(c=>Le(+!c*250,kr)),K(),v(c=>c?r:S),w(c=>c.id=n),pe());z([i.pipe(m(({active:c})=>c)),s.pipe(v(c=>$t(c,250)),Q(!1))]).pipe(m(c=>c.some(l=>l))).subscribe(a);let p=a.pipe(b(c=>c),re(s,o),m(([c,l,{size:f}])=>{let u=e.getBoundingClientRect(),d=u.width/2;if(l.role==="tooltip")return{x:d,y:8+u.height};if(u.y>=f.height/2){let{height:y}=ce(l);return{x:d,y:-16-y}}else return{x:d,y:16+u.height}}));return z([s,i,p]).subscribe(([c,{offset:l},f])=>{c.style.setProperty("--md-tooltip-host-x",`${l.x}px`),c.style.setProperty("--md-tooltip-host-y",`${l.y}px`),c.style.setProperty("--md-tooltip-x",`${f.x}px`),c.style.setProperty("--md-tooltip-y",`${f.y}px`),c.classList.toggle("md-tooltip2--top",f.y<0),c.classList.toggle("md-tooltip2--bottom",f.y>=0)}),a.pipe(b(c=>c),re(s,(c,l)=>l),b(c=>c.role==="tooltip")).subscribe(c=>{let l=ce(R(":scope > *",c));c.style.setProperty("--md-tooltip-width",`${l.width}px`),c.style.setProperty("--md-tooltip-tail","0px")}),a.pipe(K(),ve(me),re(s)).subscribe(([c,l])=>{l.classList.toggle("md-tooltip2--active",c)}),z([a.pipe(b(c=>c)),s]).subscribe(([c,l])=>{l.role==="dialog"?(e.setAttribute("aria-controls",n),e.setAttribute("aria-haspopup","dialog")):e.setAttribute("aria-describedby",n)}),a.pipe(b(c=>!c)).subscribe(()=>{e.removeAttribute("aria-controls"),e.removeAttribute("aria-describedby"),e.removeAttribute("aria-haspopup")}),ja(e).pipe(w(c=>i.next(c)),_(()=>i.complete()),m(c=>$({ref:e},c)))})}function mt(e,{viewport$:t},r=document.body){return Fa(e,{content$:new j(o=>{let n=e.title,i=wn(n);return o.next(i),e.removeAttribute("title"),r.append(i),()=>{i.remove(),e.setAttribute("title",n)}}),viewport$:t})}function Ua(e,t){let r=C(()=>z([tn(e),Ne(t)])).pipe(m(([{x:o,y:n},i])=>{let{width:a,height:s}=ce(e);return{x:o-i.x+a/2,y:n-i.y+s/2}}));return et(e).pipe(v(o=>r.pipe(m(n=>({active:o,offset:n})),Te(+!o||1/0))))}function kn(e,t,{target$:r}){let[o,n]=Array.from(e.children);return C(()=>{let i=new g,a=i.pipe(Z(),ie(!0));return i.subscribe({next({offset:s}){e.style.setProperty("--md-tooltip-x",`${s.x}px`),e.style.setProperty("--md-tooltip-y",`${s.y}px`)},complete(){e.style.removeProperty("--md-tooltip-x"),e.style.removeProperty("--md-tooltip-y")}}),tt(e).pipe(W(a)).subscribe(s=>{e.toggleAttribute("data-md-visible",s)}),O(i.pipe(b(({active:s})=>s)),i.pipe(_e(250),b(({active:s})=>!s))).subscribe({next({active:s}){s?e.prepend(o):o.remove()},complete(){e.prepend(o)}}),i.pipe(Me(16,me)).subscribe(({active:s})=>{o.classList.toggle("md-tooltip--active",s)}),i.pipe(pt(125,me),b(()=>!!e.offsetParent),m(()=>e.offsetParent.getBoundingClientRect()),m(({x:s})=>s)).subscribe({next(s){s?e.style.setProperty("--md-tooltip-0",`${-s}px`):e.style.removeProperty("--md-tooltip-0")},complete(){e.style.removeProperty("--md-tooltip-0")}}),h(n,"click").pipe(W(a),b(s=>!(s.metaKey||s.ctrlKey))).subscribe(s=>{s.stopPropagation(),s.preventDefault()}),h(n,"mousedown").pipe(W(a),re(i)).subscribe(([s,{active:p}])=>{var c;if(s.button!==0||s.metaKey||s.ctrlKey)s.preventDefault();else if(p){s.preventDefault();let l=e.parentElement.closest(".md-annotation");l instanceof HTMLElement?l.focus():(c=Ie())==null||c.blur()}}),r.pipe(W(a),b(s=>s===o),Ge(125)).subscribe(()=>e.focus()),Ua(e,t).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))})}function Wa(e){return e.tagName==="CODE"?P(".c, .c1, .cm",e):[e]}function Da(e){let t=[];for(let r of Wa(e)){let o=[],n=document.createNodeIterator(r,NodeFilter.SHOW_TEXT);for(let i=n.nextNode();i;i=n.nextNode())o.push(i);for(let i of o){let a;for(;a=/(\(\d+\))(!)?/.exec(i.textContent);){let[,s,p]=a;if(typeof p=="undefined"){let c=i.splitText(a.index);i=c.splitText(s.length),t.push(c)}else{i.textContent=s,t.push(i);break}}}}return t}function Hn(e,t){t.append(...Array.from(e.childNodes))}function fr(e,t,{target$:r,print$:o}){let n=t.closest("[id]"),i=n==null?void 0:n.id,a=new Map;for(let s of Da(t)){let[,p]=s.textContent.match(/\((\d+)\)/);fe(`:scope > li:nth-child(${p})`,e)&&(a.set(p,Tn(p,i)),s.replaceWith(a.get(p)))}return a.size===0?S:C(()=>{let s=new g,p=s.pipe(Z(),ie(!0)),c=[];for(let[l,f]of a)c.push([R(".md-typeset",f),R(`:scope > li:nth-child(${l})`,e)]);return o.pipe(W(p)).subscribe(l=>{e.hidden=!l,e.classList.toggle("md-annotation-list",l);for(let[f,u]of c)l?Hn(f,u):Hn(u,f)}),O(...[...a].map(([,l])=>kn(l,t,{target$:r}))).pipe(_(()=>s.complete()),pe())})}function $n(e){if(e.nextElementSibling){let t=e.nextElementSibling;if(t.tagName==="OL")return t;if(t.tagName==="P"&&!t.children.length)return $n(t)}}function Pn(e,t){return C(()=>{let r=$n(e);return typeof r!="undefined"?fr(r,e,t):S})}var Rn=Mt(Br());var Va=0;function In(e){if(e.nextElementSibling){let t=e.nextElementSibling;if(t.tagName==="OL")return t;if(t.tagName==="P"&&!t.children.length)return In(t)}}function Na(e){return ge(e).pipe(m(({width:t})=>({scrollable:St(e).width>t})),ee("scrollable"))}function jn(e,t){let{matches:r}=matchMedia("(hover)"),o=C(()=>{let n=new g,i=n.pipe(jr(1));n.subscribe(({scrollable:c})=>{c&&r?e.setAttribute("tabindex","0"):e.removeAttribute("tabindex")});let a=[];if(Rn.default.isSupported()&&(e.closest(".copy")||B("content.code.copy")&&!e.closest(".no-copy"))){let c=e.closest("pre");c.id=`__code_${Va++}`;let l=Sn(c.id);c.insertBefore(l,e),B("content.tooltips")&&a.push(mt(l,{viewport$}))}let s=e.closest(".highlight");if(s instanceof HTMLElement){let c=In(s);if(typeof c!="undefined"&&(s.classList.contains("annotate")||B("content.code.annotate"))){let l=fr(c,e,t);a.push(ge(s).pipe(W(i),m(({width:f,height:u})=>f&&u),K(),v(f=>f?l:S)))}}return P(":scope > span[id]",e).length&&e.classList.add("md-code__content"),Na(e).pipe(w(c=>n.next(c)),_(()=>n.complete()),m(c=>$({ref:e},c)),Re(...a))});return B("content.lazy")?tt(e).pipe(b(n=>n),Te(1),v(()=>o)):o}function za(e,{target$:t,print$:r}){let o=!0;return O(t.pipe(m(n=>n.closest("details:not([open])")),b(n=>e===n),m(()=>({action:"open",reveal:!0}))),r.pipe(b(n=>n||!o),w(()=>o=e.open),m(n=>({action:n?"open":"close"}))))}function Fn(e,t){return C(()=>{let r=new g;return r.subscribe(({action:o,reveal:n})=>{e.toggleAttribute("open",o==="open"),n&&e.scrollIntoView()}),za(e,t).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}var Un=".node circle,.node ellipse,.node path,.node polygon,.node rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}marker{fill:var(--md-mermaid-edge-color)!important}.edgeLabel .label rect{fill:#0000}.label{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.label foreignObject{line-height:normal;overflow:visible}.label div .edgeLabel{color:var(--md-mermaid-label-fg-color)}.edgeLabel,.edgeLabel p,.label div .edgeLabel{background-color:var(--md-mermaid-label-bg-color)}.edgeLabel,.edgeLabel p{fill:var(--md-mermaid-label-bg-color);color:var(--md-mermaid-edge-color)}.edgePath .path,.flowchart-link{stroke:var(--md-mermaid-edge-color);stroke-width:.05rem}.edgePath .arrowheadPath{fill:var(--md-mermaid-edge-color);stroke:none}.cluster rect{fill:var(--md-default-fg-color--lightest);stroke:var(--md-default-fg-color--lighter)}.cluster span{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}g #flowchart-circleEnd,g #flowchart-circleStart,g #flowchart-crossEnd,g #flowchart-crossStart,g #flowchart-pointEnd,g #flowchart-pointStart{stroke:none}g.classGroup line,g.classGroup rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}g.classGroup text{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.classLabel .box{fill:var(--md-mermaid-label-bg-color);background-color:var(--md-mermaid-label-bg-color);opacity:1}.classLabel .label{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.node .divider{stroke:var(--md-mermaid-node-fg-color)}.relation{stroke:var(--md-mermaid-edge-color)}.cardinality{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.cardinality text{fill:inherit!important}defs #classDiagram-compositionEnd,defs #classDiagram-compositionStart,defs #classDiagram-dependencyEnd,defs #classDiagram-dependencyStart,defs #classDiagram-extensionEnd,defs #classDiagram-extensionStart{fill:var(--md-mermaid-edge-color)!important;stroke:var(--md-mermaid-edge-color)!important}defs #classDiagram-aggregationEnd,defs #classDiagram-aggregationStart{fill:var(--md-mermaid-label-bg-color)!important;stroke:var(--md-mermaid-edge-color)!important}g.stateGroup rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}g.stateGroup .state-title{fill:var(--md-mermaid-label-fg-color)!important;font-family:var(--md-mermaid-font-family)}g.stateGroup .composit{fill:var(--md-mermaid-label-bg-color)}.nodeLabel,.nodeLabel p{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}a .nodeLabel{text-decoration:underline}.node circle.state-end,.node circle.state-start,.start-state{fill:var(--md-mermaid-edge-color);stroke:none}.end-state-inner,.end-state-outer{fill:var(--md-mermaid-edge-color)}.end-state-inner,.node circle.state-end{stroke:var(--md-mermaid-label-bg-color)}.transition{stroke:var(--md-mermaid-edge-color)}[id^=state-fork] rect,[id^=state-join] rect{fill:var(--md-mermaid-edge-color)!important;stroke:none!important}.statediagram-cluster.statediagram-cluster .inner{fill:var(--md-default-bg-color)}.statediagram-cluster rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}.statediagram-state rect.divider{fill:var(--md-default-fg-color--lightest);stroke:var(--md-default-fg-color--lighter)}defs #statediagram-barbEnd{stroke:var(--md-mermaid-edge-color)}.attributeBoxEven,.attributeBoxOdd{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}.entityBox{fill:var(--md-mermaid-label-bg-color);stroke:var(--md-mermaid-node-fg-color)}.entityLabel{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.relationshipLabelBox{fill:var(--md-mermaid-label-bg-color);fill-opacity:1;background-color:var(--md-mermaid-label-bg-color);opacity:1}.relationshipLabel{fill:var(--md-mermaid-label-fg-color)}.relationshipLine{stroke:var(--md-mermaid-edge-color)}defs #ONE_OR_MORE_END *,defs #ONE_OR_MORE_START *,defs #ONLY_ONE_END *,defs #ONLY_ONE_START *,defs #ZERO_OR_MORE_END *,defs #ZERO_OR_MORE_START *,defs #ZERO_OR_ONE_END *,defs #ZERO_OR_ONE_START *{stroke:var(--md-mermaid-edge-color)!important}defs #ZERO_OR_MORE_END circle,defs #ZERO_OR_MORE_START circle{fill:var(--md-mermaid-label-bg-color)}.actor{fill:var(--md-mermaid-sequence-actor-bg-color);stroke:var(--md-mermaid-sequence-actor-border-color)}text.actor>tspan{fill:var(--md-mermaid-sequence-actor-fg-color);font-family:var(--md-mermaid-font-family)}line{stroke:var(--md-mermaid-sequence-actor-line-color)}.actor-man circle,.actor-man line{fill:var(--md-mermaid-sequence-actorman-bg-color);stroke:var(--md-mermaid-sequence-actorman-line-color)}.messageLine0,.messageLine1{stroke:var(--md-mermaid-sequence-message-line-color)}.note{fill:var(--md-mermaid-sequence-note-bg-color);stroke:var(--md-mermaid-sequence-note-border-color)}.loopText,.loopText>tspan,.messageText,.noteText>tspan{stroke:none;font-family:var(--md-mermaid-font-family)!important}.messageText{fill:var(--md-mermaid-sequence-message-fg-color)}.loopText,.loopText>tspan{fill:var(--md-mermaid-sequence-loop-fg-color)}.noteText>tspan{fill:var(--md-mermaid-sequence-note-fg-color)}#arrowhead path{fill:var(--md-mermaid-sequence-message-line-color);stroke:none}.loopLine{fill:var(--md-mermaid-sequence-loop-bg-color);stroke:var(--md-mermaid-sequence-loop-border-color)}.labelBox{fill:var(--md-mermaid-sequence-label-bg-color);stroke:none}.labelText,.labelText>span{fill:var(--md-mermaid-sequence-label-fg-color);font-family:var(--md-mermaid-font-family)}.sequenceNumber{fill:var(--md-mermaid-sequence-number-fg-color)}rect.rect{fill:var(--md-mermaid-sequence-box-bg-color);stroke:none}rect.rect+text.text{fill:var(--md-mermaid-sequence-box-fg-color)}defs #sequencenumber{fill:var(--md-mermaid-sequence-number-bg-color)!important}";var Gr,Qa=0;function Ka(){return typeof mermaid=="undefined"||mermaid instanceof Element?Tt("https://unpkg.com/mermaid@11/dist/mermaid.min.js"):I(void 0)}function Wn(e){return e.classList.remove("mermaid"),Gr||(Gr=Ka().pipe(w(()=>mermaid.initialize({startOnLoad:!1,themeCSS:Un,sequence:{actorFontSize:"16px",messageFontSize:"16px",noteFontSize:"16px"}})),m(()=>{}),G(1))),Gr.subscribe(()=>co(this,null,function*(){e.classList.add("mermaid");let t=`__mermaid_${Qa++}`,r=x("div",{class:"mermaid"}),o=e.textContent,{svg:n,fn:i}=yield mermaid.render(t,o),a=r.attachShadow({mode:"closed"});a.innerHTML=n,e.replaceWith(r),i==null||i(a)})),Gr.pipe(m(()=>({ref:e})))}var Dn=x("table");function Vn(e){return e.replaceWith(Dn),Dn.replaceWith(An(e)),I({ref:e})}function Ya(e){let t=e.find(r=>r.checked)||e[0];return O(...e.map(r=>h(r,"change").pipe(m(()=>R(`label[for="${r.id}"]`))))).pipe(Q(R(`label[for="${t.id}"]`)),m(r=>({active:r})))}function Nn(e,{viewport$:t,target$:r}){let o=R(".tabbed-labels",e),n=P(":scope > input",e),i=Kr("prev");e.append(i);let a=Kr("next");return e.append(a),C(()=>{let s=new g,p=s.pipe(Z(),ie(!0));z([s,ge(e),tt(e)]).pipe(W(p),Me(1,me)).subscribe({next([{active:c},l]){let f=Ve(c),{width:u}=ce(c);e.style.setProperty("--md-indicator-x",`${f.x}px`),e.style.setProperty("--md-indicator-width",`${u}px`);let d=pr(o);(f.xd.x+l.width)&&o.scrollTo({left:Math.max(0,f.x-16),behavior:"smooth"})},complete(){e.style.removeProperty("--md-indicator-x"),e.style.removeProperty("--md-indicator-width")}}),z([Ne(o),ge(o)]).pipe(W(p)).subscribe(([c,l])=>{let f=St(o);i.hidden=c.x<16,a.hidden=c.x>f.width-l.width-16}),O(h(i,"click").pipe(m(()=>-1)),h(a,"click").pipe(m(()=>1))).pipe(W(p)).subscribe(c=>{let{width:l}=ce(o);o.scrollBy({left:l*c,behavior:"smooth"})}),r.pipe(W(p),b(c=>n.includes(c))).subscribe(c=>c.click()),o.classList.add("tabbed-labels--linked");for(let c of n){let l=R(`label[for="${c.id}"]`);l.replaceChildren(x("a",{href:`#${l.htmlFor}`,tabIndex:-1},...Array.from(l.childNodes))),h(l.firstElementChild,"click").pipe(W(p),b(f=>!(f.metaKey||f.ctrlKey)),w(f=>{f.preventDefault(),f.stopPropagation()})).subscribe(()=>{history.replaceState({},"",`#${l.htmlFor}`),l.click()})}return B("content.tabs.link")&&s.pipe(Ce(1),re(t)).subscribe(([{active:c},{offset:l}])=>{let f=c.innerText.trim();if(c.hasAttribute("data-md-switching"))c.removeAttribute("data-md-switching");else{let u=e.offsetTop-l.y;for(let y of P("[data-tabs]"))for(let L of P(":scope > input",y)){let X=R(`label[for="${L.id}"]`);if(X!==c&&X.innerText.trim()===f){X.setAttribute("data-md-switching",""),L.click();break}}window.scrollTo({top:e.offsetTop-u});let d=__md_get("__tabs")||[];__md_set("__tabs",[...new Set([f,...d])])}}),s.pipe(W(p)).subscribe(()=>{for(let c of P("audio, video",e))c.pause()}),Ya(n).pipe(w(c=>s.next(c)),_(()=>s.complete()),m(c=>$({ref:e},c)))}).pipe(Ke(se))}function zn(e,{viewport$:t,target$:r,print$:o}){return O(...P(".annotate:not(.highlight)",e).map(n=>Pn(n,{target$:r,print$:o})),...P("pre:not(.mermaid) > code",e).map(n=>jn(n,{target$:r,print$:o})),...P("pre.mermaid",e).map(n=>Wn(n)),...P("table:not([class])",e).map(n=>Vn(n)),...P("details",e).map(n=>Fn(n,{target$:r,print$:o})),...P("[data-tabs]",e).map(n=>Nn(n,{viewport$:t,target$:r})),...P("[title]",e).filter(()=>B("content.tooltips")).map(n=>mt(n,{viewport$:t})))}function Ba(e,{alert$:t}){return t.pipe(v(r=>O(I(!0),I(!1).pipe(Ge(2e3))).pipe(m(o=>({message:r,active:o})))))}function qn(e,t){let r=R(".md-typeset",e);return C(()=>{let o=new g;return o.subscribe(({message:n,active:i})=>{e.classList.toggle("md-dialog--active",i),r.textContent=n}),Ba(e,t).pipe(w(n=>o.next(n)),_(()=>o.complete()),m(n=>$({ref:e},n)))})}var Ga=0;function Ja(e,t){document.body.append(e);let{width:r}=ce(e);e.style.setProperty("--md-tooltip-width",`${r}px`),e.remove();let o=cr(t),n=typeof o!="undefined"?Ne(o):I({x:0,y:0}),i=O(et(t),$t(t)).pipe(K());return z([i,n]).pipe(m(([a,s])=>{let{x:p,y:c}=Ve(t),l=ce(t),f=t.closest("table");return f&&t.parentElement&&(p+=f.offsetLeft+t.parentElement.offsetLeft,c+=f.offsetTop+t.parentElement.offsetTop),{active:a,offset:{x:p-s.x+l.width/2-r/2,y:c-s.y+l.height+8}}}))}function Qn(e){let t=e.title;if(!t.length)return S;let r=`__tooltip_${Ga++}`,o=Rt(r,"inline"),n=R(".md-typeset",o);return n.innerHTML=t,C(()=>{let i=new g;return i.subscribe({next({offset:a}){o.style.setProperty("--md-tooltip-x",`${a.x}px`),o.style.setProperty("--md-tooltip-y",`${a.y}px`)},complete(){o.style.removeProperty("--md-tooltip-x"),o.style.removeProperty("--md-tooltip-y")}}),O(i.pipe(b(({active:a})=>a)),i.pipe(_e(250),b(({active:a})=>!a))).subscribe({next({active:a}){a?(e.insertAdjacentElement("afterend",o),e.setAttribute("aria-describedby",r),e.removeAttribute("title")):(o.remove(),e.removeAttribute("aria-describedby"),e.setAttribute("title",t))},complete(){o.remove(),e.removeAttribute("aria-describedby"),e.setAttribute("title",t)}}),i.pipe(Me(16,me)).subscribe(({active:a})=>{o.classList.toggle("md-tooltip--active",a)}),i.pipe(pt(125,me),b(()=>!!e.offsetParent),m(()=>e.offsetParent.getBoundingClientRect()),m(({x:a})=>a)).subscribe({next(a){a?o.style.setProperty("--md-tooltip-0",`${-a}px`):o.style.removeProperty("--md-tooltip-0")},complete(){o.style.removeProperty("--md-tooltip-0")}}),Ja(o,e).pipe(w(a=>i.next(a)),_(()=>i.complete()),m(a=>$({ref:e},a)))}).pipe(Ke(se))}function Xa({viewport$:e}){if(!B("header.autohide"))return I(!1);let t=e.pipe(m(({offset:{y:n}})=>n),Be(2,1),m(([n,i])=>[nMath.abs(i-n.y)>100),m(([,[n]])=>n),K()),o=ze("search");return z([e,o]).pipe(m(([{offset:n},i])=>n.y>400&&!i),K(),v(n=>n?r:I(!1)),Q(!1))}function Kn(e,t){return C(()=>z([ge(e),Xa(t)])).pipe(m(([{height:r},o])=>({height:r,hidden:o})),K((r,o)=>r.height===o.height&&r.hidden===o.hidden),G(1))}function Yn(e,{header$:t,main$:r}){return C(()=>{let o=new g,n=o.pipe(Z(),ie(!0));o.pipe(ee("active"),He(t)).subscribe(([{active:a},{hidden:s}])=>{e.classList.toggle("md-header--shadow",a&&!s),e.hidden=s});let i=ue(P("[title]",e)).pipe(b(()=>B("content.tooltips")),ne(a=>Qn(a)));return r.subscribe(o),t.pipe(W(n),m(a=>$({ref:e},a)),Re(i.pipe(W(n))))})}function Za(e,{viewport$:t,header$:r}){return mr(e,{viewport$:t,header$:r}).pipe(m(({offset:{y:o}})=>{let{height:n}=ce(e);return{active:o>=n}}),ee("active"))}function Bn(e,t){return C(()=>{let r=new g;r.subscribe({next({active:n}){e.classList.toggle("md-header__title--active",n)},complete(){e.classList.remove("md-header__title--active")}});let o=fe(".md-content h1");return typeof o=="undefined"?S:Za(o,t).pipe(w(n=>r.next(n)),_(()=>r.complete()),m(n=>$({ref:e},n)))})}function Gn(e,{viewport$:t,header$:r}){let o=r.pipe(m(({height:i})=>i),K()),n=o.pipe(v(()=>ge(e).pipe(m(({height:i})=>({top:e.offsetTop,bottom:e.offsetTop+i})),ee("bottom"))));return z([o,n,t]).pipe(m(([i,{top:a,bottom:s},{offset:{y:p},size:{height:c}}])=>(c=Math.max(0,c-Math.max(0,a-p,i)-Math.max(0,c+p-s)),{offset:a-i,height:c,active:a-i<=p})),K((i,a)=>i.offset===a.offset&&i.height===a.height&&i.active===a.active))}function es(e){let t=__md_get("__palette")||{index:e.findIndex(o=>matchMedia(o.getAttribute("data-md-color-media")).matches)},r=Math.max(0,Math.min(t.index,e.length-1));return I(...e).pipe(ne(o=>h(o,"change").pipe(m(()=>o))),Q(e[r]),m(o=>({index:e.indexOf(o),color:{media:o.getAttribute("data-md-color-media"),scheme:o.getAttribute("data-md-color-scheme"),primary:o.getAttribute("data-md-color-primary"),accent:o.getAttribute("data-md-color-accent")}})),G(1))}function Jn(e){let t=P("input",e),r=x("meta",{name:"theme-color"});document.head.appendChild(r);let o=x("meta",{name:"color-scheme"});document.head.appendChild(o);let n=Pt("(prefers-color-scheme: light)");return C(()=>{let i=new g;return i.subscribe(a=>{if(document.body.setAttribute("data-md-color-switching",""),a.color.media==="(prefers-color-scheme)"){let s=matchMedia("(prefers-color-scheme: light)"),p=document.querySelector(s.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");a.color.scheme=p.getAttribute("data-md-color-scheme"),a.color.primary=p.getAttribute("data-md-color-primary"),a.color.accent=p.getAttribute("data-md-color-accent")}for(let[s,p]of Object.entries(a.color))document.body.setAttribute(`data-md-color-${s}`,p);for(let s=0;sa.key==="Enter"),re(i,(a,s)=>s)).subscribe(({index:a})=>{a=(a+1)%t.length,t[a].click(),t[a].focus()}),i.pipe(m(()=>{let a=Se("header"),s=window.getComputedStyle(a);return o.content=s.colorScheme,s.backgroundColor.match(/\d+/g).map(p=>(+p).toString(16).padStart(2,"0")).join("")})).subscribe(a=>r.content=`#${a}`),i.pipe(ve(se)).subscribe(()=>{document.body.removeAttribute("data-md-color-switching")}),es(t).pipe(W(n.pipe(Ce(1))),ct(),w(a=>i.next(a)),_(()=>i.complete()),m(a=>$({ref:e},a)))})}function Xn(e,{progress$:t}){return C(()=>{let r=new g;return r.subscribe(({value:o})=>{e.style.setProperty("--md-progress-value",`${o}`)}),t.pipe(w(o=>r.next({value:o})),_(()=>r.complete()),m(o=>({ref:e,value:o})))})}var Jr=Mt(Br());function ts(e){e.setAttribute("data-md-copying","");let t=e.closest("[data-copy]"),r=t?t.getAttribute("data-copy"):e.innerText;return e.removeAttribute("data-md-copying"),r.trimEnd()}function Zn({alert$:e}){Jr.default.isSupported()&&new j(t=>{new Jr.default("[data-clipboard-target], [data-clipboard-text]",{text:r=>r.getAttribute("data-clipboard-text")||ts(R(r.getAttribute("data-clipboard-target")))}).on("success",r=>t.next(r))}).pipe(w(t=>{t.trigger.focus()}),m(()=>Ee("clipboard.copied"))).subscribe(e)}function ei(e,t){return e.protocol=t.protocol,e.hostname=t.hostname,e}function rs(e,t){let r=new Map;for(let o of P("url",e)){let n=R("loc",o),i=[ei(new URL(n.textContent),t)];r.set(`${i[0]}`,i);for(let a of P("[rel=alternate]",o)){let s=a.getAttribute("href");s!=null&&i.push(ei(new URL(s),t))}}return r}function ur(e){return un(new URL("sitemap.xml",e)).pipe(m(t=>rs(t,new URL(e))),de(()=>I(new Map)))}function os(e,t){if(!(e.target instanceof Element))return S;let r=e.target.closest("a");if(r===null)return S;if(r.target||e.metaKey||e.ctrlKey)return S;let o=new URL(r.href);return o.search=o.hash="",t.has(`${o}`)?(e.preventDefault(),I(new URL(r.href))):S}function ti(e){let t=new Map;for(let r of P(":scope > *",e.head))t.set(r.outerHTML,r);return t}function ri(e){for(let t of P("[href], [src]",e))for(let r of["href","src"]){let o=t.getAttribute(r);if(o&&!/^(?:[a-z]+:)?\/\//i.test(o)){t[r]=t[r];break}}return I(e)}function ns(e){for(let o of["[data-md-component=announce]","[data-md-component=container]","[data-md-component=header-topic]","[data-md-component=outdated]","[data-md-component=logo]","[data-md-component=skip]",...B("navigation.tabs.sticky")?["[data-md-component=tabs]"]:[]]){let n=fe(o),i=fe(o,e);typeof n!="undefined"&&typeof i!="undefined"&&n.replaceWith(i)}let t=ti(document);for(let[o,n]of ti(e))t.has(o)?t.delete(o):document.head.appendChild(n);for(let o of t.values()){let n=o.getAttribute("name");n!=="theme-color"&&n!=="color-scheme"&&o.remove()}let r=Se("container");return We(P("script",r)).pipe(v(o=>{let n=e.createElement("script");if(o.src){for(let i of o.getAttributeNames())n.setAttribute(i,o.getAttribute(i));return o.replaceWith(n),new j(i=>{n.onload=()=>i.complete()})}else return n.textContent=o.textContent,o.replaceWith(n),S}),Z(),ie(document))}function oi({location$:e,viewport$:t,progress$:r}){let o=xe();if(location.protocol==="file:")return S;let n=ur(o.base);I(document).subscribe(ri);let i=h(document.body,"click").pipe(He(n),v(([p,c])=>os(p,c)),pe()),a=h(window,"popstate").pipe(m(ye),pe());i.pipe(re(t)).subscribe(([p,{offset:c}])=>{history.replaceState(c,""),history.pushState(null,"",p)}),O(i,a).subscribe(e);let s=e.pipe(ee("pathname"),v(p=>fn(p,{progress$:r}).pipe(de(()=>(lt(p,!0),S)))),v(ri),v(ns),pe());return O(s.pipe(re(e,(p,c)=>c)),s.pipe(v(()=>e),ee("pathname"),v(()=>e),ee("hash")),e.pipe(K((p,c)=>p.pathname===c.pathname&&p.hash===c.hash),v(()=>i),w(()=>history.back()))).subscribe(p=>{var c,l;history.state!==null||!p.hash?window.scrollTo(0,(l=(c=history.state)==null?void 0:c.y)!=null?l:0):(history.scrollRestoration="auto",pn(p.hash),history.scrollRestoration="manual")}),e.subscribe(()=>{history.scrollRestoration="manual"}),h(window,"beforeunload").subscribe(()=>{history.scrollRestoration="auto"}),t.pipe(ee("offset"),_e(100)).subscribe(({offset:p})=>{history.replaceState(p,"")}),s}var ni=Mt(qr());function ii(e){let t=e.separator.split("|").map(n=>n.replace(/(\(\?[!=<][^)]+\))/g,"").length===0?"\uFFFD":n).join("|"),r=new RegExp(t,"img"),o=(n,i,a)=>`${i}${a}`;return n=>{n=n.replace(/[\s*+\-:~^]+/g," ").trim();let i=new RegExp(`(^|${e.separator}|)(${n.replace(/[|\\{}()[\]^$+*?.-]/g,"\\$&").replace(r,"|")})`,"img");return a=>(0,ni.default)(a).replace(i,o).replace(/<\/mark>(\s+)]*>/img,"$1")}}function jt(e){return e.type===1}function dr(e){return e.type===3}function ai(e,t){let r=yn(e);return O(I(location.protocol!=="file:"),ze("search")).pipe(Ae(o=>o),v(()=>t)).subscribe(({config:o,docs:n})=>r.next({type:0,data:{config:o,docs:n,options:{suggest:B("search.suggest")}}})),r}function si(e){var l;let{selectedVersionSitemap:t,selectedVersionBaseURL:r,currentLocation:o,currentBaseURL:n}=e,i=(l=Xr(n))==null?void 0:l.pathname;if(i===void 0)return;let a=ss(o.pathname,i);if(a===void 0)return;let s=ps(t.keys());if(!t.has(s))return;let p=Xr(a,s);if(!p||!t.has(p.href))return;let c=Xr(a,r);if(c)return c.hash=o.hash,c.search=o.search,c}function Xr(e,t){try{return new URL(e,t)}catch(r){return}}function ss(e,t){if(e.startsWith(t))return e.slice(t.length)}function cs(e,t){let r=Math.min(e.length,t.length),o;for(o=0;oS)),o=r.pipe(m(n=>{let[,i]=t.base.match(/([^/]+)\/?$/);return n.find(({version:a,aliases:s})=>a===i||s.includes(i))||n[0]}));r.pipe(m(n=>new Map(n.map(i=>[`${new URL(`../${i.version}/`,t.base)}`,i]))),v(n=>h(document.body,"click").pipe(b(i=>!i.metaKey&&!i.ctrlKey),re(o),v(([i,a])=>{if(i.target instanceof Element){let s=i.target.closest("a");if(s&&!s.target&&n.has(s.href)){let p=s.href;return!i.target.closest(".md-version")&&n.get(p)===a?S:(i.preventDefault(),I(new URL(p)))}}return S}),v(i=>ur(i).pipe(m(a=>{var s;return(s=si({selectedVersionSitemap:a,selectedVersionBaseURL:i,currentLocation:ye(),currentBaseURL:t.base}))!=null?s:i})))))).subscribe(n=>lt(n,!0)),z([r,o]).subscribe(([n,i])=>{R(".md-header__topic").appendChild(Cn(n,i))}),e.pipe(v(()=>o)).subscribe(n=>{var a;let i=__md_get("__outdated",sessionStorage);if(i===null){i=!0;let s=((a=t.version)==null?void 0:a.default)||"latest";Array.isArray(s)||(s=[s]);e:for(let p of s)for(let c of n.aliases.concat(n.version))if(new RegExp(p,"i").test(c)){i=!1;break e}__md_set("__outdated",i,sessionStorage)}if(i)for(let s of ae("outdated"))s.hidden=!1})}function ls(e,{worker$:t}){let{searchParams:r}=ye();r.has("q")&&(Je("search",!0),e.value=r.get("q"),e.focus(),ze("search").pipe(Ae(i=>!i)).subscribe(()=>{let i=ye();i.searchParams.delete("q"),history.replaceState({},"",`${i}`)}));let o=et(e),n=O(t.pipe(Ae(jt)),h(e,"keyup"),o).pipe(m(()=>e.value),K());return z([n,o]).pipe(m(([i,a])=>({value:i,focus:a})),G(1))}function pi(e,{worker$:t}){let r=new g,o=r.pipe(Z(),ie(!0));z([t.pipe(Ae(jt)),r],(i,a)=>a).pipe(ee("value")).subscribe(({value:i})=>t.next({type:2,data:i})),r.pipe(ee("focus")).subscribe(({focus:i})=>{i&&Je("search",i)}),h(e.form,"reset").pipe(W(o)).subscribe(()=>e.focus());let n=R("header [for=__search]");return h(n,"click").subscribe(()=>e.focus()),ls(e,{worker$:t}).pipe(w(i=>r.next(i)),_(()=>r.complete()),m(i=>$({ref:e},i)),G(1))}function li(e,{worker$:t,query$:r}){let o=new g,n=on(e.parentElement).pipe(b(Boolean)),i=e.parentElement,a=R(":scope > :first-child",e),s=R(":scope > :last-child",e);ze("search").subscribe(l=>s.setAttribute("role",l?"list":"presentation")),o.pipe(re(r),Wr(t.pipe(Ae(jt)))).subscribe(([{items:l},{value:f}])=>{switch(l.length){case 0:a.textContent=f.length?Ee("search.result.none"):Ee("search.result.placeholder");break;case 1:a.textContent=Ee("search.result.one");break;default:let u=sr(l.length);a.textContent=Ee("search.result.other",u)}});let p=o.pipe(w(()=>s.innerHTML=""),v(({items:l})=>O(I(...l.slice(0,10)),I(...l.slice(10)).pipe(Be(4),Vr(n),v(([f])=>f)))),m(Mn),pe());return p.subscribe(l=>s.appendChild(l)),p.pipe(ne(l=>{let f=fe("details",l);return typeof f=="undefined"?S:h(f,"toggle").pipe(W(o),m(()=>f))})).subscribe(l=>{l.open===!1&&l.offsetTop<=i.scrollTop&&i.scrollTo({top:l.offsetTop})}),t.pipe(b(dr),m(({data:l})=>l)).pipe(w(l=>o.next(l)),_(()=>o.complete()),m(l=>$({ref:e},l)))}function ms(e,{query$:t}){return t.pipe(m(({value:r})=>{let o=ye();return o.hash="",r=r.replace(/\s+/g,"+").replace(/&/g,"%26").replace(/=/g,"%3D"),o.search=`q=${r}`,{url:o}}))}function mi(e,t){let r=new g,o=r.pipe(Z(),ie(!0));return r.subscribe(({url:n})=>{e.setAttribute("data-clipboard-text",e.href),e.href=`${n}`}),h(e,"click").pipe(W(o)).subscribe(n=>n.preventDefault()),ms(e,t).pipe(w(n=>r.next(n)),_(()=>r.complete()),m(n=>$({ref:e},n)))}function fi(e,{worker$:t,keyboard$:r}){let o=new g,n=Se("search-query"),i=O(h(n,"keydown"),h(n,"focus")).pipe(ve(se),m(()=>n.value),K());return o.pipe(He(i),m(([{suggest:s},p])=>{let c=p.split(/([\s-]+)/);if(s!=null&&s.length&&c[c.length-1]){let l=s[s.length-1];l.startsWith(c[c.length-1])&&(c[c.length-1]=l)}else c.length=0;return c})).subscribe(s=>e.innerHTML=s.join("").replace(/\s/g," ")),r.pipe(b(({mode:s})=>s==="search")).subscribe(s=>{switch(s.type){case"ArrowRight":e.innerText.length&&n.selectionStart===n.value.length&&(n.value=e.innerText);break}}),t.pipe(b(dr),m(({data:s})=>s)).pipe(w(s=>o.next(s)),_(()=>o.complete()),m(()=>({ref:e})))}function ui(e,{index$:t,keyboard$:r}){let o=xe();try{let n=ai(o.search,t),i=Se("search-query",e),a=Se("search-result",e);h(e,"click").pipe(b(({target:p})=>p instanceof Element&&!!p.closest("a"))).subscribe(()=>Je("search",!1)),r.pipe(b(({mode:p})=>p==="search")).subscribe(p=>{let c=Ie();switch(p.type){case"Enter":if(c===i){let l=new Map;for(let f of P(":first-child [href]",a)){let u=f.firstElementChild;l.set(f,parseFloat(u.getAttribute("data-md-score")))}if(l.size){let[[f]]=[...l].sort(([,u],[,d])=>d-u);f.click()}p.claim()}break;case"Escape":case"Tab":Je("search",!1),i.blur();break;case"ArrowUp":case"ArrowDown":if(typeof c=="undefined")i.focus();else{let l=[i,...P(":not(details) > [href], summary, details[open] [href]",a)],f=Math.max(0,(Math.max(0,l.indexOf(c))+l.length+(p.type==="ArrowUp"?-1:1))%l.length);l[f].focus()}p.claim();break;default:i!==Ie()&&i.focus()}}),r.pipe(b(({mode:p})=>p==="global")).subscribe(p=>{switch(p.type){case"f":case"s":case"/":i.focus(),i.select(),p.claim();break}});let s=pi(i,{worker$:n});return O(s,li(a,{worker$:n,query$:s})).pipe(Re(...ae("search-share",e).map(p=>mi(p,{query$:s})),...ae("search-suggest",e).map(p=>fi(p,{worker$:n,keyboard$:r}))))}catch(n){return e.hidden=!0,Ye}}function di(e,{index$:t,location$:r}){return z([t,r.pipe(Q(ye()),b(o=>!!o.searchParams.get("h")))]).pipe(m(([o,n])=>ii(o.config)(n.searchParams.get("h"))),m(o=>{var a;let n=new Map,i=document.createNodeIterator(e,NodeFilter.SHOW_TEXT);for(let s=i.nextNode();s;s=i.nextNode())if((a=s.parentElement)!=null&&a.offsetHeight){let p=s.textContent,c=o(p);c.length>p.length&&n.set(s,c)}for(let[s,p]of n){let{childNodes:c}=x("span",null,p);s.replaceWith(...Array.from(c))}return{ref:e,nodes:n}}))}function fs(e,{viewport$:t,main$:r}){let o=e.closest(".md-grid"),n=o.offsetTop-o.parentElement.offsetTop;return z([r,t]).pipe(m(([{offset:i,height:a},{offset:{y:s}}])=>(a=a+Math.min(n,Math.max(0,s-i))-n,{height:a,locked:s>=i+n})),K((i,a)=>i.height===a.height&&i.locked===a.locked))}function Zr(e,o){var n=o,{header$:t}=n,r=so(n,["header$"]);let i=R(".md-sidebar__scrollwrap",e),{y:a}=Ve(i);return C(()=>{let s=new g,p=s.pipe(Z(),ie(!0)),c=s.pipe(Me(0,me));return c.pipe(re(t)).subscribe({next([{height:l},{height:f}]){i.style.height=`${l-2*a}px`,e.style.top=`${f}px`},complete(){i.style.height="",e.style.top=""}}),c.pipe(Ae()).subscribe(()=>{for(let l of P(".md-nav__link--active[href]",e)){if(!l.clientHeight)continue;let f=l.closest(".md-sidebar__scrollwrap");if(typeof f!="undefined"){let u=l.offsetTop-f.offsetTop,{height:d}=ce(f);f.scrollTo({top:u-d/2})}}}),ue(P("label[tabindex]",e)).pipe(ne(l=>h(l,"click").pipe(ve(se),m(()=>l),W(p)))).subscribe(l=>{let f=R(`[id="${l.htmlFor}"]`);R(`[aria-labelledby="${l.id}"]`).setAttribute("aria-expanded",`${f.checked}`)}),fs(e,r).pipe(w(l=>s.next(l)),_(()=>s.complete()),m(l=>$({ref:e},l)))})}function hi(e,t){if(typeof t!="undefined"){let r=`https://api.github.com/repos/${e}/${t}`;return st(je(`${r}/releases/latest`).pipe(de(()=>S),m(o=>({version:o.tag_name})),De({})),je(r).pipe(de(()=>S),m(o=>({stars:o.stargazers_count,forks:o.forks_count})),De({}))).pipe(m(([o,n])=>$($({},o),n)))}else{let r=`https://api.github.com/users/${e}`;return je(r).pipe(m(o=>({repositories:o.public_repos})),De({}))}}function bi(e,t){let r=`https://${e}/api/v4/projects/${encodeURIComponent(t)}`;return st(je(`${r}/releases/permalink/latest`).pipe(de(()=>S),m(({tag_name:o})=>({version:o})),De({})),je(r).pipe(de(()=>S),m(({star_count:o,forks_count:n})=>({stars:o,forks:n})),De({}))).pipe(m(([o,n])=>$($({},o),n)))}function vi(e){let t=e.match(/^.+github\.com\/([^/]+)\/?([^/]+)?/i);if(t){let[,r,o]=t;return hi(r,o)}if(t=e.match(/^.+?([^/]*gitlab[^/]+)\/(.+?)\/?$/i),t){let[,r,o]=t;return bi(r,o)}return S}var us;function ds(e){return us||(us=C(()=>{let t=__md_get("__source",sessionStorage);if(t)return I(t);if(ae("consent").length){let o=__md_get("__consent");if(!(o&&o.github))return S}return vi(e.href).pipe(w(o=>__md_set("__source",o,sessionStorage)))}).pipe(de(()=>S),b(t=>Object.keys(t).length>0),m(t=>({facts:t})),G(1)))}function gi(e){let t=R(":scope > :last-child",e);return C(()=>{let r=new g;return r.subscribe(({facts:o})=>{t.appendChild(_n(o)),t.classList.add("md-source__repository--active")}),ds(e).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}function hs(e,{viewport$:t,header$:r}){return ge(document.body).pipe(v(()=>mr(e,{header$:r,viewport$:t})),m(({offset:{y:o}})=>({hidden:o>=10})),ee("hidden"))}function yi(e,t){return C(()=>{let r=new g;return r.subscribe({next({hidden:o}){e.hidden=o},complete(){e.hidden=!1}}),(B("navigation.tabs.sticky")?I({hidden:!1}):hs(e,t)).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}function bs(e,{viewport$:t,header$:r}){let o=new Map,n=P(".md-nav__link",e);for(let s of n){let p=decodeURIComponent(s.hash.substring(1)),c=fe(`[id="${p}"]`);typeof c!="undefined"&&o.set(s,c)}let i=r.pipe(ee("height"),m(({height:s})=>{let p=Se("main"),c=R(":scope > :first-child",p);return s+.8*(c.offsetTop-p.offsetTop)}),pe());return ge(document.body).pipe(ee("height"),v(s=>C(()=>{let p=[];return I([...o].reduce((c,[l,f])=>{for(;p.length&&o.get(p[p.length-1]).tagName>=f.tagName;)p.pop();let u=f.offsetTop;for(;!u&&f.parentElement;)f=f.parentElement,u=f.offsetTop;let d=f.offsetParent;for(;d;d=d.offsetParent)u+=d.offsetTop;return c.set([...p=[...p,l]].reverse(),u)},new Map))}).pipe(m(p=>new Map([...p].sort(([,c],[,l])=>c-l))),He(i),v(([p,c])=>t.pipe(Fr(([l,f],{offset:{y:u},size:d})=>{let y=u+d.height>=Math.floor(s.height);for(;f.length;){let[,L]=f[0];if(L-c=u&&!y)f=[l.pop(),...f];else break}return[l,f]},[[],[...p]]),K((l,f)=>l[0]===f[0]&&l[1]===f[1])))))).pipe(m(([s,p])=>({prev:s.map(([c])=>c),next:p.map(([c])=>c)})),Q({prev:[],next:[]}),Be(2,1),m(([s,p])=>s.prev.length{let i=new g,a=i.pipe(Z(),ie(!0));if(i.subscribe(({prev:s,next:p})=>{for(let[c]of p)c.classList.remove("md-nav__link--passed"),c.classList.remove("md-nav__link--active");for(let[c,[l]]of s.entries())l.classList.add("md-nav__link--passed"),l.classList.toggle("md-nav__link--active",c===s.length-1)}),B("toc.follow")){let s=O(t.pipe(_e(1),m(()=>{})),t.pipe(_e(250),m(()=>"smooth")));i.pipe(b(({prev:p})=>p.length>0),He(o.pipe(ve(se))),re(s)).subscribe(([[{prev:p}],c])=>{let[l]=p[p.length-1];if(l.offsetHeight){let f=cr(l);if(typeof f!="undefined"){let u=l.offsetTop-f.offsetTop,{height:d}=ce(f);f.scrollTo({top:u-d/2,behavior:c})}}})}return B("navigation.tracking")&&t.pipe(W(a),ee("offset"),_e(250),Ce(1),W(n.pipe(Ce(1))),ct({delay:250}),re(i)).subscribe(([,{prev:s}])=>{let p=ye(),c=s[s.length-1];if(c&&c.length){let[l]=c,{hash:f}=new URL(l.href);p.hash!==f&&(p.hash=f,history.replaceState({},"",`${p}`))}else p.hash="",history.replaceState({},"",`${p}`)}),bs(e,{viewport$:t,header$:r}).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))})}function vs(e,{viewport$:t,main$:r,target$:o}){let n=t.pipe(m(({offset:{y:a}})=>a),Be(2,1),m(([a,s])=>a>s&&s>0),K()),i=r.pipe(m(({active:a})=>a));return z([i,n]).pipe(m(([a,s])=>!(a&&s)),K(),W(o.pipe(Ce(1))),ie(!0),ct({delay:250}),m(a=>({hidden:a})))}function Ei(e,{viewport$:t,header$:r,main$:o,target$:n}){let i=new g,a=i.pipe(Z(),ie(!0));return i.subscribe({next({hidden:s}){e.hidden=s,s?(e.setAttribute("tabindex","-1"),e.blur()):e.removeAttribute("tabindex")},complete(){e.style.top="",e.hidden=!0,e.removeAttribute("tabindex")}}),r.pipe(W(a),ee("height")).subscribe(({height:s})=>{e.style.top=`${s+16}px`}),h(e,"click").subscribe(s=>{s.preventDefault(),window.scrollTo({top:0})}),vs(e,{viewport$:t,main$:o,target$:n}).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))}function wi({document$:e,viewport$:t}){e.pipe(v(()=>P(".md-ellipsis")),ne(r=>tt(r).pipe(W(e.pipe(Ce(1))),b(o=>o),m(()=>r),Te(1))),b(r=>r.offsetWidth{let o=r.innerText,n=r.closest("a")||r;return n.title=o,B("content.tooltips")?mt(n,{viewport$:t}).pipe(W(e.pipe(Ce(1))),_(()=>n.removeAttribute("title"))):S})).subscribe(),B("content.tooltips")&&e.pipe(v(()=>P(".md-status")),ne(r=>mt(r,{viewport$:t}))).subscribe()}function Ti({document$:e,tablet$:t}){e.pipe(v(()=>P(".md-toggle--indeterminate")),w(r=>{r.indeterminate=!0,r.checked=!1}),ne(r=>h(r,"change").pipe(Dr(()=>r.classList.contains("md-toggle--indeterminate")),m(()=>r))),re(t)).subscribe(([r,o])=>{r.classList.remove("md-toggle--indeterminate"),o&&(r.checked=!1)})}function gs(){return/(iPad|iPhone|iPod)/.test(navigator.userAgent)}function Si({document$:e}){e.pipe(v(()=>P("[data-md-scrollfix]")),w(t=>t.removeAttribute("data-md-scrollfix")),b(gs),ne(t=>h(t,"touchstart").pipe(m(()=>t)))).subscribe(t=>{let r=t.scrollTop;r===0?t.scrollTop=1:r+t.offsetHeight===t.scrollHeight&&(t.scrollTop=r-1)})}function Oi({viewport$:e,tablet$:t}){z([ze("search"),t]).pipe(m(([r,o])=>r&&!o),v(r=>I(r).pipe(Ge(r?400:100))),re(e)).subscribe(([r,{offset:{y:o}}])=>{if(r)document.body.setAttribute("data-md-scrolllock",""),document.body.style.top=`-${o}px`;else{let n=-1*parseInt(document.body.style.top,10);document.body.removeAttribute("data-md-scrolllock"),document.body.style.top="",n&&window.scrollTo(0,n)}})}Object.entries||(Object.entries=function(e){let t=[];for(let r of Object.keys(e))t.push([r,e[r]]);return t});Object.values||(Object.values=function(e){let t=[];for(let r of Object.keys(e))t.push(e[r]);return t});typeof Element!="undefined"&&(Element.prototype.scrollTo||(Element.prototype.scrollTo=function(e,t){typeof e=="object"?(this.scrollLeft=e.left,this.scrollTop=e.top):(this.scrollLeft=e,this.scrollTop=t)}),Element.prototype.replaceWith||(Element.prototype.replaceWith=function(...e){let t=this.parentNode;if(t){e.length===0&&t.removeChild(this);for(let r=e.length-1;r>=0;r--){let o=e[r];typeof o=="string"?o=document.createTextNode(o):o.parentNode&&o.parentNode.removeChild(o),r?t.insertBefore(this.previousSibling,o):t.replaceChild(o,this)}}}));function ys(){return location.protocol==="file:"?Tt(`${new URL("search/search_index.js",eo.base)}`).pipe(m(()=>__index),G(1)):je(new URL("search/search_index.json",eo.base))}document.documentElement.classList.remove("no-js");document.documentElement.classList.add("js");var ot=Go(),Ut=sn(),Lt=ln(Ut),to=an(),Oe=gn(),hr=Pt("(min-width: 960px)"),Mi=Pt("(min-width: 1220px)"),_i=mn(),eo=xe(),Ai=document.forms.namedItem("search")?ys():Ye,ro=new g;Zn({alert$:ro});var oo=new g;B("navigation.instant")&&oi({location$:Ut,viewport$:Oe,progress$:oo}).subscribe(ot);var Li;((Li=eo.version)==null?void 0:Li.provider)==="mike"&&ci({document$:ot});O(Ut,Lt).pipe(Ge(125)).subscribe(()=>{Je("drawer",!1),Je("search",!1)});to.pipe(b(({mode:e})=>e==="global")).subscribe(e=>{switch(e.type){case"p":case",":let t=fe("link[rel=prev]");typeof t!="undefined"&<(t);break;case"n":case".":let r=fe("link[rel=next]");typeof r!="undefined"&<(r);break;case"Enter":let o=Ie();o instanceof HTMLLabelElement&&o.click()}});wi({viewport$:Oe,document$:ot});Ti({document$:ot,tablet$:hr});Si({document$:ot});Oi({viewport$:Oe,tablet$:hr});var rt=Kn(Se("header"),{viewport$:Oe}),Ft=ot.pipe(m(()=>Se("main")),v(e=>Gn(e,{viewport$:Oe,header$:rt})),G(1)),xs=O(...ae("consent").map(e=>En(e,{target$:Lt})),...ae("dialog").map(e=>qn(e,{alert$:ro})),...ae("palette").map(e=>Jn(e)),...ae("progress").map(e=>Xn(e,{progress$:oo})),...ae("search").map(e=>ui(e,{index$:Ai,keyboard$:to})),...ae("source").map(e=>gi(e))),Es=C(()=>O(...ae("announce").map(e=>xn(e)),...ae("content").map(e=>zn(e,{viewport$:Oe,target$:Lt,print$:_i})),...ae("content").map(e=>B("search.highlight")?di(e,{index$:Ai,location$:Ut}):S),...ae("header").map(e=>Yn(e,{viewport$:Oe,header$:rt,main$:Ft})),...ae("header-title").map(e=>Bn(e,{viewport$:Oe,header$:rt})),...ae("sidebar").map(e=>e.getAttribute("data-md-type")==="navigation"?Nr(Mi,()=>Zr(e,{viewport$:Oe,header$:rt,main$:Ft})):Nr(hr,()=>Zr(e,{viewport$:Oe,header$:rt,main$:Ft}))),...ae("tabs").map(e=>yi(e,{viewport$:Oe,header$:rt})),...ae("toc").map(e=>xi(e,{viewport$:Oe,header$:rt,main$:Ft,target$:Lt})),...ae("top").map(e=>Ei(e,{viewport$:Oe,header$:rt,main$:Ft,target$:Lt})))),Ci=ot.pipe(v(()=>Es),Re(xs),G(1));Ci.subscribe();window.document$=ot;window.location$=Ut;window.target$=Lt;window.keyboard$=to;window.viewport$=Oe;window.tablet$=hr;window.screen$=Mi;window.print$=_i;window.alert$=ro;window.progress$=oo;window.component$=Ci;})(); +//# sourceMappingURL=bundle.83f73b43.min.js.map + diff --git a/assets/javascripts/bundle.83f73b43.min.js.map b/assets/javascripts/bundle.83f73b43.min.js.map new file mode 100644 index 00000000..fe920b7d --- /dev/null +++ b/assets/javascripts/bundle.83f73b43.min.js.map @@ -0,0 +1,7 @@ +{ + "version": 3, + "sources": ["node_modules/focus-visible/dist/focus-visible.js", "node_modules/escape-html/index.js", "node_modules/clipboard/dist/clipboard.js", "src/templates/assets/javascripts/bundle.ts", "node_modules/tslib/tslib.es6.mjs", "node_modules/rxjs/src/internal/util/isFunction.ts", "node_modules/rxjs/src/internal/util/createErrorClass.ts", "node_modules/rxjs/src/internal/util/UnsubscriptionError.ts", "node_modules/rxjs/src/internal/util/arrRemove.ts", "node_modules/rxjs/src/internal/Subscription.ts", "node_modules/rxjs/src/internal/config.ts", "node_modules/rxjs/src/internal/scheduler/timeoutProvider.ts", "node_modules/rxjs/src/internal/util/reportUnhandledError.ts", "node_modules/rxjs/src/internal/util/noop.ts", "node_modules/rxjs/src/internal/NotificationFactories.ts", "node_modules/rxjs/src/internal/util/errorContext.ts", "node_modules/rxjs/src/internal/Subscriber.ts", "node_modules/rxjs/src/internal/symbol/observable.ts", "node_modules/rxjs/src/internal/util/identity.ts", "node_modules/rxjs/src/internal/util/pipe.ts", "node_modules/rxjs/src/internal/Observable.ts", "node_modules/rxjs/src/internal/util/lift.ts", "node_modules/rxjs/src/internal/operators/OperatorSubscriber.ts", "node_modules/rxjs/src/internal/scheduler/animationFrameProvider.ts", "node_modules/rxjs/src/internal/util/ObjectUnsubscribedError.ts", "node_modules/rxjs/src/internal/Subject.ts", "node_modules/rxjs/src/internal/BehaviorSubject.ts", "node_modules/rxjs/src/internal/scheduler/dateTimestampProvider.ts", "node_modules/rxjs/src/internal/ReplaySubject.ts", "node_modules/rxjs/src/internal/scheduler/Action.ts", "node_modules/rxjs/src/internal/scheduler/intervalProvider.ts", "node_modules/rxjs/src/internal/scheduler/AsyncAction.ts", "node_modules/rxjs/src/internal/Scheduler.ts", "node_modules/rxjs/src/internal/scheduler/AsyncScheduler.ts", "node_modules/rxjs/src/internal/scheduler/async.ts", "node_modules/rxjs/src/internal/scheduler/QueueAction.ts", "node_modules/rxjs/src/internal/scheduler/QueueScheduler.ts", "node_modules/rxjs/src/internal/scheduler/queue.ts", "node_modules/rxjs/src/internal/scheduler/AnimationFrameAction.ts", "node_modules/rxjs/src/internal/scheduler/AnimationFrameScheduler.ts", "node_modules/rxjs/src/internal/scheduler/animationFrame.ts", "node_modules/rxjs/src/internal/observable/empty.ts", "node_modules/rxjs/src/internal/util/isScheduler.ts", "node_modules/rxjs/src/internal/util/args.ts", "node_modules/rxjs/src/internal/util/isArrayLike.ts", "node_modules/rxjs/src/internal/util/isPromise.ts", "node_modules/rxjs/src/internal/util/isInteropObservable.ts", "node_modules/rxjs/src/internal/util/isAsyncIterable.ts", "node_modules/rxjs/src/internal/util/throwUnobservableError.ts", "node_modules/rxjs/src/internal/symbol/iterator.ts", "node_modules/rxjs/src/internal/util/isIterable.ts", "node_modules/rxjs/src/internal/util/isReadableStreamLike.ts", "node_modules/rxjs/src/internal/observable/innerFrom.ts", "node_modules/rxjs/src/internal/util/executeSchedule.ts", "node_modules/rxjs/src/internal/operators/observeOn.ts", "node_modules/rxjs/src/internal/operators/subscribeOn.ts", "node_modules/rxjs/src/internal/scheduled/scheduleObservable.ts", "node_modules/rxjs/src/internal/scheduled/schedulePromise.ts", "node_modules/rxjs/src/internal/scheduled/scheduleArray.ts", "node_modules/rxjs/src/internal/scheduled/scheduleIterable.ts", "node_modules/rxjs/src/internal/scheduled/scheduleAsyncIterable.ts", "node_modules/rxjs/src/internal/scheduled/scheduleReadableStreamLike.ts", "node_modules/rxjs/src/internal/scheduled/scheduled.ts", "node_modules/rxjs/src/internal/observable/from.ts", "node_modules/rxjs/src/internal/observable/of.ts", "node_modules/rxjs/src/internal/observable/throwError.ts", "node_modules/rxjs/src/internal/util/EmptyError.ts", "node_modules/rxjs/src/internal/util/isDate.ts", "node_modules/rxjs/src/internal/operators/map.ts", "node_modules/rxjs/src/internal/util/mapOneOrManyArgs.ts", "node_modules/rxjs/src/internal/util/argsArgArrayOrObject.ts", "node_modules/rxjs/src/internal/util/createObject.ts", "node_modules/rxjs/src/internal/observable/combineLatest.ts", "node_modules/rxjs/src/internal/operators/mergeInternals.ts", "node_modules/rxjs/src/internal/operators/mergeMap.ts", "node_modules/rxjs/src/internal/operators/mergeAll.ts", "node_modules/rxjs/src/internal/operators/concatAll.ts", "node_modules/rxjs/src/internal/observable/concat.ts", "node_modules/rxjs/src/internal/observable/defer.ts", "node_modules/rxjs/src/internal/observable/fromEvent.ts", "node_modules/rxjs/src/internal/observable/fromEventPattern.ts", "node_modules/rxjs/src/internal/observable/timer.ts", "node_modules/rxjs/src/internal/observable/merge.ts", "node_modules/rxjs/src/internal/observable/never.ts", "node_modules/rxjs/src/internal/util/argsOrArgArray.ts", "node_modules/rxjs/src/internal/operators/filter.ts", "node_modules/rxjs/src/internal/observable/zip.ts", "node_modules/rxjs/src/internal/operators/audit.ts", "node_modules/rxjs/src/internal/operators/auditTime.ts", "node_modules/rxjs/src/internal/operators/bufferCount.ts", "node_modules/rxjs/src/internal/operators/catchError.ts", "node_modules/rxjs/src/internal/operators/scanInternals.ts", "node_modules/rxjs/src/internal/operators/combineLatest.ts", "node_modules/rxjs/src/internal/operators/combineLatestWith.ts", "node_modules/rxjs/src/internal/operators/debounce.ts", "node_modules/rxjs/src/internal/operators/debounceTime.ts", "node_modules/rxjs/src/internal/operators/defaultIfEmpty.ts", "node_modules/rxjs/src/internal/operators/take.ts", "node_modules/rxjs/src/internal/operators/ignoreElements.ts", "node_modules/rxjs/src/internal/operators/mapTo.ts", "node_modules/rxjs/src/internal/operators/delayWhen.ts", "node_modules/rxjs/src/internal/operators/delay.ts", "node_modules/rxjs/src/internal/operators/distinctUntilChanged.ts", "node_modules/rxjs/src/internal/operators/distinctUntilKeyChanged.ts", "node_modules/rxjs/src/internal/operators/throwIfEmpty.ts", "node_modules/rxjs/src/internal/operators/endWith.ts", "node_modules/rxjs/src/internal/operators/finalize.ts", "node_modules/rxjs/src/internal/operators/first.ts", "node_modules/rxjs/src/internal/operators/takeLast.ts", "node_modules/rxjs/src/internal/operators/merge.ts", "node_modules/rxjs/src/internal/operators/mergeWith.ts", "node_modules/rxjs/src/internal/operators/repeat.ts", "node_modules/rxjs/src/internal/operators/scan.ts", "node_modules/rxjs/src/internal/operators/share.ts", "node_modules/rxjs/src/internal/operators/shareReplay.ts", "node_modules/rxjs/src/internal/operators/skip.ts", "node_modules/rxjs/src/internal/operators/skipUntil.ts", "node_modules/rxjs/src/internal/operators/startWith.ts", "node_modules/rxjs/src/internal/operators/switchMap.ts", "node_modules/rxjs/src/internal/operators/takeUntil.ts", "node_modules/rxjs/src/internal/operators/takeWhile.ts", "node_modules/rxjs/src/internal/operators/tap.ts", "node_modules/rxjs/src/internal/operators/throttle.ts", "node_modules/rxjs/src/internal/operators/throttleTime.ts", "node_modules/rxjs/src/internal/operators/withLatestFrom.ts", "node_modules/rxjs/src/internal/operators/zip.ts", "node_modules/rxjs/src/internal/operators/zipWith.ts", "src/templates/assets/javascripts/browser/document/index.ts", "src/templates/assets/javascripts/browser/element/_/index.ts", "src/templates/assets/javascripts/browser/element/focus/index.ts", "src/templates/assets/javascripts/browser/element/hover/index.ts", "src/templates/assets/javascripts/utilities/h/index.ts", "src/templates/assets/javascripts/utilities/round/index.ts", "src/templates/assets/javascripts/browser/script/index.ts", "src/templates/assets/javascripts/browser/element/size/_/index.ts", "src/templates/assets/javascripts/browser/element/size/content/index.ts", "src/templates/assets/javascripts/browser/element/offset/_/index.ts", "src/templates/assets/javascripts/browser/element/offset/content/index.ts", "src/templates/assets/javascripts/browser/element/visibility/index.ts", "src/templates/assets/javascripts/browser/toggle/index.ts", "src/templates/assets/javascripts/browser/keyboard/index.ts", "src/templates/assets/javascripts/browser/location/_/index.ts", "src/templates/assets/javascripts/browser/location/hash/index.ts", "src/templates/assets/javascripts/browser/media/index.ts", "src/templates/assets/javascripts/browser/request/index.ts", "src/templates/assets/javascripts/browser/viewport/offset/index.ts", "src/templates/assets/javascripts/browser/viewport/size/index.ts", "src/templates/assets/javascripts/browser/viewport/_/index.ts", "src/templates/assets/javascripts/browser/viewport/at/index.ts", "src/templates/assets/javascripts/browser/worker/index.ts", "src/templates/assets/javascripts/_/index.ts", "src/templates/assets/javascripts/components/_/index.ts", "src/templates/assets/javascripts/components/announce/index.ts", "src/templates/assets/javascripts/components/consent/index.ts", "src/templates/assets/javascripts/templates/tooltip/index.tsx", "src/templates/assets/javascripts/templates/annotation/index.tsx", "src/templates/assets/javascripts/templates/clipboard/index.tsx", "src/templates/assets/javascripts/templates/search/index.tsx", "src/templates/assets/javascripts/templates/source/index.tsx", "src/templates/assets/javascripts/templates/tabbed/index.tsx", "src/templates/assets/javascripts/templates/table/index.tsx", "src/templates/assets/javascripts/templates/version/index.tsx", "src/templates/assets/javascripts/components/tooltip2/index.ts", "src/templates/assets/javascripts/components/content/annotation/_/index.ts", "src/templates/assets/javascripts/components/content/annotation/list/index.ts", "src/templates/assets/javascripts/components/content/annotation/block/index.ts", "src/templates/assets/javascripts/components/content/code/_/index.ts", "src/templates/assets/javascripts/components/content/details/index.ts", "src/templates/assets/javascripts/components/content/mermaid/index.css", "src/templates/assets/javascripts/components/content/mermaid/index.ts", "src/templates/assets/javascripts/components/content/table/index.ts", "src/templates/assets/javascripts/components/content/tabs/index.ts", "src/templates/assets/javascripts/components/content/_/index.ts", "src/templates/assets/javascripts/components/dialog/index.ts", "src/templates/assets/javascripts/components/tooltip/index.ts", "src/templates/assets/javascripts/components/header/_/index.ts", "src/templates/assets/javascripts/components/header/title/index.ts", "src/templates/assets/javascripts/components/main/index.ts", "src/templates/assets/javascripts/components/palette/index.ts", "src/templates/assets/javascripts/components/progress/index.ts", "src/templates/assets/javascripts/integrations/clipboard/index.ts", "src/templates/assets/javascripts/integrations/sitemap/index.ts", "src/templates/assets/javascripts/integrations/instant/index.ts", "src/templates/assets/javascripts/integrations/search/highlighter/index.ts", "src/templates/assets/javascripts/integrations/search/worker/message/index.ts", "src/templates/assets/javascripts/integrations/search/worker/_/index.ts", "src/templates/assets/javascripts/integrations/version/findurl/index.ts", "src/templates/assets/javascripts/integrations/version/index.ts", "src/templates/assets/javascripts/components/search/query/index.ts", "src/templates/assets/javascripts/components/search/result/index.ts", "src/templates/assets/javascripts/components/search/share/index.ts", "src/templates/assets/javascripts/components/search/suggest/index.ts", "src/templates/assets/javascripts/components/search/_/index.ts", "src/templates/assets/javascripts/components/search/highlight/index.ts", "src/templates/assets/javascripts/components/sidebar/index.ts", "src/templates/assets/javascripts/components/source/facts/github/index.ts", "src/templates/assets/javascripts/components/source/facts/gitlab/index.ts", "src/templates/assets/javascripts/components/source/facts/_/index.ts", "src/templates/assets/javascripts/components/source/_/index.ts", "src/templates/assets/javascripts/components/tabs/index.ts", "src/templates/assets/javascripts/components/toc/index.ts", "src/templates/assets/javascripts/components/top/index.ts", "src/templates/assets/javascripts/patches/ellipsis/index.ts", "src/templates/assets/javascripts/patches/indeterminate/index.ts", "src/templates/assets/javascripts/patches/scrollfix/index.ts", "src/templates/assets/javascripts/patches/scrolllock/index.ts", "src/templates/assets/javascripts/polyfills/index.ts"], + "sourcesContent": ["(function (global, factory) {\n typeof exports === 'object' && typeof module !== 'undefined' ? factory() :\n typeof define === 'function' && define.amd ? define(factory) :\n (factory());\n}(this, (function () { 'use strict';\n\n /**\n * Applies the :focus-visible polyfill at the given scope.\n * A scope in this case is either the top-level Document or a Shadow Root.\n *\n * @param {(Document|ShadowRoot)} scope\n * @see https://github.com/WICG/focus-visible\n */\n function applyFocusVisiblePolyfill(scope) {\n var hadKeyboardEvent = true;\n var hadFocusVisibleRecently = false;\n var hadFocusVisibleRecentlyTimeout = null;\n\n var inputTypesAllowlist = {\n text: true,\n search: true,\n url: true,\n tel: true,\n email: true,\n password: true,\n number: true,\n date: true,\n month: true,\n week: true,\n time: true,\n datetime: true,\n 'datetime-local': true\n };\n\n /**\n * Helper function for legacy browsers and iframes which sometimes focus\n * elements like document, body, and non-interactive SVG.\n * @param {Element} el\n */\n function isValidFocusTarget(el) {\n if (\n el &&\n el !== document &&\n el.nodeName !== 'HTML' &&\n el.nodeName !== 'BODY' &&\n 'classList' in el &&\n 'contains' in el.classList\n ) {\n return true;\n }\n return false;\n }\n\n /**\n * Computes whether the given element should automatically trigger the\n * `focus-visible` class being added, i.e. whether it should always match\n * `:focus-visible` when focused.\n * @param {Element} el\n * @return {boolean}\n */\n function focusTriggersKeyboardModality(el) {\n var type = el.type;\n var tagName = el.tagName;\n\n if (tagName === 'INPUT' && inputTypesAllowlist[type] && !el.readOnly) {\n return true;\n }\n\n if (tagName === 'TEXTAREA' && !el.readOnly) {\n return true;\n }\n\n if (el.isContentEditable) {\n return true;\n }\n\n return false;\n }\n\n /**\n * Add the `focus-visible` class to the given element if it was not added by\n * the author.\n * @param {Element} el\n */\n function addFocusVisibleClass(el) {\n if (el.classList.contains('focus-visible')) {\n return;\n }\n el.classList.add('focus-visible');\n el.setAttribute('data-focus-visible-added', '');\n }\n\n /**\n * Remove the `focus-visible` class from the given element if it was not\n * originally added by the author.\n * @param {Element} el\n */\n function removeFocusVisibleClass(el) {\n if (!el.hasAttribute('data-focus-visible-added')) {\n return;\n }\n el.classList.remove('focus-visible');\n el.removeAttribute('data-focus-visible-added');\n }\n\n /**\n * If the most recent user interaction was via the keyboard;\n * and the key press did not include a meta, alt/option, or control key;\n * then the modality is keyboard. Otherwise, the modality is not keyboard.\n * Apply `focus-visible` to any current active element and keep track\n * of our keyboard modality state with `hadKeyboardEvent`.\n * @param {KeyboardEvent} e\n */\n function onKeyDown(e) {\n if (e.metaKey || e.altKey || e.ctrlKey) {\n return;\n }\n\n if (isValidFocusTarget(scope.activeElement)) {\n addFocusVisibleClass(scope.activeElement);\n }\n\n hadKeyboardEvent = true;\n }\n\n /**\n * If at any point a user clicks with a pointing device, ensure that we change\n * the modality away from keyboard.\n * This avoids the situation where a user presses a key on an already focused\n * element, and then clicks on a different element, focusing it with a\n * pointing device, while we still think we're in keyboard modality.\n * @param {Event} e\n */\n function onPointerDown(e) {\n hadKeyboardEvent = false;\n }\n\n /**\n * On `focus`, add the `focus-visible` class to the target if:\n * - the target received focus as a result of keyboard navigation, or\n * - the event target is an element that will likely require interaction\n * via the keyboard (e.g. a text box)\n * @param {Event} e\n */\n function onFocus(e) {\n // Prevent IE from focusing the document or HTML element.\n if (!isValidFocusTarget(e.target)) {\n return;\n }\n\n if (hadKeyboardEvent || focusTriggersKeyboardModality(e.target)) {\n addFocusVisibleClass(e.target);\n }\n }\n\n /**\n * On `blur`, remove the `focus-visible` class from the target.\n * @param {Event} e\n */\n function onBlur(e) {\n if (!isValidFocusTarget(e.target)) {\n return;\n }\n\n if (\n e.target.classList.contains('focus-visible') ||\n e.target.hasAttribute('data-focus-visible-added')\n ) {\n // To detect a tab/window switch, we look for a blur event followed\n // rapidly by a visibility change.\n // If we don't see a visibility change within 100ms, it's probably a\n // regular focus change.\n hadFocusVisibleRecently = true;\n window.clearTimeout(hadFocusVisibleRecentlyTimeout);\n hadFocusVisibleRecentlyTimeout = window.setTimeout(function() {\n hadFocusVisibleRecently = false;\n }, 100);\n removeFocusVisibleClass(e.target);\n }\n }\n\n /**\n * If the user changes tabs, keep track of whether or not the previously\n * focused element had .focus-visible.\n * @param {Event} e\n */\n function onVisibilityChange(e) {\n if (document.visibilityState === 'hidden') {\n // If the tab becomes active again, the browser will handle calling focus\n // on the element (Safari actually calls it twice).\n // If this tab change caused a blur on an element with focus-visible,\n // re-apply the class when the user switches back to the tab.\n if (hadFocusVisibleRecently) {\n hadKeyboardEvent = true;\n }\n addInitialPointerMoveListeners();\n }\n }\n\n /**\n * Add a group of listeners to detect usage of any pointing devices.\n * These listeners will be added when the polyfill first loads, and anytime\n * the window is blurred, so that they are active when the window regains\n * focus.\n */\n function addInitialPointerMoveListeners() {\n document.addEventListener('mousemove', onInitialPointerMove);\n document.addEventListener('mousedown', onInitialPointerMove);\n document.addEventListener('mouseup', onInitialPointerMove);\n document.addEventListener('pointermove', onInitialPointerMove);\n document.addEventListener('pointerdown', onInitialPointerMove);\n document.addEventListener('pointerup', onInitialPointerMove);\n document.addEventListener('touchmove', onInitialPointerMove);\n document.addEventListener('touchstart', onInitialPointerMove);\n document.addEventListener('touchend', onInitialPointerMove);\n }\n\n function removeInitialPointerMoveListeners() {\n document.removeEventListener('mousemove', onInitialPointerMove);\n document.removeEventListener('mousedown', onInitialPointerMove);\n document.removeEventListener('mouseup', onInitialPointerMove);\n document.removeEventListener('pointermove', onInitialPointerMove);\n document.removeEventListener('pointerdown', onInitialPointerMove);\n document.removeEventListener('pointerup', onInitialPointerMove);\n document.removeEventListener('touchmove', onInitialPointerMove);\n document.removeEventListener('touchstart', onInitialPointerMove);\n document.removeEventListener('touchend', onInitialPointerMove);\n }\n\n /**\n * When the polfyill first loads, assume the user is in keyboard modality.\n * If any event is received from a pointing device (e.g. mouse, pointer,\n * touch), turn off keyboard modality.\n * This accounts for situations where focus enters the page from the URL bar.\n * @param {Event} e\n */\n function onInitialPointerMove(e) {\n // Work around a Safari quirk that fires a mousemove on whenever the\n // window blurs, even if you're tabbing out of the page. \u00AF\\_(\u30C4)_/\u00AF\n if (e.target.nodeName && e.target.nodeName.toLowerCase() === 'html') {\n return;\n }\n\n hadKeyboardEvent = false;\n removeInitialPointerMoveListeners();\n }\n\n // For some kinds of state, we are interested in changes at the global scope\n // only. For example, global pointer input, global key presses and global\n // visibility change should affect the state at every scope:\n document.addEventListener('keydown', onKeyDown, true);\n document.addEventListener('mousedown', onPointerDown, true);\n document.addEventListener('pointerdown', onPointerDown, true);\n document.addEventListener('touchstart', onPointerDown, true);\n document.addEventListener('visibilitychange', onVisibilityChange, true);\n\n addInitialPointerMoveListeners();\n\n // For focus and blur, we specifically care about state changes in the local\n // scope. This is because focus / blur events that originate from within a\n // shadow root are not re-dispatched from the host element if it was already\n // the active element in its own scope:\n scope.addEventListener('focus', onFocus, true);\n scope.addEventListener('blur', onBlur, true);\n\n // We detect that a node is a ShadowRoot by ensuring that it is a\n // DocumentFragment and also has a host property. This check covers native\n // implementation and polyfill implementation transparently. If we only cared\n // about the native implementation, we could just check if the scope was\n // an instance of a ShadowRoot.\n if (scope.nodeType === Node.DOCUMENT_FRAGMENT_NODE && scope.host) {\n // Since a ShadowRoot is a special kind of DocumentFragment, it does not\n // have a root element to add a class to. So, we add this attribute to the\n // host element instead:\n scope.host.setAttribute('data-js-focus-visible', '');\n } else if (scope.nodeType === Node.DOCUMENT_NODE) {\n document.documentElement.classList.add('js-focus-visible');\n document.documentElement.setAttribute('data-js-focus-visible', '');\n }\n }\n\n // It is important to wrap all references to global window and document in\n // these checks to support server-side rendering use cases\n // @see https://github.com/WICG/focus-visible/issues/199\n if (typeof window !== 'undefined' && typeof document !== 'undefined') {\n // Make the polyfill helper globally available. This can be used as a signal\n // to interested libraries that wish to coordinate with the polyfill for e.g.,\n // applying the polyfill to a shadow root:\n window.applyFocusVisiblePolyfill = applyFocusVisiblePolyfill;\n\n // Notify interested libraries of the polyfill's presence, in case the\n // polyfill was loaded lazily:\n var event;\n\n try {\n event = new CustomEvent('focus-visible-polyfill-ready');\n } catch (error) {\n // IE11 does not support using CustomEvent as a constructor directly:\n event = document.createEvent('CustomEvent');\n event.initCustomEvent('focus-visible-polyfill-ready', false, false, {});\n }\n\n window.dispatchEvent(event);\n }\n\n if (typeof document !== 'undefined') {\n // Apply the polyfill to the global document, so that no JavaScript\n // coordination is required to use the polyfill in the top-level document:\n applyFocusVisiblePolyfill(document);\n }\n\n})));\n", "/*!\n * escape-html\n * Copyright(c) 2012-2013 TJ Holowaychuk\n * Copyright(c) 2015 Andreas Lubbe\n * Copyright(c) 2015 Tiancheng \"Timothy\" Gu\n * MIT Licensed\n */\n\n'use strict';\n\n/**\n * Module variables.\n * @private\n */\n\nvar matchHtmlRegExp = /[\"'&<>]/;\n\n/**\n * Module exports.\n * @public\n */\n\nmodule.exports = escapeHtml;\n\n/**\n * Escape special characters in the given string of html.\n *\n * @param {string} string The string to escape for inserting into HTML\n * @return {string}\n * @public\n */\n\nfunction escapeHtml(string) {\n var str = '' + string;\n var match = matchHtmlRegExp.exec(str);\n\n if (!match) {\n return str;\n }\n\n var escape;\n var html = '';\n var index = 0;\n var lastIndex = 0;\n\n for (index = match.index; index < str.length; index++) {\n switch (str.charCodeAt(index)) {\n case 34: // \"\n escape = '"';\n break;\n case 38: // &\n escape = '&';\n break;\n case 39: // '\n escape = ''';\n break;\n case 60: // <\n escape = '<';\n break;\n case 62: // >\n escape = '>';\n break;\n default:\n continue;\n }\n\n if (lastIndex !== index) {\n html += str.substring(lastIndex, index);\n }\n\n lastIndex = index + 1;\n html += escape;\n }\n\n return lastIndex !== index\n ? html + str.substring(lastIndex, index)\n : html;\n}\n", "/*!\n * clipboard.js v2.0.11\n * https://clipboardjs.com/\n *\n * Licensed MIT \u00A9 Zeno Rocha\n */\n(function webpackUniversalModuleDefinition(root, factory) {\n\tif(typeof exports === 'object' && typeof module === 'object')\n\t\tmodule.exports = factory();\n\telse if(typeof define === 'function' && define.amd)\n\t\tdefine([], factory);\n\telse if(typeof exports === 'object')\n\t\texports[\"ClipboardJS\"] = factory();\n\telse\n\t\troot[\"ClipboardJS\"] = factory();\n})(this, function() {\nreturn /******/ (function() { // webpackBootstrap\n/******/ \tvar __webpack_modules__ = ({\n\n/***/ 686:\n/***/ (function(__unused_webpack_module, __webpack_exports__, __webpack_require__) {\n\n\"use strict\";\n\n// EXPORTS\n__webpack_require__.d(__webpack_exports__, {\n \"default\": function() { return /* binding */ clipboard; }\n});\n\n// EXTERNAL MODULE: ./node_modules/tiny-emitter/index.js\nvar tiny_emitter = __webpack_require__(279);\nvar tiny_emitter_default = /*#__PURE__*/__webpack_require__.n(tiny_emitter);\n// EXTERNAL MODULE: ./node_modules/good-listener/src/listen.js\nvar listen = __webpack_require__(370);\nvar listen_default = /*#__PURE__*/__webpack_require__.n(listen);\n// EXTERNAL MODULE: ./node_modules/select/src/select.js\nvar src_select = __webpack_require__(817);\nvar select_default = /*#__PURE__*/__webpack_require__.n(src_select);\n;// CONCATENATED MODULE: ./src/common/command.js\n/**\n * Executes a given operation type.\n * @param {String} type\n * @return {Boolean}\n */\nfunction command(type) {\n try {\n return document.execCommand(type);\n } catch (err) {\n return false;\n }\n}\n;// CONCATENATED MODULE: ./src/actions/cut.js\n\n\n/**\n * Cut action wrapper.\n * @param {String|HTMLElement} target\n * @return {String}\n */\n\nvar ClipboardActionCut = function ClipboardActionCut(target) {\n var selectedText = select_default()(target);\n command('cut');\n return selectedText;\n};\n\n/* harmony default export */ var actions_cut = (ClipboardActionCut);\n;// CONCATENATED MODULE: ./src/common/create-fake-element.js\n/**\n * Creates a fake textarea element with a value.\n * @param {String} value\n * @return {HTMLElement}\n */\nfunction createFakeElement(value) {\n var isRTL = document.documentElement.getAttribute('dir') === 'rtl';\n var fakeElement = document.createElement('textarea'); // Prevent zooming on iOS\n\n fakeElement.style.fontSize = '12pt'; // Reset box model\n\n fakeElement.style.border = '0';\n fakeElement.style.padding = '0';\n fakeElement.style.margin = '0'; // Move element out of screen horizontally\n\n fakeElement.style.position = 'absolute';\n fakeElement.style[isRTL ? 'right' : 'left'] = '-9999px'; // Move element to the same position vertically\n\n var yPosition = window.pageYOffset || document.documentElement.scrollTop;\n fakeElement.style.top = \"\".concat(yPosition, \"px\");\n fakeElement.setAttribute('readonly', '');\n fakeElement.value = value;\n return fakeElement;\n}\n;// CONCATENATED MODULE: ./src/actions/copy.js\n\n\n\n/**\n * Create fake copy action wrapper using a fake element.\n * @param {String} target\n * @param {Object} options\n * @return {String}\n */\n\nvar fakeCopyAction = function fakeCopyAction(value, options) {\n var fakeElement = createFakeElement(value);\n options.container.appendChild(fakeElement);\n var selectedText = select_default()(fakeElement);\n command('copy');\n fakeElement.remove();\n return selectedText;\n};\n/**\n * Copy action wrapper.\n * @param {String|HTMLElement} target\n * @param {Object} options\n * @return {String}\n */\n\n\nvar ClipboardActionCopy = function ClipboardActionCopy(target) {\n var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {\n container: document.body\n };\n var selectedText = '';\n\n if (typeof target === 'string') {\n selectedText = fakeCopyAction(target, options);\n } else if (target instanceof HTMLInputElement && !['text', 'search', 'url', 'tel', 'password'].includes(target === null || target === void 0 ? void 0 : target.type)) {\n // If input type doesn't support `setSelectionRange`. Simulate it. https://developer.mozilla.org/en-US/docs/Web/API/HTMLInputElement/setSelectionRange\n selectedText = fakeCopyAction(target.value, options);\n } else {\n selectedText = select_default()(target);\n command('copy');\n }\n\n return selectedText;\n};\n\n/* harmony default export */ var actions_copy = (ClipboardActionCopy);\n;// CONCATENATED MODULE: ./src/actions/default.js\nfunction _typeof(obj) { \"@babel/helpers - typeof\"; if (typeof Symbol === \"function\" && typeof Symbol.iterator === \"symbol\") { _typeof = function _typeof(obj) { return typeof obj; }; } else { _typeof = function _typeof(obj) { return obj && typeof Symbol === \"function\" && obj.constructor === Symbol && obj !== Symbol.prototype ? \"symbol\" : typeof obj; }; } return _typeof(obj); }\n\n\n\n/**\n * Inner function which performs selection from either `text` or `target`\n * properties and then executes copy or cut operations.\n * @param {Object} options\n */\n\nvar ClipboardActionDefault = function ClipboardActionDefault() {\n var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};\n // Defines base properties passed from constructor.\n var _options$action = options.action,\n action = _options$action === void 0 ? 'copy' : _options$action,\n container = options.container,\n target = options.target,\n text = options.text; // Sets the `action` to be performed which can be either 'copy' or 'cut'.\n\n if (action !== 'copy' && action !== 'cut') {\n throw new Error('Invalid \"action\" value, use either \"copy\" or \"cut\"');\n } // Sets the `target` property using an element that will be have its content copied.\n\n\n if (target !== undefined) {\n if (target && _typeof(target) === 'object' && target.nodeType === 1) {\n if (action === 'copy' && target.hasAttribute('disabled')) {\n throw new Error('Invalid \"target\" attribute. Please use \"readonly\" instead of \"disabled\" attribute');\n }\n\n if (action === 'cut' && (target.hasAttribute('readonly') || target.hasAttribute('disabled'))) {\n throw new Error('Invalid \"target\" attribute. You can\\'t cut text from elements with \"readonly\" or \"disabled\" attributes');\n }\n } else {\n throw new Error('Invalid \"target\" value, use a valid Element');\n }\n } // Define selection strategy based on `text` property.\n\n\n if (text) {\n return actions_copy(text, {\n container: container\n });\n } // Defines which selection strategy based on `target` property.\n\n\n if (target) {\n return action === 'cut' ? actions_cut(target) : actions_copy(target, {\n container: container\n });\n }\n};\n\n/* harmony default export */ var actions_default = (ClipboardActionDefault);\n;// CONCATENATED MODULE: ./src/clipboard.js\nfunction clipboard_typeof(obj) { \"@babel/helpers - typeof\"; if (typeof Symbol === \"function\" && typeof Symbol.iterator === \"symbol\") { clipboard_typeof = function _typeof(obj) { return typeof obj; }; } else { clipboard_typeof = function _typeof(obj) { return obj && typeof Symbol === \"function\" && obj.constructor === Symbol && obj !== Symbol.prototype ? \"symbol\" : typeof obj; }; } return clipboard_typeof(obj); }\n\nfunction _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError(\"Cannot call a class as a function\"); } }\n\nfunction _defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if (\"value\" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } }\n\nfunction _createClass(Constructor, protoProps, staticProps) { if (protoProps) _defineProperties(Constructor.prototype, protoProps); if (staticProps) _defineProperties(Constructor, staticProps); return Constructor; }\n\nfunction _inherits(subClass, superClass) { if (typeof superClass !== \"function\" && superClass !== null) { throw new TypeError(\"Super expression must either be null or a function\"); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, writable: true, configurable: true } }); if (superClass) _setPrototypeOf(subClass, superClass); }\n\nfunction _setPrototypeOf(o, p) { _setPrototypeOf = Object.setPrototypeOf || function _setPrototypeOf(o, p) { o.__proto__ = p; return o; }; return _setPrototypeOf(o, p); }\n\nfunction _createSuper(Derived) { var hasNativeReflectConstruct = _isNativeReflectConstruct(); return function _createSuperInternal() { var Super = _getPrototypeOf(Derived), result; if (hasNativeReflectConstruct) { var NewTarget = _getPrototypeOf(this).constructor; result = Reflect.construct(Super, arguments, NewTarget); } else { result = Super.apply(this, arguments); } return _possibleConstructorReturn(this, result); }; }\n\nfunction _possibleConstructorReturn(self, call) { if (call && (clipboard_typeof(call) === \"object\" || typeof call === \"function\")) { return call; } return _assertThisInitialized(self); }\n\nfunction _assertThisInitialized(self) { if (self === void 0) { throw new ReferenceError(\"this hasn't been initialised - super() hasn't been called\"); } return self; }\n\nfunction _isNativeReflectConstruct() { if (typeof Reflect === \"undefined\" || !Reflect.construct) return false; if (Reflect.construct.sham) return false; if (typeof Proxy === \"function\") return true; try { Date.prototype.toString.call(Reflect.construct(Date, [], function () {})); return true; } catch (e) { return false; } }\n\nfunction _getPrototypeOf(o) { _getPrototypeOf = Object.setPrototypeOf ? Object.getPrototypeOf : function _getPrototypeOf(o) { return o.__proto__ || Object.getPrototypeOf(o); }; return _getPrototypeOf(o); }\n\n\n\n\n\n\n/**\n * Helper function to retrieve attribute value.\n * @param {String} suffix\n * @param {Element} element\n */\n\nfunction getAttributeValue(suffix, element) {\n var attribute = \"data-clipboard-\".concat(suffix);\n\n if (!element.hasAttribute(attribute)) {\n return;\n }\n\n return element.getAttribute(attribute);\n}\n/**\n * Base class which takes one or more elements, adds event listeners to them,\n * and instantiates a new `ClipboardAction` on each click.\n */\n\n\nvar Clipboard = /*#__PURE__*/function (_Emitter) {\n _inherits(Clipboard, _Emitter);\n\n var _super = _createSuper(Clipboard);\n\n /**\n * @param {String|HTMLElement|HTMLCollection|NodeList} trigger\n * @param {Object} options\n */\n function Clipboard(trigger, options) {\n var _this;\n\n _classCallCheck(this, Clipboard);\n\n _this = _super.call(this);\n\n _this.resolveOptions(options);\n\n _this.listenClick(trigger);\n\n return _this;\n }\n /**\n * Defines if attributes would be resolved using internal setter functions\n * or custom functions that were passed in the constructor.\n * @param {Object} options\n */\n\n\n _createClass(Clipboard, [{\n key: \"resolveOptions\",\n value: function resolveOptions() {\n var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};\n this.action = typeof options.action === 'function' ? options.action : this.defaultAction;\n this.target = typeof options.target === 'function' ? options.target : this.defaultTarget;\n this.text = typeof options.text === 'function' ? options.text : this.defaultText;\n this.container = clipboard_typeof(options.container) === 'object' ? options.container : document.body;\n }\n /**\n * Adds a click event listener to the passed trigger.\n * @param {String|HTMLElement|HTMLCollection|NodeList} trigger\n */\n\n }, {\n key: \"listenClick\",\n value: function listenClick(trigger) {\n var _this2 = this;\n\n this.listener = listen_default()(trigger, 'click', function (e) {\n return _this2.onClick(e);\n });\n }\n /**\n * Defines a new `ClipboardAction` on each click event.\n * @param {Event} e\n */\n\n }, {\n key: \"onClick\",\n value: function onClick(e) {\n var trigger = e.delegateTarget || e.currentTarget;\n var action = this.action(trigger) || 'copy';\n var text = actions_default({\n action: action,\n container: this.container,\n target: this.target(trigger),\n text: this.text(trigger)\n }); // Fires an event based on the copy operation result.\n\n this.emit(text ? 'success' : 'error', {\n action: action,\n text: text,\n trigger: trigger,\n clearSelection: function clearSelection() {\n if (trigger) {\n trigger.focus();\n }\n\n window.getSelection().removeAllRanges();\n }\n });\n }\n /**\n * Default `action` lookup function.\n * @param {Element} trigger\n */\n\n }, {\n key: \"defaultAction\",\n value: function defaultAction(trigger) {\n return getAttributeValue('action', trigger);\n }\n /**\n * Default `target` lookup function.\n * @param {Element} trigger\n */\n\n }, {\n key: \"defaultTarget\",\n value: function defaultTarget(trigger) {\n var selector = getAttributeValue('target', trigger);\n\n if (selector) {\n return document.querySelector(selector);\n }\n }\n /**\n * Allow fire programmatically a copy action\n * @param {String|HTMLElement} target\n * @param {Object} options\n * @returns Text copied.\n */\n\n }, {\n key: \"defaultText\",\n\n /**\n * Default `text` lookup function.\n * @param {Element} trigger\n */\n value: function defaultText(trigger) {\n return getAttributeValue('text', trigger);\n }\n /**\n * Destroy lifecycle.\n */\n\n }, {\n key: \"destroy\",\n value: function destroy() {\n this.listener.destroy();\n }\n }], [{\n key: \"copy\",\n value: function copy(target) {\n var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {\n container: document.body\n };\n return actions_copy(target, options);\n }\n /**\n * Allow fire programmatically a cut action\n * @param {String|HTMLElement} target\n * @returns Text cutted.\n */\n\n }, {\n key: \"cut\",\n value: function cut(target) {\n return actions_cut(target);\n }\n /**\n * Returns the support of the given action, or all actions if no action is\n * given.\n * @param {String} [action]\n */\n\n }, {\n key: \"isSupported\",\n value: function isSupported() {\n var action = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : ['copy', 'cut'];\n var actions = typeof action === 'string' ? [action] : action;\n var support = !!document.queryCommandSupported;\n actions.forEach(function (action) {\n support = support && !!document.queryCommandSupported(action);\n });\n return support;\n }\n }]);\n\n return Clipboard;\n}((tiny_emitter_default()));\n\n/* harmony default export */ var clipboard = (Clipboard);\n\n/***/ }),\n\n/***/ 828:\n/***/ (function(module) {\n\nvar DOCUMENT_NODE_TYPE = 9;\n\n/**\n * A polyfill for Element.matches()\n */\nif (typeof Element !== 'undefined' && !Element.prototype.matches) {\n var proto = Element.prototype;\n\n proto.matches = proto.matchesSelector ||\n proto.mozMatchesSelector ||\n proto.msMatchesSelector ||\n proto.oMatchesSelector ||\n proto.webkitMatchesSelector;\n}\n\n/**\n * Finds the closest parent that matches a selector.\n *\n * @param {Element} element\n * @param {String} selector\n * @return {Function}\n */\nfunction closest (element, selector) {\n while (element && element.nodeType !== DOCUMENT_NODE_TYPE) {\n if (typeof element.matches === 'function' &&\n element.matches(selector)) {\n return element;\n }\n element = element.parentNode;\n }\n}\n\nmodule.exports = closest;\n\n\n/***/ }),\n\n/***/ 438:\n/***/ (function(module, __unused_webpack_exports, __webpack_require__) {\n\nvar closest = __webpack_require__(828);\n\n/**\n * Delegates event to a selector.\n *\n * @param {Element} element\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @param {Boolean} useCapture\n * @return {Object}\n */\nfunction _delegate(element, selector, type, callback, useCapture) {\n var listenerFn = listener.apply(this, arguments);\n\n element.addEventListener(type, listenerFn, useCapture);\n\n return {\n destroy: function() {\n element.removeEventListener(type, listenerFn, useCapture);\n }\n }\n}\n\n/**\n * Delegates event to a selector.\n *\n * @param {Element|String|Array} [elements]\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @param {Boolean} useCapture\n * @return {Object}\n */\nfunction delegate(elements, selector, type, callback, useCapture) {\n // Handle the regular Element usage\n if (typeof elements.addEventListener === 'function') {\n return _delegate.apply(null, arguments);\n }\n\n // Handle Element-less usage, it defaults to global delegation\n if (typeof type === 'function') {\n // Use `document` as the first parameter, then apply arguments\n // This is a short way to .unshift `arguments` without running into deoptimizations\n return _delegate.bind(null, document).apply(null, arguments);\n }\n\n // Handle Selector-based usage\n if (typeof elements === 'string') {\n elements = document.querySelectorAll(elements);\n }\n\n // Handle Array-like based usage\n return Array.prototype.map.call(elements, function (element) {\n return _delegate(element, selector, type, callback, useCapture);\n });\n}\n\n/**\n * Finds closest match and invokes callback.\n *\n * @param {Element} element\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @return {Function}\n */\nfunction listener(element, selector, type, callback) {\n return function(e) {\n e.delegateTarget = closest(e.target, selector);\n\n if (e.delegateTarget) {\n callback.call(element, e);\n }\n }\n}\n\nmodule.exports = delegate;\n\n\n/***/ }),\n\n/***/ 879:\n/***/ (function(__unused_webpack_module, exports) {\n\n/**\n * Check if argument is a HTML element.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.node = function(value) {\n return value !== undefined\n && value instanceof HTMLElement\n && value.nodeType === 1;\n};\n\n/**\n * Check if argument is a list of HTML elements.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.nodeList = function(value) {\n var type = Object.prototype.toString.call(value);\n\n return value !== undefined\n && (type === '[object NodeList]' || type === '[object HTMLCollection]')\n && ('length' in value)\n && (value.length === 0 || exports.node(value[0]));\n};\n\n/**\n * Check if argument is a string.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.string = function(value) {\n return typeof value === 'string'\n || value instanceof String;\n};\n\n/**\n * Check if argument is a function.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.fn = function(value) {\n var type = Object.prototype.toString.call(value);\n\n return type === '[object Function]';\n};\n\n\n/***/ }),\n\n/***/ 370:\n/***/ (function(module, __unused_webpack_exports, __webpack_require__) {\n\nvar is = __webpack_require__(879);\nvar delegate = __webpack_require__(438);\n\n/**\n * Validates all params and calls the right\n * listener function based on its target type.\n *\n * @param {String|HTMLElement|HTMLCollection|NodeList} target\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listen(target, type, callback) {\n if (!target && !type && !callback) {\n throw new Error('Missing required arguments');\n }\n\n if (!is.string(type)) {\n throw new TypeError('Second argument must be a String');\n }\n\n if (!is.fn(callback)) {\n throw new TypeError('Third argument must be a Function');\n }\n\n if (is.node(target)) {\n return listenNode(target, type, callback);\n }\n else if (is.nodeList(target)) {\n return listenNodeList(target, type, callback);\n }\n else if (is.string(target)) {\n return listenSelector(target, type, callback);\n }\n else {\n throw new TypeError('First argument must be a String, HTMLElement, HTMLCollection, or NodeList');\n }\n}\n\n/**\n * Adds an event listener to a HTML element\n * and returns a remove listener function.\n *\n * @param {HTMLElement} node\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenNode(node, type, callback) {\n node.addEventListener(type, callback);\n\n return {\n destroy: function() {\n node.removeEventListener(type, callback);\n }\n }\n}\n\n/**\n * Add an event listener to a list of HTML elements\n * and returns a remove listener function.\n *\n * @param {NodeList|HTMLCollection} nodeList\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenNodeList(nodeList, type, callback) {\n Array.prototype.forEach.call(nodeList, function(node) {\n node.addEventListener(type, callback);\n });\n\n return {\n destroy: function() {\n Array.prototype.forEach.call(nodeList, function(node) {\n node.removeEventListener(type, callback);\n });\n }\n }\n}\n\n/**\n * Add an event listener to a selector\n * and returns a remove listener function.\n *\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenSelector(selector, type, callback) {\n return delegate(document.body, selector, type, callback);\n}\n\nmodule.exports = listen;\n\n\n/***/ }),\n\n/***/ 817:\n/***/ (function(module) {\n\nfunction select(element) {\n var selectedText;\n\n if (element.nodeName === 'SELECT') {\n element.focus();\n\n selectedText = element.value;\n }\n else if (element.nodeName === 'INPUT' || element.nodeName === 'TEXTAREA') {\n var isReadOnly = element.hasAttribute('readonly');\n\n if (!isReadOnly) {\n element.setAttribute('readonly', '');\n }\n\n element.select();\n element.setSelectionRange(0, element.value.length);\n\n if (!isReadOnly) {\n element.removeAttribute('readonly');\n }\n\n selectedText = element.value;\n }\n else {\n if (element.hasAttribute('contenteditable')) {\n element.focus();\n }\n\n var selection = window.getSelection();\n var range = document.createRange();\n\n range.selectNodeContents(element);\n selection.removeAllRanges();\n selection.addRange(range);\n\n selectedText = selection.toString();\n }\n\n return selectedText;\n}\n\nmodule.exports = select;\n\n\n/***/ }),\n\n/***/ 279:\n/***/ (function(module) {\n\nfunction E () {\n // Keep this empty so it's easier to inherit from\n // (via https://github.com/lipsmack from https://github.com/scottcorgan/tiny-emitter/issues/3)\n}\n\nE.prototype = {\n on: function (name, callback, ctx) {\n var e = this.e || (this.e = {});\n\n (e[name] || (e[name] = [])).push({\n fn: callback,\n ctx: ctx\n });\n\n return this;\n },\n\n once: function (name, callback, ctx) {\n var self = this;\n function listener () {\n self.off(name, listener);\n callback.apply(ctx, arguments);\n };\n\n listener._ = callback\n return this.on(name, listener, ctx);\n },\n\n emit: function (name) {\n var data = [].slice.call(arguments, 1);\n var evtArr = ((this.e || (this.e = {}))[name] || []).slice();\n var i = 0;\n var len = evtArr.length;\n\n for (i; i < len; i++) {\n evtArr[i].fn.apply(evtArr[i].ctx, data);\n }\n\n return this;\n },\n\n off: function (name, callback) {\n var e = this.e || (this.e = {});\n var evts = e[name];\n var liveEvents = [];\n\n if (evts && callback) {\n for (var i = 0, len = evts.length; i < len; i++) {\n if (evts[i].fn !== callback && evts[i].fn._ !== callback)\n liveEvents.push(evts[i]);\n }\n }\n\n // Remove event from queue to prevent memory leak\n // Suggested by https://github.com/lazd\n // Ref: https://github.com/scottcorgan/tiny-emitter/commit/c6ebfaa9bc973b33d110a84a307742b7cf94c953#commitcomment-5024910\n\n (liveEvents.length)\n ? e[name] = liveEvents\n : delete e[name];\n\n return this;\n }\n};\n\nmodule.exports = E;\nmodule.exports.TinyEmitter = E;\n\n\n/***/ })\n\n/******/ \t});\n/************************************************************************/\n/******/ \t// The module cache\n/******/ \tvar __webpack_module_cache__ = {};\n/******/ \t\n/******/ \t// The require function\n/******/ \tfunction __webpack_require__(moduleId) {\n/******/ \t\t// Check if module is in cache\n/******/ \t\tif(__webpack_module_cache__[moduleId]) {\n/******/ \t\t\treturn __webpack_module_cache__[moduleId].exports;\n/******/ \t\t}\n/******/ \t\t// Create a new module (and put it into the cache)\n/******/ \t\tvar module = __webpack_module_cache__[moduleId] = {\n/******/ \t\t\t// no module.id needed\n/******/ \t\t\t// no module.loaded needed\n/******/ \t\t\texports: {}\n/******/ \t\t};\n/******/ \t\n/******/ \t\t// Execute the module function\n/******/ \t\t__webpack_modules__[moduleId](module, module.exports, __webpack_require__);\n/******/ \t\n/******/ \t\t// Return the exports of the module\n/******/ \t\treturn module.exports;\n/******/ \t}\n/******/ \t\n/************************************************************************/\n/******/ \t/* webpack/runtime/compat get default export */\n/******/ \t!function() {\n/******/ \t\t// getDefaultExport function for compatibility with non-harmony modules\n/******/ \t\t__webpack_require__.n = function(module) {\n/******/ \t\t\tvar getter = module && module.__esModule ?\n/******/ \t\t\t\tfunction() { return module['default']; } :\n/******/ \t\t\t\tfunction() { return module; };\n/******/ \t\t\t__webpack_require__.d(getter, { a: getter });\n/******/ \t\t\treturn getter;\n/******/ \t\t};\n/******/ \t}();\n/******/ \t\n/******/ \t/* webpack/runtime/define property getters */\n/******/ \t!function() {\n/******/ \t\t// define getter functions for harmony exports\n/******/ \t\t__webpack_require__.d = function(exports, definition) {\n/******/ \t\t\tfor(var key in definition) {\n/******/ \t\t\t\tif(__webpack_require__.o(definition, key) && !__webpack_require__.o(exports, key)) {\n/******/ \t\t\t\t\tObject.defineProperty(exports, key, { enumerable: true, get: definition[key] });\n/******/ \t\t\t\t}\n/******/ \t\t\t}\n/******/ \t\t};\n/******/ \t}();\n/******/ \t\n/******/ \t/* webpack/runtime/hasOwnProperty shorthand */\n/******/ \t!function() {\n/******/ \t\t__webpack_require__.o = function(obj, prop) { return Object.prototype.hasOwnProperty.call(obj, prop); }\n/******/ \t}();\n/******/ \t\n/************************************************************************/\n/******/ \t// module exports must be returned from runtime so entry inlining is disabled\n/******/ \t// startup\n/******/ \t// Load entry module and return exports\n/******/ \treturn __webpack_require__(686);\n/******/ })()\n.default;\n});", "/*\n * Copyright (c) 2016-2024 Martin Donath \n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\nimport \"focus-visible\"\n\nimport {\n EMPTY,\n NEVER,\n Observable,\n Subject,\n defer,\n delay,\n filter,\n map,\n merge,\n mergeWith,\n shareReplay,\n switchMap\n} from \"rxjs\"\n\nimport { configuration, feature } from \"./_\"\nimport {\n at,\n getActiveElement,\n getOptionalElement,\n requestJSON,\n setLocation,\n setToggle,\n watchDocument,\n watchKeyboard,\n watchLocation,\n watchLocationTarget,\n watchMedia,\n watchPrint,\n watchScript,\n watchViewport\n} from \"./browser\"\nimport {\n getComponentElement,\n getComponentElements,\n mountAnnounce,\n mountBackToTop,\n mountConsent,\n mountContent,\n mountDialog,\n mountHeader,\n mountHeaderTitle,\n mountPalette,\n mountProgress,\n mountSearch,\n mountSearchHiglight,\n mountSidebar,\n mountSource,\n mountTableOfContents,\n mountTabs,\n watchHeader,\n watchMain\n} from \"./components\"\nimport {\n SearchIndex,\n setupClipboardJS,\n setupInstantNavigation,\n setupVersionSelector\n} from \"./integrations\"\nimport {\n patchEllipsis,\n patchIndeterminate,\n patchScrollfix,\n patchScrolllock\n} from \"./patches\"\nimport \"./polyfills\"\n\n/* ----------------------------------------------------------------------------\n * Functions - @todo refactor\n * ------------------------------------------------------------------------- */\n\n/**\n * Fetch search index\n *\n * @returns Search index observable\n */\nfunction fetchSearchIndex(): Observable {\n if (location.protocol === \"file:\") {\n return watchScript(\n `${new URL(\"search/search_index.js\", config.base)}`\n )\n .pipe(\n // @ts-ignore - @todo fix typings\n map(() => __index),\n shareReplay(1)\n )\n } else {\n return requestJSON(\n new URL(\"search/search_index.json\", config.base)\n )\n }\n}\n\n/* ----------------------------------------------------------------------------\n * Application\n * ------------------------------------------------------------------------- */\n\n/* Yay, JavaScript is available */\ndocument.documentElement.classList.remove(\"no-js\")\ndocument.documentElement.classList.add(\"js\")\n\n/* Set up navigation observables and subjects */\nconst document$ = watchDocument()\nconst location$ = watchLocation()\nconst target$ = watchLocationTarget(location$)\nconst keyboard$ = watchKeyboard()\n\n/* Set up media observables */\nconst viewport$ = watchViewport()\nconst tablet$ = watchMedia(\"(min-width: 960px)\")\nconst screen$ = watchMedia(\"(min-width: 1220px)\")\nconst print$ = watchPrint()\n\n/* Retrieve search index, if search is enabled */\nconst config = configuration()\nconst index$ = document.forms.namedItem(\"search\")\n ? fetchSearchIndex()\n : NEVER\n\n/* Set up Clipboard.js integration */\nconst alert$ = new Subject()\nsetupClipboardJS({ alert$ })\n\n/* Set up progress indicator */\nconst progress$ = new Subject()\n\n/* Set up instant navigation, if enabled */\nif (feature(\"navigation.instant\"))\n setupInstantNavigation({ location$, viewport$, progress$ })\n .subscribe(document$)\n\n/* Set up version selector */\nif (config.version?.provider === \"mike\")\n setupVersionSelector({ document$ })\n\n/* Always close drawer and search on navigation */\nmerge(location$, target$)\n .pipe(\n delay(125)\n )\n .subscribe(() => {\n setToggle(\"drawer\", false)\n setToggle(\"search\", false)\n })\n\n/* Set up global keyboard handlers */\nkeyboard$\n .pipe(\n filter(({ mode }) => mode === \"global\")\n )\n .subscribe(key => {\n switch (key.type) {\n\n /* Go to previous page */\n case \"p\":\n case \",\":\n const prev = getOptionalElement(\"link[rel=prev]\")\n if (typeof prev !== \"undefined\")\n setLocation(prev)\n break\n\n /* Go to next page */\n case \"n\":\n case \".\":\n const next = getOptionalElement(\"link[rel=next]\")\n if (typeof next !== \"undefined\")\n setLocation(next)\n break\n\n /* Expand navigation, see https://bit.ly/3ZjG5io */\n case \"Enter\":\n const active = getActiveElement()\n if (active instanceof HTMLLabelElement)\n active.click()\n }\n })\n\n/* Set up patches */\npatchEllipsis({ viewport$, document$ })\npatchIndeterminate({ document$, tablet$ })\npatchScrollfix({ document$ })\npatchScrolllock({ viewport$, tablet$ })\n\n/* Set up header and main area observable */\nconst header$ = watchHeader(getComponentElement(\"header\"), { viewport$ })\nconst main$ = document$\n .pipe(\n map(() => getComponentElement(\"main\")),\n switchMap(el => watchMain(el, { viewport$, header$ })),\n shareReplay(1)\n )\n\n/* Set up control component observables */\nconst control$ = merge(\n\n /* Consent */\n ...getComponentElements(\"consent\")\n .map(el => mountConsent(el, { target$ })),\n\n /* Dialog */\n ...getComponentElements(\"dialog\")\n .map(el => mountDialog(el, { alert$ })),\n\n /* Color palette */\n ...getComponentElements(\"palette\")\n .map(el => mountPalette(el)),\n\n /* Progress bar */\n ...getComponentElements(\"progress\")\n .map(el => mountProgress(el, { progress$ })),\n\n /* Search */\n ...getComponentElements(\"search\")\n .map(el => mountSearch(el, { index$, keyboard$ })),\n\n /* Repository information */\n ...getComponentElements(\"source\")\n .map(el => mountSource(el))\n)\n\n/* Set up content component observables */\nconst content$ = defer(() => merge(\n\n /* Announcement bar */\n ...getComponentElements(\"announce\")\n .map(el => mountAnnounce(el)),\n\n /* Content */\n ...getComponentElements(\"content\")\n .map(el => mountContent(el, { viewport$, target$, print$ })),\n\n /* Search highlighting */\n ...getComponentElements(\"content\")\n .map(el => feature(\"search.highlight\")\n ? mountSearchHiglight(el, { index$, location$ })\n : EMPTY\n ),\n\n /* Header */\n ...getComponentElements(\"header\")\n .map(el => mountHeader(el, { viewport$, header$, main$ })),\n\n /* Header title */\n ...getComponentElements(\"header-title\")\n .map(el => mountHeaderTitle(el, { viewport$, header$ })),\n\n /* Sidebar */\n ...getComponentElements(\"sidebar\")\n .map(el => el.getAttribute(\"data-md-type\") === \"navigation\"\n ? at(screen$, () => mountSidebar(el, { viewport$, header$, main$ }))\n : at(tablet$, () => mountSidebar(el, { viewport$, header$, main$ }))\n ),\n\n /* Navigation tabs */\n ...getComponentElements(\"tabs\")\n .map(el => mountTabs(el, { viewport$, header$ })),\n\n /* Table of contents */\n ...getComponentElements(\"toc\")\n .map(el => mountTableOfContents(el, {\n viewport$, header$, main$, target$\n })),\n\n /* Back-to-top button */\n ...getComponentElements(\"top\")\n .map(el => mountBackToTop(el, { viewport$, header$, main$, target$ }))\n))\n\n/* Set up component observables */\nconst component$ = document$\n .pipe(\n switchMap(() => content$),\n mergeWith(control$),\n shareReplay(1)\n )\n\n/* Subscribe to all components */\ncomponent$.subscribe()\n\n/* ----------------------------------------------------------------------------\n * Exports\n * ------------------------------------------------------------------------- */\n\nwindow.document$ = document$ /* Document observable */\nwindow.location$ = location$ /* Location subject */\nwindow.target$ = target$ /* Location target observable */\nwindow.keyboard$ = keyboard$ /* Keyboard observable */\nwindow.viewport$ = viewport$ /* Viewport observable */\nwindow.tablet$ = tablet$ /* Media tablet observable */\nwindow.screen$ = screen$ /* Media screen observable */\nwindow.print$ = print$ /* Media print observable */\nwindow.alert$ = alert$ /* Alert subject */\nwindow.progress$ = progress$ /* Progress indicator subject */\nwindow.component$ = component$ /* Component observable */\n", "/******************************************************************************\nCopyright (c) Microsoft Corporation.\n\nPermission to use, copy, modify, and/or distribute this software for any\npurpose with or without fee is hereby granted.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH\nREGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY\nAND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,\nINDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM\nLOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR\nOTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR\nPERFORMANCE OF THIS SOFTWARE.\n***************************************************************************** */\n/* global Reflect, Promise, SuppressedError, Symbol, Iterator */\n\nvar extendStatics = function(d, b) {\n extendStatics = Object.setPrototypeOf ||\n ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||\n function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; };\n return extendStatics(d, b);\n};\n\nexport function __extends(d, b) {\n if (typeof b !== \"function\" && b !== null)\n throw new TypeError(\"Class extends value \" + String(b) + \" is not a constructor or null\");\n extendStatics(d, b);\n function __() { this.constructor = d; }\n d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());\n}\n\nexport var __assign = function() {\n __assign = Object.assign || function __assign(t) {\n for (var s, i = 1, n = arguments.length; i < n; i++) {\n s = arguments[i];\n for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p)) t[p] = s[p];\n }\n return t;\n }\n return __assign.apply(this, arguments);\n}\n\nexport function __rest(s, e) {\n var t = {};\n for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)\n t[p] = s[p];\n if (s != null && typeof Object.getOwnPropertySymbols === \"function\")\n for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {\n if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))\n t[p[i]] = s[p[i]];\n }\n return t;\n}\n\nexport function __decorate(decorators, target, key, desc) {\n var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;\n if (typeof Reflect === \"object\" && typeof Reflect.decorate === \"function\") r = Reflect.decorate(decorators, target, key, desc);\n else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;\n return c > 3 && r && Object.defineProperty(target, key, r), r;\n}\n\nexport function __param(paramIndex, decorator) {\n return function (target, key) { decorator(target, key, paramIndex); }\n}\n\nexport function __esDecorate(ctor, descriptorIn, decorators, contextIn, initializers, extraInitializers) {\n function accept(f) { if (f !== void 0 && typeof f !== \"function\") throw new TypeError(\"Function expected\"); return f; }\n var kind = contextIn.kind, key = kind === \"getter\" ? \"get\" : kind === \"setter\" ? \"set\" : \"value\";\n var target = !descriptorIn && ctor ? contextIn[\"static\"] ? ctor : ctor.prototype : null;\n var descriptor = descriptorIn || (target ? Object.getOwnPropertyDescriptor(target, contextIn.name) : {});\n var _, done = false;\n for (var i = decorators.length - 1; i >= 0; i--) {\n var context = {};\n for (var p in contextIn) context[p] = p === \"access\" ? {} : contextIn[p];\n for (var p in contextIn.access) context.access[p] = contextIn.access[p];\n context.addInitializer = function (f) { if (done) throw new TypeError(\"Cannot add initializers after decoration has completed\"); extraInitializers.push(accept(f || null)); };\n var result = (0, decorators[i])(kind === \"accessor\" ? { get: descriptor.get, set: descriptor.set } : descriptor[key], context);\n if (kind === \"accessor\") {\n if (result === void 0) continue;\n if (result === null || typeof result !== \"object\") throw new TypeError(\"Object expected\");\n if (_ = accept(result.get)) descriptor.get = _;\n if (_ = accept(result.set)) descriptor.set = _;\n if (_ = accept(result.init)) initializers.unshift(_);\n }\n else if (_ = accept(result)) {\n if (kind === \"field\") initializers.unshift(_);\n else descriptor[key] = _;\n }\n }\n if (target) Object.defineProperty(target, contextIn.name, descriptor);\n done = true;\n};\n\nexport function __runInitializers(thisArg, initializers, value) {\n var useValue = arguments.length > 2;\n for (var i = 0; i < initializers.length; i++) {\n value = useValue ? initializers[i].call(thisArg, value) : initializers[i].call(thisArg);\n }\n return useValue ? value : void 0;\n};\n\nexport function __propKey(x) {\n return typeof x === \"symbol\" ? x : \"\".concat(x);\n};\n\nexport function __setFunctionName(f, name, prefix) {\n if (typeof name === \"symbol\") name = name.description ? \"[\".concat(name.description, \"]\") : \"\";\n return Object.defineProperty(f, \"name\", { configurable: true, value: prefix ? \"\".concat(prefix, \" \", name) : name });\n};\n\nexport function __metadata(metadataKey, metadataValue) {\n if (typeof Reflect === \"object\" && typeof Reflect.metadata === \"function\") return Reflect.metadata(metadataKey, metadataValue);\n}\n\nexport function __awaiter(thisArg, _arguments, P, generator) {\n function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }\n return new (P || (P = Promise))(function (resolve, reject) {\n function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }\n function rejected(value) { try { step(generator[\"throw\"](value)); } catch (e) { reject(e); } }\n function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }\n step((generator = generator.apply(thisArg, _arguments || [])).next());\n });\n}\n\nexport function __generator(thisArg, body) {\n var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g = Object.create((typeof Iterator === \"function\" ? Iterator : Object).prototype);\n return g.next = verb(0), g[\"throw\"] = verb(1), g[\"return\"] = verb(2), typeof Symbol === \"function\" && (g[Symbol.iterator] = function() { return this; }), g;\n function verb(n) { return function (v) { return step([n, v]); }; }\n function step(op) {\n if (f) throw new TypeError(\"Generator is already executing.\");\n while (g && (g = 0, op[0] && (_ = 0)), _) try {\n if (f = 1, y && (t = op[0] & 2 ? y[\"return\"] : op[0] ? y[\"throw\"] || ((t = y[\"return\"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;\n if (y = 0, t) op = [op[0] & 2, t.value];\n switch (op[0]) {\n case 0: case 1: t = op; break;\n case 4: _.label++; return { value: op[1], done: false };\n case 5: _.label++; y = op[1]; op = [0]; continue;\n case 7: op = _.ops.pop(); _.trys.pop(); continue;\n default:\n if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }\n if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }\n if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }\n if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }\n if (t[2]) _.ops.pop();\n _.trys.pop(); continue;\n }\n op = body.call(thisArg, _);\n } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }\n if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };\n }\n}\n\nexport var __createBinding = Object.create ? (function(o, m, k, k2) {\n if (k2 === undefined) k2 = k;\n var desc = Object.getOwnPropertyDescriptor(m, k);\n if (!desc || (\"get\" in desc ? !m.__esModule : desc.writable || desc.configurable)) {\n desc = { enumerable: true, get: function() { return m[k]; } };\n }\n Object.defineProperty(o, k2, desc);\n}) : (function(o, m, k, k2) {\n if (k2 === undefined) k2 = k;\n o[k2] = m[k];\n});\n\nexport function __exportStar(m, o) {\n for (var p in m) if (p !== \"default\" && !Object.prototype.hasOwnProperty.call(o, p)) __createBinding(o, m, p);\n}\n\nexport function __values(o) {\n var s = typeof Symbol === \"function\" && Symbol.iterator, m = s && o[s], i = 0;\n if (m) return m.call(o);\n if (o && typeof o.length === \"number\") return {\n next: function () {\n if (o && i >= o.length) o = void 0;\n return { value: o && o[i++], done: !o };\n }\n };\n throw new TypeError(s ? \"Object is not iterable.\" : \"Symbol.iterator is not defined.\");\n}\n\nexport function __read(o, n) {\n var m = typeof Symbol === \"function\" && o[Symbol.iterator];\n if (!m) return o;\n var i = m.call(o), r, ar = [], e;\n try {\n while ((n === void 0 || n-- > 0) && !(r = i.next()).done) ar.push(r.value);\n }\n catch (error) { e = { error: error }; }\n finally {\n try {\n if (r && !r.done && (m = i[\"return\"])) m.call(i);\n }\n finally { if (e) throw e.error; }\n }\n return ar;\n}\n\n/** @deprecated */\nexport function __spread() {\n for (var ar = [], i = 0; i < arguments.length; i++)\n ar = ar.concat(__read(arguments[i]));\n return ar;\n}\n\n/** @deprecated */\nexport function __spreadArrays() {\n for (var s = 0, i = 0, il = arguments.length; i < il; i++) s += arguments[i].length;\n for (var r = Array(s), k = 0, i = 0; i < il; i++)\n for (var a = arguments[i], j = 0, jl = a.length; j < jl; j++, k++)\n r[k] = a[j];\n return r;\n}\n\nexport function __spreadArray(to, from, pack) {\n if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) {\n if (ar || !(i in from)) {\n if (!ar) ar = Array.prototype.slice.call(from, 0, i);\n ar[i] = from[i];\n }\n }\n return to.concat(ar || Array.prototype.slice.call(from));\n}\n\nexport function __await(v) {\n return this instanceof __await ? (this.v = v, this) : new __await(v);\n}\n\nexport function __asyncGenerator(thisArg, _arguments, generator) {\n if (!Symbol.asyncIterator) throw new TypeError(\"Symbol.asyncIterator is not defined.\");\n var g = generator.apply(thisArg, _arguments || []), i, q = [];\n return i = Object.create((typeof AsyncIterator === \"function\" ? AsyncIterator : Object).prototype), verb(\"next\"), verb(\"throw\"), verb(\"return\", awaitReturn), i[Symbol.asyncIterator] = function () { return this; }, i;\n function awaitReturn(f) { return function (v) { return Promise.resolve(v).then(f, reject); }; }\n function verb(n, f) { if (g[n]) { i[n] = function (v) { return new Promise(function (a, b) { q.push([n, v, a, b]) > 1 || resume(n, v); }); }; if (f) i[n] = f(i[n]); } }\n function resume(n, v) { try { step(g[n](v)); } catch (e) { settle(q[0][3], e); } }\n function step(r) { r.value instanceof __await ? Promise.resolve(r.value.v).then(fulfill, reject) : settle(q[0][2], r); }\n function fulfill(value) { resume(\"next\", value); }\n function reject(value) { resume(\"throw\", value); }\n function settle(f, v) { if (f(v), q.shift(), q.length) resume(q[0][0], q[0][1]); }\n}\n\nexport function __asyncDelegator(o) {\n var i, p;\n return i = {}, verb(\"next\"), verb(\"throw\", function (e) { throw e; }), verb(\"return\"), i[Symbol.iterator] = function () { return this; }, i;\n function verb(n, f) { i[n] = o[n] ? function (v) { return (p = !p) ? { value: __await(o[n](v)), done: false } : f ? f(v) : v; } : f; }\n}\n\nexport function __asyncValues(o) {\n if (!Symbol.asyncIterator) throw new TypeError(\"Symbol.asyncIterator is not defined.\");\n var m = o[Symbol.asyncIterator], i;\n return m ? m.call(o) : (o = typeof __values === \"function\" ? __values(o) : o[Symbol.iterator](), i = {}, verb(\"next\"), verb(\"throw\"), verb(\"return\"), i[Symbol.asyncIterator] = function () { return this; }, i);\n function verb(n) { i[n] = o[n] && function (v) { return new Promise(function (resolve, reject) { v = o[n](v), settle(resolve, reject, v.done, v.value); }); }; }\n function settle(resolve, reject, d, v) { Promise.resolve(v).then(function(v) { resolve({ value: v, done: d }); }, reject); }\n}\n\nexport function __makeTemplateObject(cooked, raw) {\n if (Object.defineProperty) { Object.defineProperty(cooked, \"raw\", { value: raw }); } else { cooked.raw = raw; }\n return cooked;\n};\n\nvar __setModuleDefault = Object.create ? (function(o, v) {\n Object.defineProperty(o, \"default\", { enumerable: true, value: v });\n}) : function(o, v) {\n o[\"default\"] = v;\n};\n\nexport function __importStar(mod) {\n if (mod && mod.__esModule) return mod;\n var result = {};\n if (mod != null) for (var k in mod) if (k !== \"default\" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);\n __setModuleDefault(result, mod);\n return result;\n}\n\nexport function __importDefault(mod) {\n return (mod && mod.__esModule) ? mod : { default: mod };\n}\n\nexport function __classPrivateFieldGet(receiver, state, kind, f) {\n if (kind === \"a\" && !f) throw new TypeError(\"Private accessor was defined without a getter\");\n if (typeof state === \"function\" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError(\"Cannot read private member from an object whose class did not declare it\");\n return kind === \"m\" ? f : kind === \"a\" ? f.call(receiver) : f ? f.value : state.get(receiver);\n}\n\nexport function __classPrivateFieldSet(receiver, state, value, kind, f) {\n if (kind === \"m\") throw new TypeError(\"Private method is not writable\");\n if (kind === \"a\" && !f) throw new TypeError(\"Private accessor was defined without a setter\");\n if (typeof state === \"function\" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError(\"Cannot write private member to an object whose class did not declare it\");\n return (kind === \"a\" ? f.call(receiver, value) : f ? f.value = value : state.set(receiver, value)), value;\n}\n\nexport function __classPrivateFieldIn(state, receiver) {\n if (receiver === null || (typeof receiver !== \"object\" && typeof receiver !== \"function\")) throw new TypeError(\"Cannot use 'in' operator on non-object\");\n return typeof state === \"function\" ? receiver === state : state.has(receiver);\n}\n\nexport function __addDisposableResource(env, value, async) {\n if (value !== null && value !== void 0) {\n if (typeof value !== \"object\" && typeof value !== \"function\") throw new TypeError(\"Object expected.\");\n var dispose, inner;\n if (async) {\n if (!Symbol.asyncDispose) throw new TypeError(\"Symbol.asyncDispose is not defined.\");\n dispose = value[Symbol.asyncDispose];\n }\n if (dispose === void 0) {\n if (!Symbol.dispose) throw new TypeError(\"Symbol.dispose is not defined.\");\n dispose = value[Symbol.dispose];\n if (async) inner = dispose;\n }\n if (typeof dispose !== \"function\") throw new TypeError(\"Object not disposable.\");\n if (inner) dispose = function() { try { inner.call(this); } catch (e) { return Promise.reject(e); } };\n env.stack.push({ value: value, dispose: dispose, async: async });\n }\n else if (async) {\n env.stack.push({ async: true });\n }\n return value;\n}\n\nvar _SuppressedError = typeof SuppressedError === \"function\" ? SuppressedError : function (error, suppressed, message) {\n var e = new Error(message);\n return e.name = \"SuppressedError\", e.error = error, e.suppressed = suppressed, e;\n};\n\nexport function __disposeResources(env) {\n function fail(e) {\n env.error = env.hasError ? new _SuppressedError(e, env.error, \"An error was suppressed during disposal.\") : e;\n env.hasError = true;\n }\n var r, s = 0;\n function next() {\n while (r = env.stack.pop()) {\n try {\n if (!r.async && s === 1) return s = 0, env.stack.push(r), Promise.resolve().then(next);\n if (r.dispose) {\n var result = r.dispose.call(r.value);\n if (r.async) return s |= 2, Promise.resolve(result).then(next, function(e) { fail(e); return next(); });\n }\n else s |= 1;\n }\n catch (e) {\n fail(e);\n }\n }\n if (s === 1) return env.hasError ? Promise.reject(env.error) : Promise.resolve();\n if (env.hasError) throw env.error;\n }\n return next();\n}\n\nexport default {\n __extends,\n __assign,\n __rest,\n __decorate,\n __param,\n __metadata,\n __awaiter,\n __generator,\n __createBinding,\n __exportStar,\n __values,\n __read,\n __spread,\n __spreadArrays,\n __spreadArray,\n __await,\n __asyncGenerator,\n __asyncDelegator,\n __asyncValues,\n __makeTemplateObject,\n __importStar,\n __importDefault,\n __classPrivateFieldGet,\n __classPrivateFieldSet,\n __classPrivateFieldIn,\n __addDisposableResource,\n __disposeResources,\n};\n", "/**\n * Returns true if the object is a function.\n * @param value The value to check\n */\nexport function isFunction(value: any): value is (...args: any[]) => any {\n return typeof value === 'function';\n}\n", "/**\n * Used to create Error subclasses until the community moves away from ES5.\n *\n * This is because compiling from TypeScript down to ES5 has issues with subclassing Errors\n * as well as other built-in types: https://github.com/Microsoft/TypeScript/issues/12123\n *\n * @param createImpl A factory function to create the actual constructor implementation. The returned\n * function should be a named function that calls `_super` internally.\n */\nexport function createErrorClass(createImpl: (_super: any) => any): T {\n const _super = (instance: any) => {\n Error.call(instance);\n instance.stack = new Error().stack;\n };\n\n const ctorFunc = createImpl(_super);\n ctorFunc.prototype = Object.create(Error.prototype);\n ctorFunc.prototype.constructor = ctorFunc;\n return ctorFunc;\n}\n", "import { createErrorClass } from './createErrorClass';\n\nexport interface UnsubscriptionError extends Error {\n readonly errors: any[];\n}\n\nexport interface UnsubscriptionErrorCtor {\n /**\n * @deprecated Internal implementation detail. Do not construct error instances.\n * Cannot be tagged as internal: https://github.com/ReactiveX/rxjs/issues/6269\n */\n new (errors: any[]): UnsubscriptionError;\n}\n\n/**\n * An error thrown when one or more errors have occurred during the\n * `unsubscribe` of a {@link Subscription}.\n */\nexport const UnsubscriptionError: UnsubscriptionErrorCtor = createErrorClass(\n (_super) =>\n function UnsubscriptionErrorImpl(this: any, errors: (Error | string)[]) {\n _super(this);\n this.message = errors\n ? `${errors.length} errors occurred during unsubscription:\n${errors.map((err, i) => `${i + 1}) ${err.toString()}`).join('\\n ')}`\n : '';\n this.name = 'UnsubscriptionError';\n this.errors = errors;\n }\n);\n", "/**\n * Removes an item from an array, mutating it.\n * @param arr The array to remove the item from\n * @param item The item to remove\n */\nexport function arrRemove(arr: T[] | undefined | null, item: T) {\n if (arr) {\n const index = arr.indexOf(item);\n 0 <= index && arr.splice(index, 1);\n }\n}\n", "import { isFunction } from './util/isFunction';\nimport { UnsubscriptionError } from './util/UnsubscriptionError';\nimport { SubscriptionLike, TeardownLogic, Unsubscribable } from './types';\nimport { arrRemove } from './util/arrRemove';\n\n/**\n * Represents a disposable resource, such as the execution of an Observable. A\n * Subscription has one important method, `unsubscribe`, that takes no argument\n * and just disposes the resource held by the subscription.\n *\n * Additionally, subscriptions may be grouped together through the `add()`\n * method, which will attach a child Subscription to the current Subscription.\n * When a Subscription is unsubscribed, all its children (and its grandchildren)\n * will be unsubscribed as well.\n *\n * @class Subscription\n */\nexport class Subscription implements SubscriptionLike {\n /** @nocollapse */\n public static EMPTY = (() => {\n const empty = new Subscription();\n empty.closed = true;\n return empty;\n })();\n\n /**\n * A flag to indicate whether this Subscription has already been unsubscribed.\n */\n public closed = false;\n\n private _parentage: Subscription[] | Subscription | null = null;\n\n /**\n * The list of registered finalizers to execute upon unsubscription. Adding and removing from this\n * list occurs in the {@link #add} and {@link #remove} methods.\n */\n private _finalizers: Exclude[] | null = null;\n\n /**\n * @param initialTeardown A function executed first as part of the finalization\n * process that is kicked off when {@link #unsubscribe} is called.\n */\n constructor(private initialTeardown?: () => void) {}\n\n /**\n * Disposes the resources held by the subscription. May, for instance, cancel\n * an ongoing Observable execution or cancel any other type of work that\n * started when the Subscription was created.\n * @return {void}\n */\n unsubscribe(): void {\n let errors: any[] | undefined;\n\n if (!this.closed) {\n this.closed = true;\n\n // Remove this from it's parents.\n const { _parentage } = this;\n if (_parentage) {\n this._parentage = null;\n if (Array.isArray(_parentage)) {\n for (const parent of _parentage) {\n parent.remove(this);\n }\n } else {\n _parentage.remove(this);\n }\n }\n\n const { initialTeardown: initialFinalizer } = this;\n if (isFunction(initialFinalizer)) {\n try {\n initialFinalizer();\n } catch (e) {\n errors = e instanceof UnsubscriptionError ? e.errors : [e];\n }\n }\n\n const { _finalizers } = this;\n if (_finalizers) {\n this._finalizers = null;\n for (const finalizer of _finalizers) {\n try {\n execFinalizer(finalizer);\n } catch (err) {\n errors = errors ?? [];\n if (err instanceof UnsubscriptionError) {\n errors = [...errors, ...err.errors];\n } else {\n errors.push(err);\n }\n }\n }\n }\n\n if (errors) {\n throw new UnsubscriptionError(errors);\n }\n }\n }\n\n /**\n * Adds a finalizer to this subscription, so that finalization will be unsubscribed/called\n * when this subscription is unsubscribed. If this subscription is already {@link #closed},\n * because it has already been unsubscribed, then whatever finalizer is passed to it\n * will automatically be executed (unless the finalizer itself is also a closed subscription).\n *\n * Closed Subscriptions cannot be added as finalizers to any subscription. Adding a closed\n * subscription to a any subscription will result in no operation. (A noop).\n *\n * Adding a subscription to itself, or adding `null` or `undefined` will not perform any\n * operation at all. (A noop).\n *\n * `Subscription` instances that are added to this instance will automatically remove themselves\n * if they are unsubscribed. Functions and {@link Unsubscribable} objects that you wish to remove\n * will need to be removed manually with {@link #remove}\n *\n * @param teardown The finalization logic to add to this subscription.\n */\n add(teardown: TeardownLogic): void {\n // Only add the finalizer if it's not undefined\n // and don't add a subscription to itself.\n if (teardown && teardown !== this) {\n if (this.closed) {\n // If this subscription is already closed,\n // execute whatever finalizer is handed to it automatically.\n execFinalizer(teardown);\n } else {\n if (teardown instanceof Subscription) {\n // We don't add closed subscriptions, and we don't add the same subscription\n // twice. Subscription unsubscribe is idempotent.\n if (teardown.closed || teardown._hasParent(this)) {\n return;\n }\n teardown._addParent(this);\n }\n (this._finalizers = this._finalizers ?? []).push(teardown);\n }\n }\n }\n\n /**\n * Checks to see if a this subscription already has a particular parent.\n * This will signal that this subscription has already been added to the parent in question.\n * @param parent the parent to check for\n */\n private _hasParent(parent: Subscription) {\n const { _parentage } = this;\n return _parentage === parent || (Array.isArray(_parentage) && _parentage.includes(parent));\n }\n\n /**\n * Adds a parent to this subscription so it can be removed from the parent if it\n * unsubscribes on it's own.\n *\n * NOTE: THIS ASSUMES THAT {@link _hasParent} HAS ALREADY BEEN CHECKED.\n * @param parent The parent subscription to add\n */\n private _addParent(parent: Subscription) {\n const { _parentage } = this;\n this._parentage = Array.isArray(_parentage) ? (_parentage.push(parent), _parentage) : _parentage ? [_parentage, parent] : parent;\n }\n\n /**\n * Called on a child when it is removed via {@link #remove}.\n * @param parent The parent to remove\n */\n private _removeParent(parent: Subscription) {\n const { _parentage } = this;\n if (_parentage === parent) {\n this._parentage = null;\n } else if (Array.isArray(_parentage)) {\n arrRemove(_parentage, parent);\n }\n }\n\n /**\n * Removes a finalizer from this subscription that was previously added with the {@link #add} method.\n *\n * Note that `Subscription` instances, when unsubscribed, will automatically remove themselves\n * from every other `Subscription` they have been added to. This means that using the `remove` method\n * is not a common thing and should be used thoughtfully.\n *\n * If you add the same finalizer instance of a function or an unsubscribable object to a `Subscription` instance\n * more than once, you will need to call `remove` the same number of times to remove all instances.\n *\n * All finalizer instances are removed to free up memory upon unsubscription.\n *\n * @param teardown The finalizer to remove from this subscription\n */\n remove(teardown: Exclude): void {\n const { _finalizers } = this;\n _finalizers && arrRemove(_finalizers, teardown);\n\n if (teardown instanceof Subscription) {\n teardown._removeParent(this);\n }\n }\n}\n\nexport const EMPTY_SUBSCRIPTION = Subscription.EMPTY;\n\nexport function isSubscription(value: any): value is Subscription {\n return (\n value instanceof Subscription ||\n (value && 'closed' in value && isFunction(value.remove) && isFunction(value.add) && isFunction(value.unsubscribe))\n );\n}\n\nfunction execFinalizer(finalizer: Unsubscribable | (() => void)) {\n if (isFunction(finalizer)) {\n finalizer();\n } else {\n finalizer.unsubscribe();\n }\n}\n", "import { Subscriber } from './Subscriber';\nimport { ObservableNotification } from './types';\n\n/**\n * The {@link GlobalConfig} object for RxJS. It is used to configure things\n * like how to react on unhandled errors.\n */\nexport const config: GlobalConfig = {\n onUnhandledError: null,\n onStoppedNotification: null,\n Promise: undefined,\n useDeprecatedSynchronousErrorHandling: false,\n useDeprecatedNextContext: false,\n};\n\n/**\n * The global configuration object for RxJS, used to configure things\n * like how to react on unhandled errors. Accessible via {@link config}\n * object.\n */\nexport interface GlobalConfig {\n /**\n * A registration point for unhandled errors from RxJS. These are errors that\n * cannot were not handled by consuming code in the usual subscription path. For\n * example, if you have this configured, and you subscribe to an observable without\n * providing an error handler, errors from that subscription will end up here. This\n * will _always_ be called asynchronously on another job in the runtime. This is because\n * we do not want errors thrown in this user-configured handler to interfere with the\n * behavior of the library.\n */\n onUnhandledError: ((err: any) => void) | null;\n\n /**\n * A registration point for notifications that cannot be sent to subscribers because they\n * have completed, errored or have been explicitly unsubscribed. By default, next, complete\n * and error notifications sent to stopped subscribers are noops. However, sometimes callers\n * might want a different behavior. For example, with sources that attempt to report errors\n * to stopped subscribers, a caller can configure RxJS to throw an unhandled error instead.\n * This will _always_ be called asynchronously on another job in the runtime. This is because\n * we do not want errors thrown in this user-configured handler to interfere with the\n * behavior of the library.\n */\n onStoppedNotification: ((notification: ObservableNotification, subscriber: Subscriber) => void) | null;\n\n /**\n * The promise constructor used by default for {@link Observable#toPromise toPromise} and {@link Observable#forEach forEach}\n * methods.\n *\n * @deprecated As of version 8, RxJS will no longer support this sort of injection of a\n * Promise constructor. If you need a Promise implementation other than native promises,\n * please polyfill/patch Promise as you see appropriate. Will be removed in v8.\n */\n Promise?: PromiseConstructorLike;\n\n /**\n * If true, turns on synchronous error rethrowing, which is a deprecated behavior\n * in v6 and higher. This behavior enables bad patterns like wrapping a subscribe\n * call in a try/catch block. It also enables producer interference, a nasty bug\n * where a multicast can be broken for all observers by a downstream consumer with\n * an unhandled error. DO NOT USE THIS FLAG UNLESS IT'S NEEDED TO BUY TIME\n * FOR MIGRATION REASONS.\n *\n * @deprecated As of version 8, RxJS will no longer support synchronous throwing\n * of unhandled errors. All errors will be thrown on a separate call stack to prevent bad\n * behaviors described above. Will be removed in v8.\n */\n useDeprecatedSynchronousErrorHandling: boolean;\n\n /**\n * If true, enables an as-of-yet undocumented feature from v5: The ability to access\n * `unsubscribe()` via `this` context in `next` functions created in observers passed\n * to `subscribe`.\n *\n * This is being removed because the performance was severely problematic, and it could also cause\n * issues when types other than POJOs are passed to subscribe as subscribers, as they will likely have\n * their `this` context overwritten.\n *\n * @deprecated As of version 8, RxJS will no longer support altering the\n * context of next functions provided as part of an observer to Subscribe. Instead,\n * you will have access to a subscription or a signal or token that will allow you to do things like\n * unsubscribe and test closed status. Will be removed in v8.\n */\n useDeprecatedNextContext: boolean;\n}\n", "import type { TimerHandle } from './timerHandle';\ntype SetTimeoutFunction = (handler: () => void, timeout?: number, ...args: any[]) => TimerHandle;\ntype ClearTimeoutFunction = (handle: TimerHandle) => void;\n\ninterface TimeoutProvider {\n setTimeout: SetTimeoutFunction;\n clearTimeout: ClearTimeoutFunction;\n delegate:\n | {\n setTimeout: SetTimeoutFunction;\n clearTimeout: ClearTimeoutFunction;\n }\n | undefined;\n}\n\nexport const timeoutProvider: TimeoutProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n setTimeout(handler: () => void, timeout?: number, ...args) {\n const { delegate } = timeoutProvider;\n if (delegate?.setTimeout) {\n return delegate.setTimeout(handler, timeout, ...args);\n }\n return setTimeout(handler, timeout, ...args);\n },\n clearTimeout(handle) {\n const { delegate } = timeoutProvider;\n return (delegate?.clearTimeout || clearTimeout)(handle as any);\n },\n delegate: undefined,\n};\n", "import { config } from '../config';\nimport { timeoutProvider } from '../scheduler/timeoutProvider';\n\n/**\n * Handles an error on another job either with the user-configured {@link onUnhandledError},\n * or by throwing it on that new job so it can be picked up by `window.onerror`, `process.on('error')`, etc.\n *\n * This should be called whenever there is an error that is out-of-band with the subscription\n * or when an error hits a terminal boundary of the subscription and no error handler was provided.\n *\n * @param err the error to report\n */\nexport function reportUnhandledError(err: any) {\n timeoutProvider.setTimeout(() => {\n const { onUnhandledError } = config;\n if (onUnhandledError) {\n // Execute the user-configured error handler.\n onUnhandledError(err);\n } else {\n // Throw so it is picked up by the runtime's uncaught error mechanism.\n throw err;\n }\n });\n}\n", "/* tslint:disable:no-empty */\nexport function noop() { }\n", "import { CompleteNotification, NextNotification, ErrorNotification } from './types';\n\n/**\n * A completion object optimized for memory use and created to be the\n * same \"shape\" as other notifications in v8.\n * @internal\n */\nexport const COMPLETE_NOTIFICATION = (() => createNotification('C', undefined, undefined) as CompleteNotification)();\n\n/**\n * Internal use only. Creates an optimized error notification that is the same \"shape\"\n * as other notifications.\n * @internal\n */\nexport function errorNotification(error: any): ErrorNotification {\n return createNotification('E', undefined, error) as any;\n}\n\n/**\n * Internal use only. Creates an optimized next notification that is the same \"shape\"\n * as other notifications.\n * @internal\n */\nexport function nextNotification(value: T) {\n return createNotification('N', value, undefined) as NextNotification;\n}\n\n/**\n * Ensures that all notifications created internally have the same \"shape\" in v8.\n *\n * TODO: This is only exported to support a crazy legacy test in `groupBy`.\n * @internal\n */\nexport function createNotification(kind: 'N' | 'E' | 'C', value: any, error: any) {\n return {\n kind,\n value,\n error,\n };\n}\n", "import { config } from '../config';\n\nlet context: { errorThrown: boolean; error: any } | null = null;\n\n/**\n * Handles dealing with errors for super-gross mode. Creates a context, in which\n * any synchronously thrown errors will be passed to {@link captureError}. Which\n * will record the error such that it will be rethrown after the call back is complete.\n * TODO: Remove in v8\n * @param cb An immediately executed function.\n */\nexport function errorContext(cb: () => void) {\n if (config.useDeprecatedSynchronousErrorHandling) {\n const isRoot = !context;\n if (isRoot) {\n context = { errorThrown: false, error: null };\n }\n cb();\n if (isRoot) {\n const { errorThrown, error } = context!;\n context = null;\n if (errorThrown) {\n throw error;\n }\n }\n } else {\n // This is the general non-deprecated path for everyone that\n // isn't crazy enough to use super-gross mode (useDeprecatedSynchronousErrorHandling)\n cb();\n }\n}\n\n/**\n * Captures errors only in super-gross mode.\n * @param err the error to capture\n */\nexport function captureError(err: any) {\n if (config.useDeprecatedSynchronousErrorHandling && context) {\n context.errorThrown = true;\n context.error = err;\n }\n}\n", "import { isFunction } from './util/isFunction';\nimport { Observer, ObservableNotification } from './types';\nimport { isSubscription, Subscription } from './Subscription';\nimport { config } from './config';\nimport { reportUnhandledError } from './util/reportUnhandledError';\nimport { noop } from './util/noop';\nimport { nextNotification, errorNotification, COMPLETE_NOTIFICATION } from './NotificationFactories';\nimport { timeoutProvider } from './scheduler/timeoutProvider';\nimport { captureError } from './util/errorContext';\n\n/**\n * Implements the {@link Observer} interface and extends the\n * {@link Subscription} class. While the {@link Observer} is the public API for\n * consuming the values of an {@link Observable}, all Observers get converted to\n * a Subscriber, in order to provide Subscription-like capabilities such as\n * `unsubscribe`. Subscriber is a common type in RxJS, and crucial for\n * implementing operators, but it is rarely used as a public API.\n *\n * @class Subscriber\n */\nexport class Subscriber extends Subscription implements Observer {\n /**\n * A static factory for a Subscriber, given a (potentially partial) definition\n * of an Observer.\n * @param next The `next` callback of an Observer.\n * @param error The `error` callback of an\n * Observer.\n * @param complete The `complete` callback of an\n * Observer.\n * @return A Subscriber wrapping the (partially defined)\n * Observer represented by the given arguments.\n * @nocollapse\n * @deprecated Do not use. Will be removed in v8. There is no replacement for this\n * method, and there is no reason to be creating instances of `Subscriber` directly.\n * If you have a specific use case, please file an issue.\n */\n static create(next?: (x?: T) => void, error?: (e?: any) => void, complete?: () => void): Subscriber {\n return new SafeSubscriber(next, error, complete);\n }\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n protected isStopped: boolean = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n protected destination: Subscriber | Observer; // this `any` is the escape hatch to erase extra type param (e.g. R)\n\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n * There is no reason to directly create an instance of Subscriber. This type is exported for typings reasons.\n */\n constructor(destination?: Subscriber | Observer) {\n super();\n if (destination) {\n this.destination = destination;\n // Automatically chain subscriptions together here.\n // if destination is a Subscription, then it is a Subscriber.\n if (isSubscription(destination)) {\n destination.add(this);\n }\n } else {\n this.destination = EMPTY_OBSERVER;\n }\n }\n\n /**\n * The {@link Observer} callback to receive notifications of type `next` from\n * the Observable, with a value. The Observable may call this method 0 or more\n * times.\n * @param {T} [value] The `next` value.\n * @return {void}\n */\n next(value?: T): void {\n if (this.isStopped) {\n handleStoppedNotification(nextNotification(value), this);\n } else {\n this._next(value!);\n }\n }\n\n /**\n * The {@link Observer} callback to receive notifications of type `error` from\n * the Observable, with an attached `Error`. Notifies the Observer that\n * the Observable has experienced an error condition.\n * @param {any} [err] The `error` exception.\n * @return {void}\n */\n error(err?: any): void {\n if (this.isStopped) {\n handleStoppedNotification(errorNotification(err), this);\n } else {\n this.isStopped = true;\n this._error(err);\n }\n }\n\n /**\n * The {@link Observer} callback to receive a valueless notification of type\n * `complete` from the Observable. Notifies the Observer that the Observable\n * has finished sending push-based notifications.\n * @return {void}\n */\n complete(): void {\n if (this.isStopped) {\n handleStoppedNotification(COMPLETE_NOTIFICATION, this);\n } else {\n this.isStopped = true;\n this._complete();\n }\n }\n\n unsubscribe(): void {\n if (!this.closed) {\n this.isStopped = true;\n super.unsubscribe();\n this.destination = null!;\n }\n }\n\n protected _next(value: T): void {\n this.destination.next(value);\n }\n\n protected _error(err: any): void {\n try {\n this.destination.error(err);\n } finally {\n this.unsubscribe();\n }\n }\n\n protected _complete(): void {\n try {\n this.destination.complete();\n } finally {\n this.unsubscribe();\n }\n }\n}\n\n/**\n * This bind is captured here because we want to be able to have\n * compatibility with monoid libraries that tend to use a method named\n * `bind`. In particular, a library called Monio requires this.\n */\nconst _bind = Function.prototype.bind;\n\nfunction bind any>(fn: Fn, thisArg: any): Fn {\n return _bind.call(fn, thisArg);\n}\n\n/**\n * Internal optimization only, DO NOT EXPOSE.\n * @internal\n */\nclass ConsumerObserver implements Observer {\n constructor(private partialObserver: Partial>) {}\n\n next(value: T): void {\n const { partialObserver } = this;\n if (partialObserver.next) {\n try {\n partialObserver.next(value);\n } catch (error) {\n handleUnhandledError(error);\n }\n }\n }\n\n error(err: any): void {\n const { partialObserver } = this;\n if (partialObserver.error) {\n try {\n partialObserver.error(err);\n } catch (error) {\n handleUnhandledError(error);\n }\n } else {\n handleUnhandledError(err);\n }\n }\n\n complete(): void {\n const { partialObserver } = this;\n if (partialObserver.complete) {\n try {\n partialObserver.complete();\n } catch (error) {\n handleUnhandledError(error);\n }\n }\n }\n}\n\nexport class SafeSubscriber extends Subscriber {\n constructor(\n observerOrNext?: Partial> | ((value: T) => void) | null,\n error?: ((e?: any) => void) | null,\n complete?: (() => void) | null\n ) {\n super();\n\n let partialObserver: Partial>;\n if (isFunction(observerOrNext) || !observerOrNext) {\n // The first argument is a function, not an observer. The next\n // two arguments *could* be observers, or they could be empty.\n partialObserver = {\n next: (observerOrNext ?? undefined) as (((value: T) => void) | undefined),\n error: error ?? undefined,\n complete: complete ?? undefined,\n };\n } else {\n // The first argument is a partial observer.\n let context: any;\n if (this && config.useDeprecatedNextContext) {\n // This is a deprecated path that made `this.unsubscribe()` available in\n // next handler functions passed to subscribe. This only exists behind a flag\n // now, as it is *very* slow.\n context = Object.create(observerOrNext);\n context.unsubscribe = () => this.unsubscribe();\n partialObserver = {\n next: observerOrNext.next && bind(observerOrNext.next, context),\n error: observerOrNext.error && bind(observerOrNext.error, context),\n complete: observerOrNext.complete && bind(observerOrNext.complete, context),\n };\n } else {\n // The \"normal\" path. Just use the partial observer directly.\n partialObserver = observerOrNext;\n }\n }\n\n // Wrap the partial observer to ensure it's a full observer, and\n // make sure proper error handling is accounted for.\n this.destination = new ConsumerObserver(partialObserver);\n }\n}\n\nfunction handleUnhandledError(error: any) {\n if (config.useDeprecatedSynchronousErrorHandling) {\n captureError(error);\n } else {\n // Ideal path, we report this as an unhandled error,\n // which is thrown on a new call stack.\n reportUnhandledError(error);\n }\n}\n\n/**\n * An error handler used when no error handler was supplied\n * to the SafeSubscriber -- meaning no error handler was supplied\n * do the `subscribe` call on our observable.\n * @param err The error to handle\n */\nfunction defaultErrorHandler(err: any) {\n throw err;\n}\n\n/**\n * A handler for notifications that cannot be sent to a stopped subscriber.\n * @param notification The notification being sent\n * @param subscriber The stopped subscriber\n */\nfunction handleStoppedNotification(notification: ObservableNotification, subscriber: Subscriber) {\n const { onStoppedNotification } = config;\n onStoppedNotification && timeoutProvider.setTimeout(() => onStoppedNotification(notification, subscriber));\n}\n\n/**\n * The observer used as a stub for subscriptions where the user did not\n * pass any arguments to `subscribe`. Comes with the default error handling\n * behavior.\n */\nexport const EMPTY_OBSERVER: Readonly> & { closed: true } = {\n closed: true,\n next: noop,\n error: defaultErrorHandler,\n complete: noop,\n};\n", "/**\n * Symbol.observable or a string \"@@observable\". Used for interop\n *\n * @deprecated We will no longer be exporting this symbol in upcoming versions of RxJS.\n * Instead polyfill and use Symbol.observable directly *or* use https://www.npmjs.com/package/symbol-observable\n */\nexport const observable: string | symbol = (() => (typeof Symbol === 'function' && Symbol.observable) || '@@observable')();\n", "/**\n * This function takes one parameter and just returns it. Simply put,\n * this is like `(x: T): T => x`.\n *\n * ## Examples\n *\n * This is useful in some cases when using things like `mergeMap`\n *\n * ```ts\n * import { interval, take, map, range, mergeMap, identity } from 'rxjs';\n *\n * const source$ = interval(1000).pipe(take(5));\n *\n * const result$ = source$.pipe(\n * map(i => range(i)),\n * mergeMap(identity) // same as mergeMap(x => x)\n * );\n *\n * result$.subscribe({\n * next: console.log\n * });\n * ```\n *\n * Or when you want to selectively apply an operator\n *\n * ```ts\n * import { interval, take, identity } from 'rxjs';\n *\n * const shouldLimit = () => Math.random() < 0.5;\n *\n * const source$ = interval(1000);\n *\n * const result$ = source$.pipe(shouldLimit() ? take(5) : identity);\n *\n * result$.subscribe({\n * next: console.log\n * });\n * ```\n *\n * @param x Any value that is returned by this function\n * @returns The value passed as the first parameter to this function\n */\nexport function identity(x: T): T {\n return x;\n}\n", "import { identity } from './identity';\nimport { UnaryFunction } from '../types';\n\nexport function pipe(): typeof identity;\nexport function pipe(fn1: UnaryFunction): UnaryFunction;\nexport function pipe(fn1: UnaryFunction, fn2: UnaryFunction): UnaryFunction;\nexport function pipe(fn1: UnaryFunction, fn2: UnaryFunction, fn3: UnaryFunction): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction,\n fn9: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction,\n fn9: UnaryFunction,\n ...fns: UnaryFunction[]\n): UnaryFunction;\n\n/**\n * pipe() can be called on one or more functions, each of which can take one argument (\"UnaryFunction\")\n * and uses it to return a value.\n * It returns a function that takes one argument, passes it to the first UnaryFunction, and then\n * passes the result to the next one, passes that result to the next one, and so on. \n */\nexport function pipe(...fns: Array>): UnaryFunction {\n return pipeFromArray(fns);\n}\n\n/** @internal */\nexport function pipeFromArray(fns: Array>): UnaryFunction {\n if (fns.length === 0) {\n return identity as UnaryFunction;\n }\n\n if (fns.length === 1) {\n return fns[0];\n }\n\n return function piped(input: T): R {\n return fns.reduce((prev: any, fn: UnaryFunction) => fn(prev), input as any);\n };\n}\n", "import { Operator } from './Operator';\nimport { SafeSubscriber, Subscriber } from './Subscriber';\nimport { isSubscription, Subscription } from './Subscription';\nimport { TeardownLogic, OperatorFunction, Subscribable, Observer } from './types';\nimport { observable as Symbol_observable } from './symbol/observable';\nimport { pipeFromArray } from './util/pipe';\nimport { config } from './config';\nimport { isFunction } from './util/isFunction';\nimport { errorContext } from './util/errorContext';\n\n/**\n * A representation of any set of values over any amount of time. This is the most basic building block\n * of RxJS.\n *\n * @class Observable\n */\nexport class Observable implements Subscribable {\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n */\n source: Observable | undefined;\n\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n */\n operator: Operator | undefined;\n\n /**\n * @constructor\n * @param {Function} subscribe the function that is called when the Observable is\n * initially subscribed to. This function is given a Subscriber, to which new values\n * can be `next`ed, or an `error` method can be called to raise an error, or\n * `complete` can be called to notify of a successful completion.\n */\n constructor(subscribe?: (this: Observable, subscriber: Subscriber) => TeardownLogic) {\n if (subscribe) {\n this._subscribe = subscribe;\n }\n }\n\n // HACK: Since TypeScript inherits static properties too, we have to\n // fight against TypeScript here so Subject can have a different static create signature\n /**\n * Creates a new Observable by calling the Observable constructor\n * @owner Observable\n * @method create\n * @param {Function} subscribe? the subscriber function to be passed to the Observable constructor\n * @return {Observable} a new observable\n * @nocollapse\n * @deprecated Use `new Observable()` instead. Will be removed in v8.\n */\n static create: (...args: any[]) => any = (subscribe?: (subscriber: Subscriber) => TeardownLogic) => {\n return new Observable(subscribe);\n };\n\n /**\n * Creates a new Observable, with this Observable instance as the source, and the passed\n * operator defined as the new observable's operator.\n * @method lift\n * @param operator the operator defining the operation to take on the observable\n * @return a new observable with the Operator applied\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n * If you have implemented an operator using `lift`, it is recommended that you create an\n * operator by simply returning `new Observable()` directly. See \"Creating new operators from\n * scratch\" section here: https://rxjs.dev/guide/operators\n */\n lift(operator?: Operator): Observable {\n const observable = new Observable();\n observable.source = this;\n observable.operator = operator;\n return observable;\n }\n\n subscribe(observerOrNext?: Partial> | ((value: T) => void)): Subscription;\n /** @deprecated Instead of passing separate callback arguments, use an observer argument. Signatures taking separate callback arguments will be removed in v8. Details: https://rxjs.dev/deprecations/subscribe-arguments */\n subscribe(next?: ((value: T) => void) | null, error?: ((error: any) => void) | null, complete?: (() => void) | null): Subscription;\n /**\n * Invokes an execution of an Observable and registers Observer handlers for notifications it will emit.\n *\n * Use it when you have all these Observables, but still nothing is happening.\n *\n * `subscribe` is not a regular operator, but a method that calls Observable's internal `subscribe` function. It\n * might be for example a function that you passed to Observable's constructor, but most of the time it is\n * a library implementation, which defines what will be emitted by an Observable, and when it be will emitted. This means\n * that calling `subscribe` is actually the moment when Observable starts its work, not when it is created, as it is often\n * the thought.\n *\n * Apart from starting the execution of an Observable, this method allows you to listen for values\n * that an Observable emits, as well as for when it completes or errors. You can achieve this in two\n * of the following ways.\n *\n * The first way is creating an object that implements {@link Observer} interface. It should have methods\n * defined by that interface, but note that it should be just a regular JavaScript object, which you can create\n * yourself in any way you want (ES6 class, classic function constructor, object literal etc.). In particular, do\n * not attempt to use any RxJS implementation details to create Observers - you don't need them. Remember also\n * that your object does not have to implement all methods. If you find yourself creating a method that doesn't\n * do anything, you can simply omit it. Note however, if the `error` method is not provided and an error happens,\n * it will be thrown asynchronously. Errors thrown asynchronously cannot be caught using `try`/`catch`. Instead,\n * use the {@link onUnhandledError} configuration option or use a runtime handler (like `window.onerror` or\n * `process.on('error)`) to be notified of unhandled errors. Because of this, it's recommended that you provide\n * an `error` method to avoid missing thrown errors.\n *\n * The second way is to give up on Observer object altogether and simply provide callback functions in place of its methods.\n * This means you can provide three functions as arguments to `subscribe`, where the first function is equivalent\n * of a `next` method, the second of an `error` method and the third of a `complete` method. Just as in case of an Observer,\n * if you do not need to listen for something, you can omit a function by passing `undefined` or `null`,\n * since `subscribe` recognizes these functions by where they were placed in function call. When it comes\n * to the `error` function, as with an Observer, if not provided, errors emitted by an Observable will be thrown asynchronously.\n *\n * You can, however, subscribe with no parameters at all. This may be the case where you're not interested in terminal events\n * and you also handled emissions internally by using operators (e.g. using `tap`).\n *\n * Whichever style of calling `subscribe` you use, in both cases it returns a Subscription object.\n * This object allows you to call `unsubscribe` on it, which in turn will stop the work that an Observable does and will clean\n * up all resources that an Observable used. Note that cancelling a subscription will not call `complete` callback\n * provided to `subscribe` function, which is reserved for a regular completion signal that comes from an Observable.\n *\n * Remember that callbacks provided to `subscribe` are not guaranteed to be called asynchronously.\n * It is an Observable itself that decides when these functions will be called. For example {@link of}\n * by default emits all its values synchronously. Always check documentation for how given Observable\n * will behave when subscribed and if its default behavior can be modified with a `scheduler`.\n *\n * #### Examples\n *\n * Subscribe with an {@link guide/observer Observer}\n *\n * ```ts\n * import { of } from 'rxjs';\n *\n * const sumObserver = {\n * sum: 0,\n * next(value) {\n * console.log('Adding: ' + value);\n * this.sum = this.sum + value;\n * },\n * error() {\n * // We actually could just remove this method,\n * // since we do not really care about errors right now.\n * },\n * complete() {\n * console.log('Sum equals: ' + this.sum);\n * }\n * };\n *\n * of(1, 2, 3) // Synchronously emits 1, 2, 3 and then completes.\n * .subscribe(sumObserver);\n *\n * // Logs:\n * // 'Adding: 1'\n * // 'Adding: 2'\n * // 'Adding: 3'\n * // 'Sum equals: 6'\n * ```\n *\n * Subscribe with functions ({@link deprecations/subscribe-arguments deprecated})\n *\n * ```ts\n * import { of } from 'rxjs'\n *\n * let sum = 0;\n *\n * of(1, 2, 3).subscribe(\n * value => {\n * console.log('Adding: ' + value);\n * sum = sum + value;\n * },\n * undefined,\n * () => console.log('Sum equals: ' + sum)\n * );\n *\n * // Logs:\n * // 'Adding: 1'\n * // 'Adding: 2'\n * // 'Adding: 3'\n * // 'Sum equals: 6'\n * ```\n *\n * Cancel a subscription\n *\n * ```ts\n * import { interval } from 'rxjs';\n *\n * const subscription = interval(1000).subscribe({\n * next(num) {\n * console.log(num)\n * },\n * complete() {\n * // Will not be called, even when cancelling subscription.\n * console.log('completed!');\n * }\n * });\n *\n * setTimeout(() => {\n * subscription.unsubscribe();\n * console.log('unsubscribed!');\n * }, 2500);\n *\n * // Logs:\n * // 0 after 1s\n * // 1 after 2s\n * // 'unsubscribed!' after 2.5s\n * ```\n *\n * @param {Observer|Function} observerOrNext (optional) Either an observer with methods to be called,\n * or the first of three possible handlers, which is the handler for each value emitted from the subscribed\n * Observable.\n * @param {Function} error (optional) A handler for a terminal event resulting from an error. If no error handler is provided,\n * the error will be thrown asynchronously as unhandled.\n * @param {Function} complete (optional) A handler for a terminal event resulting from successful completion.\n * @return {Subscription} a subscription reference to the registered handlers\n * @method subscribe\n */\n subscribe(\n observerOrNext?: Partial> | ((value: T) => void) | null,\n error?: ((error: any) => void) | null,\n complete?: (() => void) | null\n ): Subscription {\n const subscriber = isSubscriber(observerOrNext) ? observerOrNext : new SafeSubscriber(observerOrNext, error, complete);\n\n errorContext(() => {\n const { operator, source } = this;\n subscriber.add(\n operator\n ? // We're dealing with a subscription in the\n // operator chain to one of our lifted operators.\n operator.call(subscriber, source)\n : source\n ? // If `source` has a value, but `operator` does not, something that\n // had intimate knowledge of our API, like our `Subject`, must have\n // set it. We're going to just call `_subscribe` directly.\n this._subscribe(subscriber)\n : // In all other cases, we're likely wrapping a user-provided initializer\n // function, so we need to catch errors and handle them appropriately.\n this._trySubscribe(subscriber)\n );\n });\n\n return subscriber;\n }\n\n /** @internal */\n protected _trySubscribe(sink: Subscriber): TeardownLogic {\n try {\n return this._subscribe(sink);\n } catch (err) {\n // We don't need to return anything in this case,\n // because it's just going to try to `add()` to a subscription\n // above.\n sink.error(err);\n }\n }\n\n /**\n * Used as a NON-CANCELLABLE means of subscribing to an observable, for use with\n * APIs that expect promises, like `async/await`. You cannot unsubscribe from this.\n *\n * **WARNING**: Only use this with observables you *know* will complete. If the source\n * observable does not complete, you will end up with a promise that is hung up, and\n * potentially all of the state of an async function hanging out in memory. To avoid\n * this situation, look into adding something like {@link timeout}, {@link take},\n * {@link takeWhile}, or {@link takeUntil} amongst others.\n *\n * #### Example\n *\n * ```ts\n * import { interval, take } from 'rxjs';\n *\n * const source$ = interval(1000).pipe(take(4));\n *\n * async function getTotal() {\n * let total = 0;\n *\n * await source$.forEach(value => {\n * total += value;\n * console.log('observable -> ' + value);\n * });\n *\n * return total;\n * }\n *\n * getTotal().then(\n * total => console.log('Total: ' + total)\n * );\n *\n * // Expected:\n * // 'observable -> 0'\n * // 'observable -> 1'\n * // 'observable -> 2'\n * // 'observable -> 3'\n * // 'Total: 6'\n * ```\n *\n * @param next a handler for each value emitted by the observable\n * @return a promise that either resolves on observable completion or\n * rejects with the handled error\n */\n forEach(next: (value: T) => void): Promise;\n\n /**\n * @param next a handler for each value emitted by the observable\n * @param promiseCtor a constructor function used to instantiate the Promise\n * @return a promise that either resolves on observable completion or\n * rejects with the handled error\n * @deprecated Passing a Promise constructor will no longer be available\n * in upcoming versions of RxJS. This is because it adds weight to the library, for very\n * little benefit. If you need this functionality, it is recommended that you either\n * polyfill Promise, or you create an adapter to convert the returned native promise\n * to whatever promise implementation you wanted. Will be removed in v8.\n */\n forEach(next: (value: T) => void, promiseCtor: PromiseConstructorLike): Promise;\n\n forEach(next: (value: T) => void, promiseCtor?: PromiseConstructorLike): Promise {\n promiseCtor = getPromiseCtor(promiseCtor);\n\n return new promiseCtor((resolve, reject) => {\n const subscriber = new SafeSubscriber({\n next: (value) => {\n try {\n next(value);\n } catch (err) {\n reject(err);\n subscriber.unsubscribe();\n }\n },\n error: reject,\n complete: resolve,\n });\n this.subscribe(subscriber);\n }) as Promise;\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): TeardownLogic {\n return this.source?.subscribe(subscriber);\n }\n\n /**\n * An interop point defined by the es7-observable spec https://github.com/zenparsing/es-observable\n * @method Symbol.observable\n * @return {Observable} this instance of the observable\n */\n [Symbol_observable]() {\n return this;\n }\n\n /* tslint:disable:max-line-length */\n pipe(): Observable;\n pipe(op1: OperatorFunction): Observable;\n pipe(op1: OperatorFunction, op2: OperatorFunction): Observable;\n pipe(op1: OperatorFunction, op2: OperatorFunction, op3: OperatorFunction): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction,\n op9: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction,\n op9: OperatorFunction,\n ...operations: OperatorFunction[]\n ): Observable;\n /* tslint:enable:max-line-length */\n\n /**\n * Used to stitch together functional operators into a chain.\n * @method pipe\n * @return {Observable} the Observable result of all of the operators having\n * been called in the order they were passed in.\n *\n * ## Example\n *\n * ```ts\n * import { interval, filter, map, scan } from 'rxjs';\n *\n * interval(1000)\n * .pipe(\n * filter(x => x % 2 === 0),\n * map(x => x + x),\n * scan((acc, x) => acc + x)\n * )\n * .subscribe(x => console.log(x));\n * ```\n */\n pipe(...operations: OperatorFunction[]): Observable {\n return pipeFromArray(operations)(this);\n }\n\n /* tslint:disable:max-line-length */\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(): Promise;\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(PromiseCtor: typeof Promise): Promise;\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(PromiseCtor: PromiseConstructorLike): Promise;\n /* tslint:enable:max-line-length */\n\n /**\n * Subscribe to this Observable and get a Promise resolving on\n * `complete` with the last emission (if any).\n *\n * **WARNING**: Only use this with observables you *know* will complete. If the source\n * observable does not complete, you will end up with a promise that is hung up, and\n * potentially all of the state of an async function hanging out in memory. To avoid\n * this situation, look into adding something like {@link timeout}, {@link take},\n * {@link takeWhile}, or {@link takeUntil} amongst others.\n *\n * @method toPromise\n * @param [promiseCtor] a constructor function used to instantiate\n * the Promise\n * @return A Promise that resolves with the last value emit, or\n * rejects on an error. If there were no emissions, Promise\n * resolves with undefined.\n * @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise\n */\n toPromise(promiseCtor?: PromiseConstructorLike): Promise {\n promiseCtor = getPromiseCtor(promiseCtor);\n\n return new promiseCtor((resolve, reject) => {\n let value: T | undefined;\n this.subscribe(\n (x: T) => (value = x),\n (err: any) => reject(err),\n () => resolve(value)\n );\n }) as Promise;\n }\n}\n\n/**\n * Decides between a passed promise constructor from consuming code,\n * A default configured promise constructor, and the native promise\n * constructor and returns it. If nothing can be found, it will throw\n * an error.\n * @param promiseCtor The optional promise constructor to passed by consuming code\n */\nfunction getPromiseCtor(promiseCtor: PromiseConstructorLike | undefined) {\n return promiseCtor ?? config.Promise ?? Promise;\n}\n\nfunction isObserver(value: any): value is Observer {\n return value && isFunction(value.next) && isFunction(value.error) && isFunction(value.complete);\n}\n\nfunction isSubscriber(value: any): value is Subscriber {\n return (value && value instanceof Subscriber) || (isObserver(value) && isSubscription(value));\n}\n", "import { Observable } from '../Observable';\nimport { Subscriber } from '../Subscriber';\nimport { OperatorFunction } from '../types';\nimport { isFunction } from './isFunction';\n\n/**\n * Used to determine if an object is an Observable with a lift function.\n */\nexport function hasLift(source: any): source is { lift: InstanceType['lift'] } {\n return isFunction(source?.lift);\n}\n\n/**\n * Creates an `OperatorFunction`. Used to define operators throughout the library in a concise way.\n * @param init The logic to connect the liftedSource to the subscriber at the moment of subscription.\n */\nexport function operate(\n init: (liftedSource: Observable, subscriber: Subscriber) => (() => void) | void\n): OperatorFunction {\n return (source: Observable) => {\n if (hasLift(source)) {\n return source.lift(function (this: Subscriber, liftedSource: Observable) {\n try {\n return init(liftedSource, this);\n } catch (err) {\n this.error(err);\n }\n });\n }\n throw new TypeError('Unable to lift unknown Observable type');\n };\n}\n", "import { Subscriber } from '../Subscriber';\n\n/**\n * Creates an instance of an `OperatorSubscriber`.\n * @param destination The downstream subscriber.\n * @param onNext Handles next values, only called if this subscriber is not stopped or closed. Any\n * error that occurs in this function is caught and sent to the `error` method of this subscriber.\n * @param onError Handles errors from the subscription, any errors that occur in this handler are caught\n * and send to the `destination` error handler.\n * @param onComplete Handles completion notification from the subscription. Any errors that occur in\n * this handler are sent to the `destination` error handler.\n * @param onFinalize Additional teardown logic here. This will only be called on teardown if the\n * subscriber itself is not already closed. This is called after all other teardown logic is executed.\n */\nexport function createOperatorSubscriber(\n destination: Subscriber,\n onNext?: (value: T) => void,\n onComplete?: () => void,\n onError?: (err: any) => void,\n onFinalize?: () => void\n): Subscriber {\n return new OperatorSubscriber(destination, onNext, onComplete, onError, onFinalize);\n}\n\n/**\n * A generic helper for allowing operators to be created with a Subscriber and\n * use closures to capture necessary state from the operator function itself.\n */\nexport class OperatorSubscriber extends Subscriber {\n /**\n * Creates an instance of an `OperatorSubscriber`.\n * @param destination The downstream subscriber.\n * @param onNext Handles next values, only called if this subscriber is not stopped or closed. Any\n * error that occurs in this function is caught and sent to the `error` method of this subscriber.\n * @param onError Handles errors from the subscription, any errors that occur in this handler are caught\n * and send to the `destination` error handler.\n * @param onComplete Handles completion notification from the subscription. Any errors that occur in\n * this handler are sent to the `destination` error handler.\n * @param onFinalize Additional finalization logic here. This will only be called on finalization if the\n * subscriber itself is not already closed. This is called after all other finalization logic is executed.\n * @param shouldUnsubscribe An optional check to see if an unsubscribe call should truly unsubscribe.\n * NOTE: This currently **ONLY** exists to support the strange behavior of {@link groupBy}, where unsubscription\n * to the resulting observable does not actually disconnect from the source if there are active subscriptions\n * to any grouped observable. (DO NOT EXPOSE OR USE EXTERNALLY!!!)\n */\n constructor(\n destination: Subscriber,\n onNext?: (value: T) => void,\n onComplete?: () => void,\n onError?: (err: any) => void,\n private onFinalize?: () => void,\n private shouldUnsubscribe?: () => boolean\n ) {\n // It's important - for performance reasons - that all of this class's\n // members are initialized and that they are always initialized in the same\n // order. This will ensure that all OperatorSubscriber instances have the\n // same hidden class in V8. This, in turn, will help keep the number of\n // hidden classes involved in property accesses within the base class as\n // low as possible. If the number of hidden classes involved exceeds four,\n // the property accesses will become megamorphic and performance penalties\n // will be incurred - i.e. inline caches won't be used.\n //\n // The reasons for ensuring all instances have the same hidden class are\n // further discussed in this blog post from Benedikt Meurer:\n // https://benediktmeurer.de/2018/03/23/impact-of-polymorphism-on-component-based-frameworks-like-react/\n super(destination);\n this._next = onNext\n ? function (this: OperatorSubscriber, value: T) {\n try {\n onNext(value);\n } catch (err) {\n destination.error(err);\n }\n }\n : super._next;\n this._error = onError\n ? function (this: OperatorSubscriber, err: any) {\n try {\n onError(err);\n } catch (err) {\n // Send any errors that occur down stream.\n destination.error(err);\n } finally {\n // Ensure finalization.\n this.unsubscribe();\n }\n }\n : super._error;\n this._complete = onComplete\n ? function (this: OperatorSubscriber) {\n try {\n onComplete();\n } catch (err) {\n // Send any errors that occur down stream.\n destination.error(err);\n } finally {\n // Ensure finalization.\n this.unsubscribe();\n }\n }\n : super._complete;\n }\n\n unsubscribe() {\n if (!this.shouldUnsubscribe || this.shouldUnsubscribe()) {\n const { closed } = this;\n super.unsubscribe();\n // Execute additional teardown if we have any and we didn't already do so.\n !closed && this.onFinalize?.();\n }\n }\n}\n", "import { Subscription } from '../Subscription';\n\ninterface AnimationFrameProvider {\n schedule(callback: FrameRequestCallback): Subscription;\n requestAnimationFrame: typeof requestAnimationFrame;\n cancelAnimationFrame: typeof cancelAnimationFrame;\n delegate:\n | {\n requestAnimationFrame: typeof requestAnimationFrame;\n cancelAnimationFrame: typeof cancelAnimationFrame;\n }\n | undefined;\n}\n\nexport const animationFrameProvider: AnimationFrameProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n schedule(callback) {\n let request = requestAnimationFrame;\n let cancel: typeof cancelAnimationFrame | undefined = cancelAnimationFrame;\n const { delegate } = animationFrameProvider;\n if (delegate) {\n request = delegate.requestAnimationFrame;\n cancel = delegate.cancelAnimationFrame;\n }\n const handle = request((timestamp) => {\n // Clear the cancel function. The request has been fulfilled, so\n // attempting to cancel the request upon unsubscription would be\n // pointless.\n cancel = undefined;\n callback(timestamp);\n });\n return new Subscription(() => cancel?.(handle));\n },\n requestAnimationFrame(...args) {\n const { delegate } = animationFrameProvider;\n return (delegate?.requestAnimationFrame || requestAnimationFrame)(...args);\n },\n cancelAnimationFrame(...args) {\n const { delegate } = animationFrameProvider;\n return (delegate?.cancelAnimationFrame || cancelAnimationFrame)(...args);\n },\n delegate: undefined,\n};\n", "import { createErrorClass } from './createErrorClass';\n\nexport interface ObjectUnsubscribedError extends Error {}\n\nexport interface ObjectUnsubscribedErrorCtor {\n /**\n * @deprecated Internal implementation detail. Do not construct error instances.\n * Cannot be tagged as internal: https://github.com/ReactiveX/rxjs/issues/6269\n */\n new (): ObjectUnsubscribedError;\n}\n\n/**\n * An error thrown when an action is invalid because the object has been\n * unsubscribed.\n *\n * @see {@link Subject}\n * @see {@link BehaviorSubject}\n *\n * @class ObjectUnsubscribedError\n */\nexport const ObjectUnsubscribedError: ObjectUnsubscribedErrorCtor = createErrorClass(\n (_super) =>\n function ObjectUnsubscribedErrorImpl(this: any) {\n _super(this);\n this.name = 'ObjectUnsubscribedError';\n this.message = 'object unsubscribed';\n }\n);\n", "import { Operator } from './Operator';\nimport { Observable } from './Observable';\nimport { Subscriber } from './Subscriber';\nimport { Subscription, EMPTY_SUBSCRIPTION } from './Subscription';\nimport { Observer, SubscriptionLike, TeardownLogic } from './types';\nimport { ObjectUnsubscribedError } from './util/ObjectUnsubscribedError';\nimport { arrRemove } from './util/arrRemove';\nimport { errorContext } from './util/errorContext';\n\n/**\n * A Subject is a special type of Observable that allows values to be\n * multicasted to many Observers. Subjects are like EventEmitters.\n *\n * Every Subject is an Observable and an Observer. You can subscribe to a\n * Subject, and you can call next to feed values as well as error and complete.\n */\nexport class Subject extends Observable implements SubscriptionLike {\n closed = false;\n\n private currentObservers: Observer[] | null = null;\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n observers: Observer[] = [];\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n isStopped = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n hasError = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n thrownError: any = null;\n\n /**\n * Creates a \"subject\" by basically gluing an observer to an observable.\n *\n * @nocollapse\n * @deprecated Recommended you do not use. Will be removed at some point in the future. Plans for replacement still under discussion.\n */\n static create: (...args: any[]) => any = (destination: Observer, source: Observable): AnonymousSubject => {\n return new AnonymousSubject(destination, source);\n };\n\n constructor() {\n // NOTE: This must be here to obscure Observable's constructor.\n super();\n }\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n lift(operator: Operator): Observable {\n const subject = new AnonymousSubject(this, this);\n subject.operator = operator as any;\n return subject as any;\n }\n\n /** @internal */\n protected _throwIfClosed() {\n if (this.closed) {\n throw new ObjectUnsubscribedError();\n }\n }\n\n next(value: T) {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n if (!this.currentObservers) {\n this.currentObservers = Array.from(this.observers);\n }\n for (const observer of this.currentObservers) {\n observer.next(value);\n }\n }\n });\n }\n\n error(err: any) {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n this.hasError = this.isStopped = true;\n this.thrownError = err;\n const { observers } = this;\n while (observers.length) {\n observers.shift()!.error(err);\n }\n }\n });\n }\n\n complete() {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n this.isStopped = true;\n const { observers } = this;\n while (observers.length) {\n observers.shift()!.complete();\n }\n }\n });\n }\n\n unsubscribe() {\n this.isStopped = this.closed = true;\n this.observers = this.currentObservers = null!;\n }\n\n get observed() {\n return this.observers?.length > 0;\n }\n\n /** @internal */\n protected _trySubscribe(subscriber: Subscriber): TeardownLogic {\n this._throwIfClosed();\n return super._trySubscribe(subscriber);\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n this._throwIfClosed();\n this._checkFinalizedStatuses(subscriber);\n return this._innerSubscribe(subscriber);\n }\n\n /** @internal */\n protected _innerSubscribe(subscriber: Subscriber) {\n const { hasError, isStopped, observers } = this;\n if (hasError || isStopped) {\n return EMPTY_SUBSCRIPTION;\n }\n this.currentObservers = null;\n observers.push(subscriber);\n return new Subscription(() => {\n this.currentObservers = null;\n arrRemove(observers, subscriber);\n });\n }\n\n /** @internal */\n protected _checkFinalizedStatuses(subscriber: Subscriber) {\n const { hasError, thrownError, isStopped } = this;\n if (hasError) {\n subscriber.error(thrownError);\n } else if (isStopped) {\n subscriber.complete();\n }\n }\n\n /**\n * Creates a new Observable with this Subject as the source. You can do this\n * to create custom Observer-side logic of the Subject and conceal it from\n * code that uses the Observable.\n * @return {Observable} Observable that the Subject casts to\n */\n asObservable(): Observable {\n const observable: any = new Observable();\n observable.source = this;\n return observable;\n }\n}\n\n/**\n * @class AnonymousSubject\n */\nexport class AnonymousSubject extends Subject {\n constructor(\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n public destination?: Observer,\n source?: Observable\n ) {\n super();\n this.source = source;\n }\n\n next(value: T) {\n this.destination?.next?.(value);\n }\n\n error(err: any) {\n this.destination?.error?.(err);\n }\n\n complete() {\n this.destination?.complete?.();\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n return this.source?.subscribe(subscriber) ?? EMPTY_SUBSCRIPTION;\n }\n}\n", "import { Subject } from './Subject';\nimport { Subscriber } from './Subscriber';\nimport { Subscription } from './Subscription';\n\n/**\n * A variant of Subject that requires an initial value and emits its current\n * value whenever it is subscribed to.\n *\n * @class BehaviorSubject\n */\nexport class BehaviorSubject extends Subject {\n constructor(private _value: T) {\n super();\n }\n\n get value(): T {\n return this.getValue();\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n const subscription = super._subscribe(subscriber);\n !subscription.closed && subscriber.next(this._value);\n return subscription;\n }\n\n getValue(): T {\n const { hasError, thrownError, _value } = this;\n if (hasError) {\n throw thrownError;\n }\n this._throwIfClosed();\n return _value;\n }\n\n next(value: T): void {\n super.next((this._value = value));\n }\n}\n", "import { TimestampProvider } from '../types';\n\ninterface DateTimestampProvider extends TimestampProvider {\n delegate: TimestampProvider | undefined;\n}\n\nexport const dateTimestampProvider: DateTimestampProvider = {\n now() {\n // Use the variable rather than `this` so that the function can be called\n // without being bound to the provider.\n return (dateTimestampProvider.delegate || Date).now();\n },\n delegate: undefined,\n};\n", "import { Subject } from './Subject';\nimport { TimestampProvider } from './types';\nimport { Subscriber } from './Subscriber';\nimport { Subscription } from './Subscription';\nimport { dateTimestampProvider } from './scheduler/dateTimestampProvider';\n\n/**\n * A variant of {@link Subject} that \"replays\" old values to new subscribers by emitting them when they first subscribe.\n *\n * `ReplaySubject` has an internal buffer that will store a specified number of values that it has observed. Like `Subject`,\n * `ReplaySubject` \"observes\" values by having them passed to its `next` method. When it observes a value, it will store that\n * value for a time determined by the configuration of the `ReplaySubject`, as passed to its constructor.\n *\n * When a new subscriber subscribes to the `ReplaySubject` instance, it will synchronously emit all values in its buffer in\n * a First-In-First-Out (FIFO) manner. The `ReplaySubject` will also complete, if it has observed completion; and it will\n * error if it has observed an error.\n *\n * There are two main configuration items to be concerned with:\n *\n * 1. `bufferSize` - This will determine how many items are stored in the buffer, defaults to infinite.\n * 2. `windowTime` - The amount of time to hold a value in the buffer before removing it from the buffer.\n *\n * Both configurations may exist simultaneously. So if you would like to buffer a maximum of 3 values, as long as the values\n * are less than 2 seconds old, you could do so with a `new ReplaySubject(3, 2000)`.\n *\n * ### Differences with BehaviorSubject\n *\n * `BehaviorSubject` is similar to `new ReplaySubject(1)`, with a couple of exceptions:\n *\n * 1. `BehaviorSubject` comes \"primed\" with a single value upon construction.\n * 2. `ReplaySubject` will replay values, even after observing an error, where `BehaviorSubject` will not.\n *\n * @see {@link Subject}\n * @see {@link BehaviorSubject}\n * @see {@link shareReplay}\n */\nexport class ReplaySubject extends Subject {\n private _buffer: (T | number)[] = [];\n private _infiniteTimeWindow = true;\n\n /**\n * @param bufferSize The size of the buffer to replay on subscription\n * @param windowTime The amount of time the buffered items will stay buffered\n * @param timestampProvider An object with a `now()` method that provides the current timestamp. This is used to\n * calculate the amount of time something has been buffered.\n */\n constructor(\n private _bufferSize = Infinity,\n private _windowTime = Infinity,\n private _timestampProvider: TimestampProvider = dateTimestampProvider\n ) {\n super();\n this._infiniteTimeWindow = _windowTime === Infinity;\n this._bufferSize = Math.max(1, _bufferSize);\n this._windowTime = Math.max(1, _windowTime);\n }\n\n next(value: T): void {\n const { isStopped, _buffer, _infiniteTimeWindow, _timestampProvider, _windowTime } = this;\n if (!isStopped) {\n _buffer.push(value);\n !_infiniteTimeWindow && _buffer.push(_timestampProvider.now() + _windowTime);\n }\n this._trimBuffer();\n super.next(value);\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n this._throwIfClosed();\n this._trimBuffer();\n\n const subscription = this._innerSubscribe(subscriber);\n\n const { _infiniteTimeWindow, _buffer } = this;\n // We use a copy here, so reentrant code does not mutate our array while we're\n // emitting it to a new subscriber.\n const copy = _buffer.slice();\n for (let i = 0; i < copy.length && !subscriber.closed; i += _infiniteTimeWindow ? 1 : 2) {\n subscriber.next(copy[i] as T);\n }\n\n this._checkFinalizedStatuses(subscriber);\n\n return subscription;\n }\n\n private _trimBuffer() {\n const { _bufferSize, _timestampProvider, _buffer, _infiniteTimeWindow } = this;\n // If we don't have an infinite buffer size, and we're over the length,\n // use splice to truncate the old buffer values off. Note that we have to\n // double the size for instances where we're not using an infinite time window\n // because we're storing the values and the timestamps in the same array.\n const adjustedBufferSize = (_infiniteTimeWindow ? 1 : 2) * _bufferSize;\n _bufferSize < Infinity && adjustedBufferSize < _buffer.length && _buffer.splice(0, _buffer.length - adjustedBufferSize);\n\n // Now, if we're not in an infinite time window, remove all values where the time is\n // older than what is allowed.\n if (!_infiniteTimeWindow) {\n const now = _timestampProvider.now();\n let last = 0;\n // Search the array for the first timestamp that isn't expired and\n // truncate the buffer up to that point.\n for (let i = 1; i < _buffer.length && (_buffer[i] as number) <= now; i += 2) {\n last = i;\n }\n last && _buffer.splice(0, last + 1);\n }\n }\n}\n", "import { Scheduler } from '../Scheduler';\nimport { Subscription } from '../Subscription';\nimport { SchedulerAction } from '../types';\n\n/**\n * A unit of work to be executed in a `scheduler`. An action is typically\n * created from within a {@link SchedulerLike} and an RxJS user does not need to concern\n * themselves about creating and manipulating an Action.\n *\n * ```ts\n * class Action extends Subscription {\n * new (scheduler: Scheduler, work: (state?: T) => void);\n * schedule(state?: T, delay: number = 0): Subscription;\n * }\n * ```\n *\n * @class Action\n */\nexport class Action extends Subscription {\n constructor(scheduler: Scheduler, work: (this: SchedulerAction, state?: T) => void) {\n super();\n }\n /**\n * Schedules this action on its parent {@link SchedulerLike} for execution. May be passed\n * some context object, `state`. May happen at some point in the future,\n * according to the `delay` parameter, if specified.\n * @param {T} [state] Some contextual data that the `work` function uses when\n * called by the Scheduler.\n * @param {number} [delay] Time to wait before executing the work, where the\n * time unit is implicit and defined by the Scheduler.\n * @return {void}\n */\n public schedule(state?: T, delay: number = 0): Subscription {\n return this;\n }\n}\n", "import type { TimerHandle } from './timerHandle';\ntype SetIntervalFunction = (handler: () => void, timeout?: number, ...args: any[]) => TimerHandle;\ntype ClearIntervalFunction = (handle: TimerHandle) => void;\n\ninterface IntervalProvider {\n setInterval: SetIntervalFunction;\n clearInterval: ClearIntervalFunction;\n delegate:\n | {\n setInterval: SetIntervalFunction;\n clearInterval: ClearIntervalFunction;\n }\n | undefined;\n}\n\nexport const intervalProvider: IntervalProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n setInterval(handler: () => void, timeout?: number, ...args) {\n const { delegate } = intervalProvider;\n if (delegate?.setInterval) {\n return delegate.setInterval(handler, timeout, ...args);\n }\n return setInterval(handler, timeout, ...args);\n },\n clearInterval(handle) {\n const { delegate } = intervalProvider;\n return (delegate?.clearInterval || clearInterval)(handle as any);\n },\n delegate: undefined,\n};\n", "import { Action } from './Action';\nimport { SchedulerAction } from '../types';\nimport { Subscription } from '../Subscription';\nimport { AsyncScheduler } from './AsyncScheduler';\nimport { intervalProvider } from './intervalProvider';\nimport { arrRemove } from '../util/arrRemove';\nimport { TimerHandle } from './timerHandle';\n\nexport class AsyncAction extends Action {\n public id: TimerHandle | undefined;\n public state?: T;\n // @ts-ignore: Property has no initializer and is not definitely assigned\n public delay: number;\n protected pending: boolean = false;\n\n constructor(protected scheduler: AsyncScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n public schedule(state?: T, delay: number = 0): Subscription {\n if (this.closed) {\n return this;\n }\n\n // Always replace the current state with the new state.\n this.state = state;\n\n const id = this.id;\n const scheduler = this.scheduler;\n\n //\n // Important implementation note:\n //\n // Actions only execute once by default, unless rescheduled from within the\n // scheduled callback. This allows us to implement single and repeat\n // actions via the same code path, without adding API surface area, as well\n // as mimic traditional recursion but across asynchronous boundaries.\n //\n // However, JS runtimes and timers distinguish between intervals achieved by\n // serial `setTimeout` calls vs. a single `setInterval` call. An interval of\n // serial `setTimeout` calls can be individually delayed, which delays\n // scheduling the next `setTimeout`, and so on. `setInterval` attempts to\n // guarantee the interval callback will be invoked more precisely to the\n // interval period, regardless of load.\n //\n // Therefore, we use `setInterval` to schedule single and repeat actions.\n // If the action reschedules itself with the same delay, the interval is not\n // canceled. If the action doesn't reschedule, or reschedules with a\n // different delay, the interval will be canceled after scheduled callback\n // execution.\n //\n if (id != null) {\n this.id = this.recycleAsyncId(scheduler, id, delay);\n }\n\n // Set the pending flag indicating that this action has been scheduled, or\n // has recursively rescheduled itself.\n this.pending = true;\n\n this.delay = delay;\n // If this action has already an async Id, don't request a new one.\n this.id = this.id ?? this.requestAsyncId(scheduler, this.id, delay);\n\n return this;\n }\n\n protected requestAsyncId(scheduler: AsyncScheduler, _id?: TimerHandle, delay: number = 0): TimerHandle {\n return intervalProvider.setInterval(scheduler.flush.bind(scheduler, this), delay);\n }\n\n protected recycleAsyncId(_scheduler: AsyncScheduler, id?: TimerHandle, delay: number | null = 0): TimerHandle | undefined {\n // If this action is rescheduled with the same delay time, don't clear the interval id.\n if (delay != null && this.delay === delay && this.pending === false) {\n return id;\n }\n // Otherwise, if the action's delay time is different from the current delay,\n // or the action has been rescheduled before it's executed, clear the interval id\n if (id != null) {\n intervalProvider.clearInterval(id);\n }\n\n return undefined;\n }\n\n /**\n * Immediately executes this action and the `work` it contains.\n * @return {any}\n */\n public execute(state: T, delay: number): any {\n if (this.closed) {\n return new Error('executing a cancelled action');\n }\n\n this.pending = false;\n const error = this._execute(state, delay);\n if (error) {\n return error;\n } else if (this.pending === false && this.id != null) {\n // Dequeue if the action didn't reschedule itself. Don't call\n // unsubscribe(), because the action could reschedule later.\n // For example:\n // ```\n // scheduler.schedule(function doWork(counter) {\n // /* ... I'm a busy worker bee ... */\n // var originalAction = this;\n // /* wait 100ms before rescheduling the action */\n // setTimeout(function () {\n // originalAction.schedule(counter + 1);\n // }, 100);\n // }, 1000);\n // ```\n this.id = this.recycleAsyncId(this.scheduler, this.id, null);\n }\n }\n\n protected _execute(state: T, _delay: number): any {\n let errored: boolean = false;\n let errorValue: any;\n try {\n this.work(state);\n } catch (e) {\n errored = true;\n // HACK: Since code elsewhere is relying on the \"truthiness\" of the\n // return here, we can't have it return \"\" or 0 or false.\n // TODO: Clean this up when we refactor schedulers mid-version-8 or so.\n errorValue = e ? e : new Error('Scheduled action threw falsy error');\n }\n if (errored) {\n this.unsubscribe();\n return errorValue;\n }\n }\n\n unsubscribe() {\n if (!this.closed) {\n const { id, scheduler } = this;\n const { actions } = scheduler;\n\n this.work = this.state = this.scheduler = null!;\n this.pending = false;\n\n arrRemove(actions, this);\n if (id != null) {\n this.id = this.recycleAsyncId(scheduler, id, null);\n }\n\n this.delay = null!;\n super.unsubscribe();\n }\n }\n}\n", "import { Action } from './scheduler/Action';\nimport { Subscription } from './Subscription';\nimport { SchedulerLike, SchedulerAction } from './types';\nimport { dateTimestampProvider } from './scheduler/dateTimestampProvider';\n\n/**\n * An execution context and a data structure to order tasks and schedule their\n * execution. Provides a notion of (potentially virtual) time, through the\n * `now()` getter method.\n *\n * Each unit of work in a Scheduler is called an `Action`.\n *\n * ```ts\n * class Scheduler {\n * now(): number;\n * schedule(work, delay?, state?): Subscription;\n * }\n * ```\n *\n * @class Scheduler\n * @deprecated Scheduler is an internal implementation detail of RxJS, and\n * should not be used directly. Rather, create your own class and implement\n * {@link SchedulerLike}. Will be made internal in v8.\n */\nexport class Scheduler implements SchedulerLike {\n public static now: () => number = dateTimestampProvider.now;\n\n constructor(private schedulerActionCtor: typeof Action, now: () => number = Scheduler.now) {\n this.now = now;\n }\n\n /**\n * A getter method that returns a number representing the current time\n * (at the time this function was called) according to the scheduler's own\n * internal clock.\n * @return {number} A number that represents the current time. May or may not\n * have a relation to wall-clock time. May or may not refer to a time unit\n * (e.g. milliseconds).\n */\n public now: () => number;\n\n /**\n * Schedules a function, `work`, for execution. May happen at some point in\n * the future, according to the `delay` parameter, if specified. May be passed\n * some context object, `state`, which will be passed to the `work` function.\n *\n * The given arguments will be processed an stored as an Action object in a\n * queue of actions.\n *\n * @param {function(state: ?T): ?Subscription} work A function representing a\n * task, or some unit of work to be executed by the Scheduler.\n * @param {number} [delay] Time to wait before executing the work, where the\n * time unit is implicit and defined by the Scheduler itself.\n * @param {T} [state] Some contextual data that the `work` function uses when\n * called by the Scheduler.\n * @return {Subscription} A subscription in order to be able to unsubscribe\n * the scheduled work.\n */\n public schedule(work: (this: SchedulerAction, state?: T) => void, delay: number = 0, state?: T): Subscription {\n return new this.schedulerActionCtor(this, work).schedule(state, delay);\n }\n}\n", "import { Scheduler } from '../Scheduler';\nimport { Action } from './Action';\nimport { AsyncAction } from './AsyncAction';\nimport { TimerHandle } from './timerHandle';\n\nexport class AsyncScheduler extends Scheduler {\n public actions: Array> = [];\n /**\n * A flag to indicate whether the Scheduler is currently executing a batch of\n * queued actions.\n * @type {boolean}\n * @internal\n */\n public _active: boolean = false;\n /**\n * An internal ID used to track the latest asynchronous task such as those\n * coming from `setTimeout`, `setInterval`, `requestAnimationFrame`, and\n * others.\n * @type {any}\n * @internal\n */\n public _scheduled: TimerHandle | undefined;\n\n constructor(SchedulerAction: typeof Action, now: () => number = Scheduler.now) {\n super(SchedulerAction, now);\n }\n\n public flush(action: AsyncAction): void {\n const { actions } = this;\n\n if (this._active) {\n actions.push(action);\n return;\n }\n\n let error: any;\n this._active = true;\n\n do {\n if ((error = action.execute(action.state, action.delay))) {\n break;\n }\n } while ((action = actions.shift()!)); // exhaust the scheduler queue\n\n this._active = false;\n\n if (error) {\n while ((action = actions.shift()!)) {\n action.unsubscribe();\n }\n throw error;\n }\n }\n}\n", "import { AsyncAction } from './AsyncAction';\nimport { AsyncScheduler } from './AsyncScheduler';\n\n/**\n *\n * Async Scheduler\n *\n * Schedule task as if you used setTimeout(task, duration)\n *\n * `async` scheduler schedules tasks asynchronously, by putting them on the JavaScript\n * event loop queue. It is best used to delay tasks in time or to schedule tasks repeating\n * in intervals.\n *\n * If you just want to \"defer\" task, that is to perform it right after currently\n * executing synchronous code ends (commonly achieved by `setTimeout(deferredTask, 0)`),\n * better choice will be the {@link asapScheduler} scheduler.\n *\n * ## Examples\n * Use async scheduler to delay task\n * ```ts\n * import { asyncScheduler } from 'rxjs';\n *\n * const task = () => console.log('it works!');\n *\n * asyncScheduler.schedule(task, 2000);\n *\n * // After 2 seconds logs:\n * // \"it works!\"\n * ```\n *\n * Use async scheduler to repeat task in intervals\n * ```ts\n * import { asyncScheduler } from 'rxjs';\n *\n * function task(state) {\n * console.log(state);\n * this.schedule(state + 1, 1000); // `this` references currently executing Action,\n * // which we reschedule with new state and delay\n * }\n *\n * asyncScheduler.schedule(task, 3000, 0);\n *\n * // Logs:\n * // 0 after 3s\n * // 1 after 4s\n * // 2 after 5s\n * // 3 after 6s\n * ```\n */\n\nexport const asyncScheduler = new AsyncScheduler(AsyncAction);\n\n/**\n * @deprecated Renamed to {@link asyncScheduler}. Will be removed in v8.\n */\nexport const async = asyncScheduler;\n", "import { AsyncAction } from './AsyncAction';\nimport { Subscription } from '../Subscription';\nimport { QueueScheduler } from './QueueScheduler';\nimport { SchedulerAction } from '../types';\nimport { TimerHandle } from './timerHandle';\n\nexport class QueueAction extends AsyncAction {\n constructor(protected scheduler: QueueScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n public schedule(state?: T, delay: number = 0): Subscription {\n if (delay > 0) {\n return super.schedule(state, delay);\n }\n this.delay = delay;\n this.state = state;\n this.scheduler.flush(this);\n return this;\n }\n\n public execute(state: T, delay: number): any {\n return delay > 0 || this.closed ? super.execute(state, delay) : this._execute(state, delay);\n }\n\n protected requestAsyncId(scheduler: QueueScheduler, id?: TimerHandle, delay: number = 0): TimerHandle {\n // If delay exists and is greater than 0, or if the delay is null (the\n // action wasn't rescheduled) but was originally scheduled as an async\n // action, then recycle as an async action.\n\n if ((delay != null && delay > 0) || (delay == null && this.delay > 0)) {\n return super.requestAsyncId(scheduler, id, delay);\n }\n\n // Otherwise flush the scheduler starting with this action.\n scheduler.flush(this);\n\n // HACK: In the past, this was returning `void`. However, `void` isn't a valid\n // `TimerHandle`, and generally the return value here isn't really used. So the\n // compromise is to return `0` which is both \"falsy\" and a valid `TimerHandle`,\n // as opposed to refactoring every other instanceo of `requestAsyncId`.\n return 0;\n }\n}\n", "import { AsyncScheduler } from './AsyncScheduler';\n\nexport class QueueScheduler extends AsyncScheduler {\n}\n", "import { QueueAction } from './QueueAction';\nimport { QueueScheduler } from './QueueScheduler';\n\n/**\n *\n * Queue Scheduler\n *\n * Put every next task on a queue, instead of executing it immediately\n *\n * `queue` scheduler, when used with delay, behaves the same as {@link asyncScheduler} scheduler.\n *\n * When used without delay, it schedules given task synchronously - executes it right when\n * it is scheduled. However when called recursively, that is when inside the scheduled task,\n * another task is scheduled with queue scheduler, instead of executing immediately as well,\n * that task will be put on a queue and wait for current one to finish.\n *\n * This means that when you execute task with `queue` scheduler, you are sure it will end\n * before any other task scheduled with that scheduler will start.\n *\n * ## Examples\n * Schedule recursively first, then do something\n * ```ts\n * import { queueScheduler } from 'rxjs';\n *\n * queueScheduler.schedule(() => {\n * queueScheduler.schedule(() => console.log('second')); // will not happen now, but will be put on a queue\n *\n * console.log('first');\n * });\n *\n * // Logs:\n * // \"first\"\n * // \"second\"\n * ```\n *\n * Reschedule itself recursively\n * ```ts\n * import { queueScheduler } from 'rxjs';\n *\n * queueScheduler.schedule(function(state) {\n * if (state !== 0) {\n * console.log('before', state);\n * this.schedule(state - 1); // `this` references currently executing Action,\n * // which we reschedule with new state\n * console.log('after', state);\n * }\n * }, 0, 3);\n *\n * // In scheduler that runs recursively, you would expect:\n * // \"before\", 3\n * // \"before\", 2\n * // \"before\", 1\n * // \"after\", 1\n * // \"after\", 2\n * // \"after\", 3\n *\n * // But with queue it logs:\n * // \"before\", 3\n * // \"after\", 3\n * // \"before\", 2\n * // \"after\", 2\n * // \"before\", 1\n * // \"after\", 1\n * ```\n */\n\nexport const queueScheduler = new QueueScheduler(QueueAction);\n\n/**\n * @deprecated Renamed to {@link queueScheduler}. Will be removed in v8.\n */\nexport const queue = queueScheduler;\n", "import { AsyncAction } from './AsyncAction';\nimport { AnimationFrameScheduler } from './AnimationFrameScheduler';\nimport { SchedulerAction } from '../types';\nimport { animationFrameProvider } from './animationFrameProvider';\nimport { TimerHandle } from './timerHandle';\n\nexport class AnimationFrameAction extends AsyncAction {\n constructor(protected scheduler: AnimationFrameScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n protected requestAsyncId(scheduler: AnimationFrameScheduler, id?: TimerHandle, delay: number = 0): TimerHandle {\n // If delay is greater than 0, request as an async action.\n if (delay !== null && delay > 0) {\n return super.requestAsyncId(scheduler, id, delay);\n }\n // Push the action to the end of the scheduler queue.\n scheduler.actions.push(this);\n // If an animation frame has already been requested, don't request another\n // one. If an animation frame hasn't been requested yet, request one. Return\n // the current animation frame request id.\n return scheduler._scheduled || (scheduler._scheduled = animationFrameProvider.requestAnimationFrame(() => scheduler.flush(undefined)));\n }\n\n protected recycleAsyncId(scheduler: AnimationFrameScheduler, id?: TimerHandle, delay: number = 0): TimerHandle | undefined {\n // If delay exists and is greater than 0, or if the delay is null (the\n // action wasn't rescheduled) but was originally scheduled as an async\n // action, then recycle as an async action.\n if (delay != null ? delay > 0 : this.delay > 0) {\n return super.recycleAsyncId(scheduler, id, delay);\n }\n // If the scheduler queue has no remaining actions with the same async id,\n // cancel the requested animation frame and set the scheduled flag to\n // undefined so the next AnimationFrameAction will request its own.\n const { actions } = scheduler;\n if (id != null && actions[actions.length - 1]?.id !== id) {\n animationFrameProvider.cancelAnimationFrame(id as number);\n scheduler._scheduled = undefined;\n }\n // Return undefined so the action knows to request a new async id if it's rescheduled.\n return undefined;\n }\n}\n", "import { AsyncAction } from './AsyncAction';\nimport { AsyncScheduler } from './AsyncScheduler';\n\nexport class AnimationFrameScheduler extends AsyncScheduler {\n public flush(action?: AsyncAction): void {\n this._active = true;\n // The async id that effects a call to flush is stored in _scheduled.\n // Before executing an action, it's necessary to check the action's async\n // id to determine whether it's supposed to be executed in the current\n // flush.\n // Previous implementations of this method used a count to determine this,\n // but that was unsound, as actions that are unsubscribed - i.e. cancelled -\n // are removed from the actions array and that can shift actions that are\n // scheduled to be executed in a subsequent flush into positions at which\n // they are executed within the current flush.\n const flushId = this._scheduled;\n this._scheduled = undefined;\n\n const { actions } = this;\n let error: any;\n action = action || actions.shift()!;\n\n do {\n if ((error = action.execute(action.state, action.delay))) {\n break;\n }\n } while ((action = actions[0]) && action.id === flushId && actions.shift());\n\n this._active = false;\n\n if (error) {\n while ((action = actions[0]) && action.id === flushId && actions.shift()) {\n action.unsubscribe();\n }\n throw error;\n }\n }\n}\n", "import { AnimationFrameAction } from './AnimationFrameAction';\nimport { AnimationFrameScheduler } from './AnimationFrameScheduler';\n\n/**\n *\n * Animation Frame Scheduler\n *\n * Perform task when `window.requestAnimationFrame` would fire\n *\n * When `animationFrame` scheduler is used with delay, it will fall back to {@link asyncScheduler} scheduler\n * behaviour.\n *\n * Without delay, `animationFrame` scheduler can be used to create smooth browser animations.\n * It makes sure scheduled task will happen just before next browser content repaint,\n * thus performing animations as efficiently as possible.\n *\n * ## Example\n * Schedule div height animation\n * ```ts\n * // html:
\n * import { animationFrameScheduler } from 'rxjs';\n *\n * const div = document.querySelector('div');\n *\n * animationFrameScheduler.schedule(function(height) {\n * div.style.height = height + \"px\";\n *\n * this.schedule(height + 1); // `this` references currently executing Action,\n * // which we reschedule with new state\n * }, 0, 0);\n *\n * // You will see a div element growing in height\n * ```\n */\n\nexport const animationFrameScheduler = new AnimationFrameScheduler(AnimationFrameAction);\n\n/**\n * @deprecated Renamed to {@link animationFrameScheduler}. Will be removed in v8.\n */\nexport const animationFrame = animationFrameScheduler;\n", "import { Observable } from '../Observable';\nimport { SchedulerLike } from '../types';\n\n/**\n * A simple Observable that emits no items to the Observer and immediately\n * emits a complete notification.\n *\n * Just emits 'complete', and nothing else.\n *\n * ![](empty.png)\n *\n * A simple Observable that only emits the complete notification. It can be used\n * for composing with other Observables, such as in a {@link mergeMap}.\n *\n * ## Examples\n *\n * Log complete notification\n *\n * ```ts\n * import { EMPTY } from 'rxjs';\n *\n * EMPTY.subscribe({\n * next: () => console.log('Next'),\n * complete: () => console.log('Complete!')\n * });\n *\n * // Outputs\n * // Complete!\n * ```\n *\n * Emit the number 7, then complete\n *\n * ```ts\n * import { EMPTY, startWith } from 'rxjs';\n *\n * const result = EMPTY.pipe(startWith(7));\n * result.subscribe(x => console.log(x));\n *\n * // Outputs\n * // 7\n * ```\n *\n * Map and flatten only odd numbers to the sequence `'a'`, `'b'`, `'c'`\n *\n * ```ts\n * import { interval, mergeMap, of, EMPTY } from 'rxjs';\n *\n * const interval$ = interval(1000);\n * const result = interval$.pipe(\n * mergeMap(x => x % 2 === 1 ? of('a', 'b', 'c') : EMPTY),\n * );\n * result.subscribe(x => console.log(x));\n *\n * // Results in the following to the console:\n * // x is equal to the count on the interval, e.g. (0, 1, 2, 3, ...)\n * // x will occur every 1000ms\n * // if x % 2 is equal to 1, print a, b, c (each on its own)\n * // if x % 2 is not equal to 1, nothing will be output\n * ```\n *\n * @see {@link Observable}\n * @see {@link NEVER}\n * @see {@link of}\n * @see {@link throwError}\n */\nexport const EMPTY = new Observable((subscriber) => subscriber.complete());\n\n/**\n * @param scheduler A {@link SchedulerLike} to use for scheduling\n * the emission of the complete notification.\n * @deprecated Replaced with the {@link EMPTY} constant or {@link scheduled} (e.g. `scheduled([], scheduler)`). Will be removed in v8.\n */\nexport function empty(scheduler?: SchedulerLike) {\n return scheduler ? emptyScheduled(scheduler) : EMPTY;\n}\n\nfunction emptyScheduled(scheduler: SchedulerLike) {\n return new Observable((subscriber) => scheduler.schedule(() => subscriber.complete()));\n}\n", "import { SchedulerLike } from '../types';\nimport { isFunction } from './isFunction';\n\nexport function isScheduler(value: any): value is SchedulerLike {\n return value && isFunction(value.schedule);\n}\n", "import { SchedulerLike } from '../types';\nimport { isFunction } from './isFunction';\nimport { isScheduler } from './isScheduler';\n\nfunction last(arr: T[]): T | undefined {\n return arr[arr.length - 1];\n}\n\nexport function popResultSelector(args: any[]): ((...args: unknown[]) => unknown) | undefined {\n return isFunction(last(args)) ? args.pop() : undefined;\n}\n\nexport function popScheduler(args: any[]): SchedulerLike | undefined {\n return isScheduler(last(args)) ? args.pop() : undefined;\n}\n\nexport function popNumber(args: any[], defaultValue: number): number {\n return typeof last(args) === 'number' ? args.pop()! : defaultValue;\n}\n", "export const isArrayLike = ((x: any): x is ArrayLike => x && typeof x.length === 'number' && typeof x !== 'function');", "import { isFunction } from \"./isFunction\";\n\n/**\n * Tests to see if the object is \"thennable\".\n * @param value the object to test\n */\nexport function isPromise(value: any): value is PromiseLike {\n return isFunction(value?.then);\n}\n", "import { InteropObservable } from '../types';\nimport { observable as Symbol_observable } from '../symbol/observable';\nimport { isFunction } from './isFunction';\n\n/** Identifies an input as being Observable (but not necessary an Rx Observable) */\nexport function isInteropObservable(input: any): input is InteropObservable {\n return isFunction(input[Symbol_observable]);\n}\n", "import { isFunction } from './isFunction';\n\nexport function isAsyncIterable(obj: any): obj is AsyncIterable {\n return Symbol.asyncIterator && isFunction(obj?.[Symbol.asyncIterator]);\n}\n", "/**\n * Creates the TypeError to throw if an invalid object is passed to `from` or `scheduled`.\n * @param input The object that was passed.\n */\nexport function createInvalidObservableTypeError(input: any) {\n // TODO: We should create error codes that can be looked up, so this can be less verbose.\n return new TypeError(\n `You provided ${\n input !== null && typeof input === 'object' ? 'an invalid object' : `'${input}'`\n } where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.`\n );\n}\n", "export function getSymbolIterator(): symbol {\n if (typeof Symbol !== 'function' || !Symbol.iterator) {\n return '@@iterator' as any;\n }\n\n return Symbol.iterator;\n}\n\nexport const iterator = getSymbolIterator();\n", "import { iterator as Symbol_iterator } from '../symbol/iterator';\nimport { isFunction } from './isFunction';\n\n/** Identifies an input as being an Iterable */\nexport function isIterable(input: any): input is Iterable {\n return isFunction(input?.[Symbol_iterator]);\n}\n", "import { ReadableStreamLike } from '../types';\nimport { isFunction } from './isFunction';\n\nexport async function* readableStreamLikeToAsyncGenerator(readableStream: ReadableStreamLike): AsyncGenerator {\n const reader = readableStream.getReader();\n try {\n while (true) {\n const { value, done } = await reader.read();\n if (done) {\n return;\n }\n yield value!;\n }\n } finally {\n reader.releaseLock();\n }\n}\n\nexport function isReadableStreamLike(obj: any): obj is ReadableStreamLike {\n // We don't want to use instanceof checks because they would return\n // false for instances from another Realm, like an