diff --git a/.cspell/custom_misc.txt b/.cspell/custom_misc.txt index 750db8995..901e6316f 100644 --- a/.cspell/custom_misc.txt +++ b/.cspell/custom_misc.txt @@ -48,6 +48,7 @@ pmatrix PMLR primaryclass PRNG +prob propto rcond recomb diff --git a/coreax/solvers/__init__.py b/coreax/solvers/__init__.py index a2e15e79d..11728f3ea 100644 --- a/coreax/solvers/__init__.py +++ b/coreax/solvers/__init__.py @@ -27,6 +27,7 @@ GreedyKernelPointsState, HerdingState, KernelHerding, + KernelThinning, RandomSample, RPCholesky, RPCholeskyState, @@ -57,4 +58,5 @@ "RecombinationSolver", "CaratheodoryRecombination", "TreeRecombination", + "KernelThinning", ] diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index ec53789d5..a0a3e4369 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -14,6 +14,7 @@ """Solvers for constructing coresubsets.""" +import math from collections.abc import Callable from typing import Optional, TypeVar, Union @@ -23,6 +24,7 @@ import jax.random as jr import jax.scipy as jsp import jax.tree_util as jtu +from jax import lax from jaxtyping import Array, ArrayLike, Scalar, Shaped from typing_extensions import override @@ -33,6 +35,7 @@ MinimalEuclideanNormSolver, RegularisedLeastSquaresSolver, ) +from coreax.metrics import MMD from coreax.score_matching import ScoreMatching, convert_stein_kernel from coreax.solvers.base import ( CoresubsetSolver, @@ -858,3 +861,349 @@ def _greedy_body( return Coresubset(updated_coreset_indices, dataset), GreedyKernelPointsState( padded_feature_gramian ) + + +class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): + r""" + Kernel Thinning - a hierarchical coreset construction solver. + + `Kernel Thinning` is a hierarchical, and probabilistic algorithm for coreset + construction. It builds a coreset by splitting the dataset into several candidate + coresets by repeatedly halving the dataset and applying probabilistic swapping. + The best of these candidates is chosen which is further refined to minimise + the Maximum Mean Discrepancy (MMD) between the original dataset and the coreset. + This implementation is a modification of the Kernel Thinning algorithm in + :cite:`dwivedi2024kernelthinning` to make it an ExplicitSizeSolver. + + :param kernel: A `~coreax.kernels.ScalarValuedKernel` instance defining the primary + kernel function used for choosing the best coreset and refining it. + :param random_key: Key for random number generation, enabling reproducibility of + probabilistic components in the algorithm. + :param delta: A float between 0 and 1 used to compute the swapping probability + during the splitting process. A recommended value is :math:`1 / \log(\log(n))`, + where :math:`n` is the length of the original dataset. + :param sqrt_kernel: A `~coreax.kernels.ScalarValuedKernel` instance representing the + square root kernel used for splitting the original dataset. + """ + + kernel: ScalarValuedKernel + random_key: KeyArrayLike + delta: float + sqrt_kernel: ScalarValuedKernel + + def reduce( + self, dataset: _Data, solver_state: None = None + ) -> tuple[Coresubset[_Data], None]: + """ + Reduce 'dataset' to a :class:`~coreax.coreset.Coresubset` with 'KernelThinning'. + + This is done by first computing the square root kernel if not already provided, + then using the recommended value for the delta parameter according to + :cite:`dwivedi2024kernelthinning` if not specified. Next, the number of halving + steps required, referred to as `m`, is calculated. The original data is clipped + so that it is divisible by a power of two. The kernel halving algorithm is then + recursively applied to halve the data. + + Subsequently, a `baseline_coreset` is added to the ensemble of coresets. The + best coreset is selected to minimise the Maximum Mean Discrepancy (MMD) and + finally, it is refined further for optimal results. + + + :param dataset: The original dataset to be reduced. + :param solver_state: The state of the solver. + + :return: A tuple containing the final coreset and the solver state (None). + """ + n = len(dataset) + m = math.floor(math.log2(n) - math.log2(self.coreset_size)) + clipped_original_dataset = dataset[: self.coreset_size * 2**m] + + partition = self.kt_half_recursive(clipped_original_dataset, m, dataset) + baseline_coreset = self.get_baseline_coreset(dataset, self.coreset_size) + partition.append(baseline_coreset) + + best_coreset_indices = self.kt_choose(partition, dataset) + return self.kt_refine(Coresubset(best_coreset_indices, dataset)) + + def reduce_internal( + self, + dataset: _Data, + ) -> tuple[Coresubset[_Data], None]: + """ + Implement `reduce` method for the new instance with set parameters. + + :param dataset: The original dataset to be reduced. + + :return: A tuple containing the final coreset and the solver state (None). + """ + n = len(dataset) + m = math.floor(math.log2(n) - math.log2(self.coreset_size)) + clipped_original_dataset = dataset[: self.coreset_size * 2**m] + + partition = self.kt_half_recursive(clipped_original_dataset, m, dataset) + baseline_coreset = self.get_baseline_coreset(dataset, self.coreset_size) + partition.append(baseline_coreset) + + best_coreset_indices = self.kt_choose(partition, dataset) + return self.kt_refine(Coresubset(best_coreset_indices, dataset)) + + def kt_half_recursive(self, current_coreset, m, original_dataset): + """ + Recursively halve the original dataset into coresets. + + :param current_coreset: The current coreset or dataset being partitioned. + :param m: The remaining depth of recursion. + :param original_dataset: The original dataset. + :return: Fully partitioned list of coresets. + """ + if m == 0: + return [ + Coresubset(Data(jnp.arange(len(current_coreset))), original_dataset) + ] + + # Recursively call self.kt_half on the coreset (or the dataset) + if hasattr(current_coreset, "coreset"): + subset1, subset2 = self.kt_half(current_coreset.coreset) + else: + subset1, subset2 = self.kt_half(current_coreset) + + # Update pre_coreset_data for both subsets to point to the original dataset + subset1 = eqx.tree_at(lambda x: x.pre_coreset_data, subset1, original_dataset) + subset2 = eqx.tree_at(lambda x: x.pre_coreset_data, subset2, original_dataset) + + # Update indices: map current subset's indices to original dataset + if hasattr(current_coreset, "nodes") and hasattr(current_coreset.nodes, "data"): + parent_indices = current_coreset.nodes.data # Parent subset's indices + subset1_indices = subset1.nodes.data.flatten() # Indices relative to parent + subset2_indices = subset2.nodes.data.flatten() # Indices relative to parent + + # Map subset indices back to original dataset + subset1_indices = parent_indices[subset1_indices] + subset2_indices = parent_indices[subset2_indices] + + # Update the subsets with the remapped indices + subset1 = eqx.tree_at(lambda x: x.nodes.data, subset1, subset1_indices) + subset2 = eqx.tree_at(lambda x: x.nodes.data, subset2, subset2_indices) + + # Recur for both subsets and concatenate results + return self.kt_half_recursive( + subset1, m - 1, original_dataset + ) + self.kt_half_recursive(subset2, m - 1, original_dataset) + + def kt_half(self, points: _Data) -> list[Coresubset[_Data]]: + """ + Partition the given dataset into two subsets. + + :param points: The input dataset to be halved. + :return: A tuple containing two the partitioned coresets. + """ + n = len(points) // 2 + original_array = points.data + arr1 = jnp.zeros(n, dtype=jnp.int32) + arr2 = jnp.zeros(n, dtype=jnp.int32) + + bool_arr_1 = jnp.zeros(2 * n) + bool_arr_2 = jnp.zeros(n) + + # Initialise parameter + param = jnp.float32(0) + k = self.sqrt_kernel.compute_elementwise + + def compute_b(x1, x2): + """ + Compute b. + + :param x1: The first data point. + :param x2: The second data point. + :return: The kernel distance between `x1` and `x2`. + """ + return jnp.sqrt(k(x1, x1) + k(x2, x2) - 2 * k(x1, x2)) + + def get_a_and_param(b, sigma): + """Compute 'a' and new parameter.""" + a = jnp.maximum(b * sigma * jnp.sqrt(2 * jnp.log(2 / self.delta)), b**2) + + # Update sigma + new_sigma = jnp.sqrt( + sigma**2 + jnp.maximum(b**2 * (1 + (b**2 - 2 * a) * sigma**2 / a**2), 0) + ) + + return a, new_sigma + + def get_alpha( + x1: jnp.ndarray, + x2: jnp.ndarray, + i: int, + bool_arr_1: jnp.ndarray, + bool_arr_2: jnp.ndarray, + ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Calculate the value of alpha and update the boolean arrays. + + :param x1: The first data point in the kernel evaluation. + :param x2: The second data point in the kernel evaluation. + :param i: The current index in the iteration. + :param bool_arr_1: A boolean array that tracks indices. + :param bool_arr_2: A boolean array that tracks indices. + :return: A tuple containing: + - `alpha`: The computed value of alpha. + - `bool_arr_1`: Updated boolean array for the original dataset. + - `bool_arr_2`: Updated boolean array for the subset. + """ + k_vec_x1 = jax.vmap(lambda y: k(y, x1)) + k_vec_x2 = jax.vmap(lambda y: k(y, x2)) + + k_vec_x1_idx = jax.vmap(lambda y: k(original_array[y], x1)) + k_vec_x2_idx = jax.vmap(lambda y: k(original_array[y], x2)) + # Apply to original array and sum + term1 = jnp.dot( + (k_vec_x1(original_array) - k_vec_x2(original_array)), bool_arr_1 + ) + term2 = -2 * jnp.dot((k_vec_x1_idx(arr1) - k_vec_x2_idx(arr1)), bool_arr_2) + # For bool_arr_1, set 2i and 2i+1 positions to 1 + bool_arr_1 = bool_arr_1.at[2 * i].set(1) + bool_arr_1 = bool_arr_1.at[2 * i + 1].set(1) + # For bool_arr_2, set i-th position to 1 + bool_arr_2 = bool_arr_2.at[i].set(1) + # Combine all terms + alpha = term1 + term2 + return alpha, bool_arr_1, bool_arr_2 + + def final_function( + i: int, a: jnp.ndarray, alpha: jnp.ndarray, random_key: KeyArrayLike + ) -> tuple[tuple[int, int], KeyArrayLike]: + """ + Perform a probabilistic swap based on the given parameters. + + :param i: The current index in the dataset. + :param a: The swap threshold computed based on kernel parameters. + :param alpha: The calculated value for probabilistic swapping. + :param random_key: A random key for generating random numbers. + :return: A tuple containing: + - A tuple of indices indicating the swapped values. + - The updated random key. + """ + key1, key2 = jax.random.split(random_key) + + prob = jax.random.uniform(key1) + return lax.cond( + prob < 1 / 2 * (1 - alpha / a), + lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2 + lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1 + None, + ), key2 + + def body_fun( + i: int, + state: tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + KeyArrayLike, + ], + ) -> tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + KeyArrayLike, + ]: + """ + Perform one iteration of the halving process. + + :param i: The current iteration index. + :param state: A tuple containing: + - arr1: The first array of indices. + - arr2: The second array of indices. + - param: The scaling parameter. + - bool_arr_1: Boolean array indicating selected indices. + - bool_arr_2: Boolean array indicating selected indices. + - random_key: A JAX random key. + :return: The updated state tuple after processing the current iteration. + """ + arr1, arr2, param, bool_arr_1, bool_arr_2, random_key = state + # Step 1: Get values from original array + x1 = original_array[i * 2] + x2 = original_array[i * 2 + 1] + # Step 2: Get a and new parameter + a, new_param = get_a_and_param(compute_b(x1, x2), param) + # Step 3: Compute alpha + alpha, new_bool_arr_1, new_bool_arr_2 = get_alpha( + x1, x2, i, bool_arr_1, bool_arr_2 + ) + # Step 4: Get final values + (val1, val2), new_random_key = final_function(i, a, alpha, random_key) + # Step 5: Update arrays + new_arr1 = arr1.at[i].set(val1) + new_arr2 = arr2.at[i].set(val2) + return ( + new_arr1, + new_arr2, + new_param, + new_bool_arr_1, + new_bool_arr_2, + new_random_key, + ) + + (final_arr1, final_arr2, _, _, _, _) = lax.fori_loop( + 0, # start index + n, # end index + body_fun, # body function + (arr1, arr2, param, bool_arr_1, bool_arr_2, self.random_key), + ) + return [Coresubset(final_arr1, points), Coresubset(final_arr2, points)] + + def get_baseline_coreset( + self, dataset: Data, baseline_coreset_size: int + ) -> Coresubset[_Data]: + """ + Generate a baseline coreset by randomly sampling from the dataset. + + :param dataset: The input dataset from which the baseline coreset is sampled. + :param baseline_coreset_size: The number of points in the baseline coreset. + :return: A randomly sampled baseline coreset with the specified size. + """ + baseline_coreset, _ = RandomSample( + coreset_size=baseline_coreset_size, random_key=self.random_key + ).reduce(dataset) + return baseline_coreset + + def kt_choose( + self, candidate_coresets: list[Coresubset[_Data]], points: _Data + ) -> Shaped[Array, " coreset_size"]: + """ + Select the best coreset from a list of candidate coresets based on MMD. + + :param candidate_coresets: A list of candidate coresets to be evaluated. + :param points: The original dataset against which the coresets are compared. + :return: The coreset with the smallest MMD relative to the input dataset. + """ + mmd = MMD(kernel=self.kernel) + candidate_coresets_jax = jnp.array([c.coreset.data for c in candidate_coresets]) + candidate_coresets_indices = jnp.array([c.nodes for c in candidate_coresets]) + mmd_values = jax.vmap(lambda c: mmd.compute(c, points))(candidate_coresets_jax) + + best_index = jnp.argmin(mmd_values) + + return candidate_coresets_indices[best_index] + + def kt_refine( + self, candidate_coreset: Coresubset[_Data] + ) -> tuple[Coresubset[_Data], None]: + """ + Refine the selected candidate coreset. + + Use meth:`~coreax.solvers.KernelHerding.refine` which achieves the result of + looping through each element in coreset replacing that element with a point in + the original dataset to minimise MMD in each step. + + :param candidate_coreset: The candidate coreset to be refined. + :return: The refined coreset. + """ + refined_coreset, _ = KernelHerding( + coreset_size=self.coreset_size, kernel=self.kernel + ).refine(candidate_coreset) + return refined_coreset, None