From eb68cc68678aa0cd79b84221fd77550310acfd6b Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 15 Nov 2024 12:02:07 +0000 Subject: [PATCH 1/9] feat: Added the initial implementation of KT-split --- coreax/solvers/__init__.py | 2 + coreax/solvers/coresubset.py | 184 +++++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) diff --git a/coreax/solvers/__init__.py b/coreax/solvers/__init__.py index a2e15e79d..11728f3ea 100644 --- a/coreax/solvers/__init__.py +++ b/coreax/solvers/__init__.py @@ -27,6 +27,7 @@ GreedyKernelPointsState, HerdingState, KernelHerding, + KernelThinning, RandomSample, RPCholesky, RPCholeskyState, @@ -57,4 +58,5 @@ "RecombinationSolver", "CaratheodoryRecombination", "TreeRecombination", + "KernelThinning", ] diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 6a6e23ada..69e44c042 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -33,6 +33,7 @@ MinimalEuclideanNormSolver, RegularisedLeastSquaresSolver, ) +from coreax.metrics import MMD from coreax.score_matching import ScoreMatching, convert_stein_kernel from coreax.solvers.base import ( CoresubsetSolver, @@ -816,3 +817,186 @@ def _greedy_body( return Coresubset(updated_coreset_indices, dataset), GreedyKernelPointsState( padded_feature_gramian ) + + +class KernelThinning(CoresubsetSolver): + r""" + Kernel Thinning - a hierarchical coreset construction solver. + + `Kernel Thinning` is a hierarchical, and probabilistic algorithm for coreset + construction. It builds a coreset by splitting the dataset into several candidate + coresets by repeatedly halving the dataset and applying probabilistic swapping. + The method the best of these candidate coresets which is further refined to minimise + the Maximum Mean Discrepancy (MMD) between the original dataset and the coreset. + + :param kernel: A `~coreax.kernels.ScalarValuedKernel` instance defining the primary + kernel function used for choosing the best coreset and refining it. + :param sqrt_kernel: A `~coreax.kernels.ScalarValuedKernel` instance representing the + square root kernel used for splitting the original dataset. + :param m: An integer specifying the number of hierarchical halving steps in the + coreset construction. + :param delta: A float between 0 and 1 used to compute the swapping probability + during the splitting process. A recommended value is :math:`1 / \log(\log(n))`, + where :math:`n` is the length of the original dataset. + :param random_key: Key for random number generation, enabling reproducibility of + probabilistic components in the algorithm. + """ + + kernel: ScalarValuedKernel + sqrt_kernel: ScalarValuedKernel + m: int + delta: float + random_key: KeyArrayLike + + @classmethod + def get_swap_params( + cls, sigma: Array, b: Array, delta: float + ) -> tuple[Array, Array]: + r""" + Compute the swap threshold and update the scaling parameter for swapping. + + :param sigma: The current scaling parameter used in the swapping process. + :param b: The kernel-based distance between two points in the dataset. + :param delta: A parameter used in calculation of the swapping probability. + :return: The swap threshold and the updated scaling parameter. + """ + a = jnp.maximum(b * sigma * jnp.sqrt(2 * jnp.log(2 / delta)), b**2) + + # Update sigma + new_sigma = jnp.sqrt( + sigma**2 + jnp.maximum(b**2 * (1 + (b**2 - 2 * a) * sigma**2 / a**2), 0) + ) + + return a, new_sigma + + def reduce( + self, dataset: _Data, solver_state: None = None + ) -> tuple[Coresubset[_Data], None]: + r""" + Reduce the input dataset to a single refined coreset. + + :param dataset: The original dataset to be reduced. + :param solver_state: The state of the solver (currently not used). + + :return: A tuple containing the final coreset and the solver state (None). + """ + final_coresets = self.kt_split(dataset) + return self.kt_refine(self.kt_choose(final_coresets, dataset)), solver_state + + def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: + r""" + Perform hierarchical splitting of the input dataset into multiple coresets. + + This method splits the dataset recursively halving at each level. At each step, + a probabilistic swapping is applied to refine the distribution of points across + the coresets. + + :param points: The dataset to be split into coresets. + :return: A list of refined coresets representing different hierarchical levels. + """ + n = len(points) + coresets = {(j, k): [] for j in range(self.m + 1) for k in range(1, 2**j + 1)} + sigma = { + (j, k): jnp.zeros(1) + for j in range(1, self.m + 1) + for k in range(1, 2 ** (j - 1) + 1) + } + + # Step 1: Initialize with pairs of indices in the lowest level coreset S(0,1) + for i in range(1, n // 2 + 1): + idx1, idx2 = 2 * i - 2, 2 * i - 1 + coresets[(0, 1)].extend([idx1, idx2]) + + # Step 2: Distribute indices hierarchically across levels + for j in range(1, self.m + 1): + if i % (2 ** (j - 1)) == 0: + for k in range(1, 2 ** (j - 1) + 1): + parent_set = coresets[(j - 1, k)] + if len(parent_set) <= 1: + continue + + idx_x, idx_x_prime = parent_set[-2], parent_set[-1] + x, x_prime = points[idx_x], points[idx_x_prime] + + # Calculate kernel values and b^2 + b_squared = ( + self.sqrt_kernel.compute_elementwise(x, x) + + self.sqrt_kernel.compute_elementwise(x_prime, x_prime) + - 2 * self.sqrt_kernel.compute_elementwise(x, x_prime) + ) + + # Compute swap threshold a and update sigma + a, sigma[(j, k)] = self.get_swap_params( + sigma[(j, k)], b_squared, self.delta + ) + + # Calculate alpha for probabilistic swapping + alpha = ( + self.sqrt_kernel.compute_elementwise(x_prime, x_prime) + - self.sqrt_kernel.compute_elementwise(x, x) + + sum( + self.sqrt_kernel.compute_elementwise(points[y], x) + - self.sqrt_kernel.compute_elementwise( + points[y], x_prime + ) + for y in parent_set + ) + - 2 + * sum( + self.sqrt_kernel.compute_elementwise(points[z], x) + - self.sqrt_kernel.compute_elementwise( + points[z], x_prime + ) + for z in coresets[(j, 2 * k - 1)] + ) + ) + + # Compute swap probability + swap_probability = min( + 1, max(0.5 * (1 - (alpha / a).item()), 0) + ) + + # Apply probabilistic swap + if jax.random.uniform(self.random_key) < swap_probability: + idx_x, idx_x_prime = idx_x_prime, idx_x + + # Assign indices to child coresets + coresets[(j, 2 * k - 1)].append(idx_x) + coresets[(j, 2 * k)].append(idx_x_prime) + + # Collect the indices of the final level's coresets + final_coresets = [ + Coresubset(Data(jnp.array(coresets[(self.m, k)])), points) + for k in range(1, 2**self.m + 1) + ] + + return final_coresets + + def kt_choose( + self, candidate_coresets: list[Coresubset[_Data]], points: _Data + ) -> Coresubset[_Data]: + r""" + Select the best coreset from a list of candidate coresets based on MMD. + + :param candidate_coresets: A list of candidate coresets to be evaluated. + :param points: The original dataset against which the coresets are compared. + :return: The coreset with the smallest MMD relative to the input dataset. + """ + mmd = MMD(kernel=self.kernel) + + best_coreset = min( + candidate_coresets, key=lambda coreset: mmd.compute(points, coreset.coreset) + ) + + return best_coreset + + def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]: + r""" + Refine the selected candidate coreset. + + It is not yet implemented and serves as a placeholder for future implementation. + + :param candidate_coreset: The candidate coreset to be refined. + :return: The refined coreset. + """ + return candidate_coreset From e008146ec63547c1926ab1037c49b7b9e6652e95 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:33:44 +0000 Subject: [PATCH 2/9] feat: Add recursive kernel halving --- .cspell/custom_misc.txt | 1 + coreax/solvers/coresubset.py | 147 +++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/.cspell/custom_misc.txt b/.cspell/custom_misc.txt index dd8d4dafd..6674d5b6b 100644 --- a/.cspell/custom_misc.txt +++ b/.cspell/custom_misc.txt @@ -37,6 +37,7 @@ pmatrix PMLR primaryclass PRNG +prob propto rcond recomb diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 69e44c042..29c9a9937 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -23,6 +23,7 @@ import jax.random as jr import jax.scipy as jsp import jax.tree_util as jtu +from jax import lax from jaxtyping import Array, ArrayLike, Shaped from typing_extensions import override @@ -883,6 +884,152 @@ def reduce( final_coresets = self.kt_split(dataset) return self.kt_refine(self.kt_choose(final_coresets, dataset)), solver_state + def kt_half_recursive(self, points, m, original_dataset): + r""" + Recursively halve the original dataset into coresets. + + :param points: The current coreset or dataset being partitioned. + :param m: The remaining depth of recursion. + :param original_dataset: The original dataset. + :return: Fully partitioned list of :class:`Coresubset` instances. + """ + if m == 0: + return [points] + + # Recursively call self.kt_half on the coreset (or the dataset) + if hasattr(points, "coreset"): + subset1, subset2 = self.kt_half(points.coreset) + else: + subset1, subset2 = self.kt_half(points) + + # Update pre_coreset_data for both subsets to point to the original dataset + subset1 = eqx.tree_at(lambda x: x.pre_coreset_data, subset1, original_dataset) + subset2 = eqx.tree_at(lambda x: x.pre_coreset_data, subset2, original_dataset) + + # Update indices: map current subset's indices to original dataset + if hasattr(points, "nodes") and hasattr(points.nodes, "data"): + parent_indices = points.nodes.data # Parent subset's indices + subset1_indices = subset1.nodes.data # Indices relative to parent + subset2_indices = subset2.nodes.data # Indices relative to parent + + # Map subset indices back to original dataset + subset1_indices = parent_indices[subset1_indices] + subset2_indices = parent_indices[subset2_indices] + + # Update the subsets with the remapped indices + subset1 = eqx.tree_at(lambda x: x.nodes.data, subset1, subset1_indices) + subset2 = eqx.tree_at(lambda x: x.nodes.data, subset2, subset2_indices) + + # Recur for both subsets and concatenate results + return self.kt_half_recursive( + subset1, m - 1, original_dataset + ) + self.kt_half_recursive(subset2, m - 1, original_dataset) + + def kt_half(self, points: _Data) -> list[Coresubset[_Data]]: + """ + Partition the given dataset into two subsets. + + :param points: The input dataset to be halved. + :return: A tuple containing two :class:`Coresubset` instances. + """ + n = len(points) // 2 + original_array = points.data + arr1 = jnp.zeros(n, dtype=jnp.int32) + arr2 = jnp.zeros(n, dtype=jnp.int32) + + bool_arr_1 = jnp.zeros(2 * n) + bool_arr_2 = jnp.zeros(n) + + # Initialize parameter + param = jnp.float32(0) + k = self.sqrt_kernel.compute_elementwise + + def compute_b(x1, x2): + """ + Compute b. + + :param x1: The first data point. + :param x2: The second data point. + :return: The kernel distance between `x1` and `x2`. + """ + return jnp.sqrt(k(x1, x1) + k(x2, x2) - 2 * k(x1, x2)) + + def get_a_and_param(b, sigma): + """Compute 'a' and new parameter.""" + a = jnp.maximum(b * sigma * jnp.sqrt(2 * jnp.log(2 / self.delta)), b**2) + + # Update sigma + new_sigma = jnp.sqrt( + sigma**2 + jnp.maximum(b**2 * (1 + (b**2 - 2 * a) * sigma**2 / a**2), 0) + ) + + return a, new_sigma + + def get_alpha(x1, x2, i, bool_arr_1, bool_arr_2): + k_vec_x1 = jax.vmap(lambda y: k(y, x1)) + k_vec_x2 = jax.vmap(lambda y: k(y, x2)) + + k_vec_x1_idx = jax.vmap(lambda y: k(original_array[y], x1)) + k_vec_x2_idx = jax.vmap(lambda y: k(original_array[y], x2)) + # Apply to original array and sum + term1 = jnp.dot( + (k_vec_x1(original_array) - k_vec_x2(original_array)), bool_arr_1 + ) + term2 = -2 * jnp.dot((k_vec_x1_idx(arr1) - k_vec_x2_idx(arr1)), bool_arr_2) + # For bool_arr_1, set 2i and 2i+1 positions to 1 + bool_arr_1 = bool_arr_1.at[2 * i].set(1) + bool_arr_1 = bool_arr_1.at[2 * i + 1].set(1) + # For bool_arr_2, set i-th position to 1 + bool_arr_2 = bool_arr_2.at[i].set(1) + # Combine all terms + alpha = term1 + term2 + return alpha, bool_arr_1, bool_arr_2 + + def final_function(i, a, alpha, random_key): + key1, key2 = jax.random.split(random_key) + + prob = jax.random.uniform(key1) + return lax.cond( + prob < 1 / 2 * (1 - alpha / a), + lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2 + lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1 + None, + ), key2 + + def body_fun(i, state): + arr1, arr2, param, bool_arr_1, bool_arr_2, random_key = state + # Step 1: Get values from original array + x1 = original_array[i * 2] + x2 = original_array[i * 2 + 1] + # Step 2: Get a and new parameter + a, new_param = get_a_and_param(compute_b(x1, x2), param) + # Step 3: Compute alpha + alpha, new_bool_arr_1, new_bool_arr_2 = get_alpha( + x1, x2, i, bool_arr_1, bool_arr_2 + ) + # Step 4: Get final values + (val1, val2), new_random_key = final_function(i, a, alpha, random_key) + # Step 5: Update arrays + new_arr1 = arr1.at[i].set(val1) + new_arr2 = arr2.at[i].set(val2) + return ( + new_arr1, + new_arr2, + new_param, + new_bool_arr_1, + new_bool_arr_2, + new_random_key, + ) + + (final_arr1, final_arr2, _, _, _, _) = lax.fori_loop( + 0, # start index + n, # end index + body_fun, # body function + (arr1, arr2, param, bool_arr_1, bool_arr_2, self.random_key), + ) + + return Coresubset(final_arr1, points), Coresubset(final_arr2, points) + def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: r""" Perform hierarchical splitting of the input dataset into multiple coresets. From 14750268cc07d0c59854884013f71b6f3aa5bd6b Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:05:36 +0000 Subject: [PATCH 3/9] doc: Add missing docstrings for Kernel Thinning --- coreax/solvers/coresubset.py | 84 +++++++++++++++++++++++++++++++----- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 29c9a9937..2ff8aff65 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -853,7 +853,7 @@ class KernelThinning(CoresubsetSolver): def get_swap_params( cls, sigma: Array, b: Array, delta: float ) -> tuple[Array, Array]: - r""" + """ Compute the swap threshold and update the scaling parameter for swapping. :param sigma: The current scaling parameter used in the swapping process. @@ -873,7 +873,7 @@ def get_swap_params( def reduce( self, dataset: _Data, solver_state: None = None ) -> tuple[Coresubset[_Data], None]: - r""" + """ Reduce the input dataset to a single refined coreset. :param dataset: The original dataset to be reduced. @@ -885,7 +885,7 @@ def reduce( return self.kt_refine(self.kt_choose(final_coresets, dataset)), solver_state def kt_half_recursive(self, points, m, original_dataset): - r""" + """ Recursively halve the original dataset into coresets. :param points: The current coreset or dataset being partitioned. @@ -940,7 +940,7 @@ def kt_half(self, points: _Data) -> list[Coresubset[_Data]]: bool_arr_1 = jnp.zeros(2 * n) bool_arr_2 = jnp.zeros(n) - # Initialize parameter + # Initialise parameter param = jnp.float32(0) k = self.sqrt_kernel.compute_elementwise @@ -965,7 +965,26 @@ def get_a_and_param(b, sigma): return a, new_sigma - def get_alpha(x1, x2, i, bool_arr_1, bool_arr_2): + def get_alpha( + x1: jnp.ndarray, + x2: jnp.ndarray, + i: int, + bool_arr_1: jnp.ndarray, + bool_arr_2: jnp.ndarray, + ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Calculate the value of alpha and update the boolean arrays. + + :param x1: The first data point in the kernel evaluation. + :param x2: The second data point in the kernel evaluation. + :param i: The current index in the iteration. + :param bool_arr_1: A boolean array that tracks indices. + :param bool_arr_2: A boolean array that tracks indices. + :return: A tuple containing: + - `alpha`: The computed value of alpha. + - `bool_arr_1`: Updated boolean array for the original dataset. + - `bool_arr_2`: Updated boolean array for the subset. + """ k_vec_x1 = jax.vmap(lambda y: k(y, x1)) k_vec_x2 = jax.vmap(lambda y: k(y, x2)) @@ -985,7 +1004,20 @@ def get_alpha(x1, x2, i, bool_arr_1, bool_arr_2): alpha = term1 + term2 return alpha, bool_arr_1, bool_arr_2 - def final_function(i, a, alpha, random_key): + def final_function( + i: int, a: jnp.ndarray, alpha: jnp.ndarray, random_key: KeyArrayLike + ) -> tuple[tuple[int, int], KeyArrayLike]: + """ + Perform a probabilistic swap based on the given parameters. + + :param i: The current index in the dataset. + :param a: The swap threshold computed based on kernel parameters. + :param alpha: The calculated value for probabilistic swapping. + :param random_key: A random key for generating random numbers. + :return: A tuple containing: + - A tuple of indices indicating the swapped values. + - The updated random key. + """ key1, key2 = jax.random.split(random_key) prob = jax.random.uniform(key1) @@ -996,7 +1028,37 @@ def final_function(i, a, alpha, random_key): None, ), key2 - def body_fun(i, state): + def body_fun( + i: int, + state: tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + KeyArrayLike, + ], + ) -> tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + KeyArrayLike, + ]: + """ + Perform one iteration of the halving process. + + :param i: The current iteration index. + :param state: A tuple containing: + - arr1: The first array of indices. + - arr2: The second array of indices. + - param: The scaling parameter. + - bool_arr_1: Boolean array indicating selected indices. + - bool_arr_2: Boolean array indicating selected indices. + - random_key: A JAX random key. + :return: The updated state tuple after processing the current iteration. + """ arr1, arr2, param, bool_arr_1, bool_arr_2, random_key = state # Step 1: Get values from original array x1 = original_array[i * 2] @@ -1031,7 +1093,7 @@ def body_fun(i, state): return Coresubset(final_arr1, points), Coresubset(final_arr2, points) def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: - r""" + """ Perform hierarchical splitting of the input dataset into multiple coresets. This method splits the dataset recursively halving at each level. At each step, @@ -1049,7 +1111,7 @@ def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: for k in range(1, 2 ** (j - 1) + 1) } - # Step 1: Initialize with pairs of indices in the lowest level coreset S(0,1) + # Step 1: Initialise with pairs of indices in the lowest level coreset S(0,1) for i in range(1, n // 2 + 1): idx1, idx2 = 2 * i - 2, 2 * i - 1 coresets[(0, 1)].extend([idx1, idx2]) @@ -1122,7 +1184,7 @@ def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: def kt_choose( self, candidate_coresets: list[Coresubset[_Data]], points: _Data ) -> Coresubset[_Data]: - r""" + """ Select the best coreset from a list of candidate coresets based on MMD. :param candidate_coresets: A list of candidate coresets to be evaluated. @@ -1138,7 +1200,7 @@ def kt_choose( return best_coreset def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]: - r""" + """ Refine the selected candidate coreset. It is not yet implemented and serves as a placeholder for future implementation. From e1dedb43d3a1766d2bdb572115eb8e70be297406 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:39:18 +0000 Subject: [PATCH 4/9] doc: fix build documentation error --- coreax/solvers/coresubset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 2ff8aff65..46b5a417c 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -891,7 +891,7 @@ def kt_half_recursive(self, points, m, original_dataset): :param points: The current coreset or dataset being partitioned. :param m: The remaining depth of recursion. :param original_dataset: The original dataset. - :return: Fully partitioned list of :class:`Coresubset` instances. + :return: Fully partitioned list of coresets. """ if m == 0: return [points] @@ -930,7 +930,7 @@ def kt_half(self, points: _Data) -> list[Coresubset[_Data]]: Partition the given dataset into two subsets. :param points: The input dataset to be halved. - :return: A tuple containing two :class:`Coresubset` instances. + :return: A tuple containing two the partitioned coresets. """ n = len(points) // 2 original_array = points.data From 2902757e13fd4fb1727969dfb98ecb712d490ec4 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:28:31 +0000 Subject: [PATCH 5/9] feat: kt_half returns list instead of tuple --- coreax/solvers/coresubset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 46b5a417c..7e47a1115 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -1090,7 +1090,7 @@ def body_fun( (arr1, arr2, param, bool_arr_1, bool_arr_2, self.random_key), ) - return Coresubset(final_arr1, points), Coresubset(final_arr2, points) + return [Coresubset(final_arr1, points), Coresubset(final_arr2, points)] def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: """ From 2af3980d79f1e287fe9a229b937a7d7b53d360ca Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:38:00 +0000 Subject: [PATCH 6/9] chore: add `__post_init__` to set the square root kernel. --- coreax/solvers/coresubset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 7e47a1115..8cb75ca9e 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -820,7 +820,7 @@ def _greedy_body( ) -class KernelThinning(CoresubsetSolver): +class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): r""" Kernel Thinning - a hierarchical coreset construction solver. @@ -849,6 +849,16 @@ class KernelThinning(CoresubsetSolver): delta: float random_key: KeyArrayLike + def __post_init__(self): + """ + Initialise square-root kernel. + + If square-root kernel is not provided, check if square-root kernel of the given + kernel is implemented and set that as the square root, otherwise raise an error. + """ + if not hasattr(self, "sqrt_kernel"): + self.sqrt_kernel = self.kernel + @classmethod def get_swap_params( cls, sigma: Array, b: Array, delta: float From 26a4d4454a7641dd30b622b7d73c75015a826089 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:28:06 +0000 Subject: [PATCH 7/9] feat: Initial draft of fixed coreset size KT algorithm. #688 --- coreax/solvers/coresubset.py | 164 ++++++++++++----------------------- 1 file changed, 57 insertions(+), 107 deletions(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index f57cf3a99..f5cb70ac3 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -14,6 +14,7 @@ """Solvers for constructing coresubsets.""" +import math from collections.abc import Callable from typing import Optional, TypeVar, Union @@ -869,27 +870,27 @@ class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): `Kernel Thinning` is a hierarchical, and probabilistic algorithm for coreset construction. It builds a coreset by splitting the dataset into several candidate coresets by repeatedly halving the dataset and applying probabilistic swapping. - The method the best of these candidate coresets which is further refined to minimise + The best of these candidates is chosen which is further refined to minimise the Maximum Mean Discrepancy (MMD) between the original dataset and the coreset. :param kernel: A `~coreax.kernels.ScalarValuedKernel` instance defining the primary kernel function used for choosing the best coreset and refining it. - :param sqrt_kernel: A `~coreax.kernels.ScalarValuedKernel` instance representing the - square root kernel used for splitting the original dataset. + :param m: An integer specifying the number of hierarchical halving steps in the coreset construction. + :param random_key: Key for random number generation, enabling reproducibility of + probabilistic components in the algorithm. :param delta: A float between 0 and 1 used to compute the swapping probability during the splitting process. A recommended value is :math:`1 / \log(\log(n))`, where :math:`n` is the length of the original dataset. - :param random_key: Key for random number generation, enabling reproducibility of - probabilistic components in the algorithm. + :param sqrt_kernel: A `~coreax.kernels.ScalarValuedKernel` instance representing the + square root kernel used for splitting the original dataset. """ kernel: ScalarValuedKernel - sqrt_kernel: ScalarValuedKernel - m: int - delta: float random_key: KeyArrayLike + delta: Optional[float] = None + sqrt_kernel: Optional[ScalarValuedKernel] = None def __post_init__(self): """ @@ -898,7 +899,7 @@ def __post_init__(self): If square-root kernel is not provided, check if square-root kernel of the given kernel is implemented and set that as the square root, otherwise raise an error. """ - if not hasattr(self, "sqrt_kernel"): + if self.sqrt_kernel is None: self.sqrt_kernel = self.kernel @classmethod @@ -933,8 +934,24 @@ def reduce( :return: A tuple containing the final coreset and the solver state (None). """ - final_coresets = self.kt_split(dataset) - return self.kt_refine(self.kt_choose(final_coresets, dataset)), solver_state + n = len(dataset) + if self.delta is None: + log_n = math.log(n) + if log_n > 0: + log_log_n = math.log(log_n) + if log_log_n > 0: + self.delta = 1 / (n * log_log_n) + else: + self.delta = 1 / (n * log_n) + else: + self.delta = 1 / n + m = math.floor(math.log2(n) - math.log2(self.coreset_size)) + clipped_original_dataset = dataset[: self.coreset_size * 2**m] + partition = self.kt_half_recursive(clipped_original_dataset, m, dataset) + baseline_coreset = self.get_baseline_coreset(dataset, self.coreset_size) + partition.append(baseline_coreset) + best_coreset_indices = self.kt_choose(partition, dataset) + return self.kt_refine(Coresubset(best_coreset_indices, dataset)) def kt_half_recursive(self, points, m, original_dataset): """ @@ -946,7 +963,7 @@ def kt_half_recursive(self, points, m, original_dataset): :return: Fully partitioned list of coresets. """ if m == 0: - return [points] + return [Coresubset(Data(jnp.arange(len(points))), original_dataset)] # Recursively call self.kt_half on the coreset (or the dataset) if hasattr(points, "coreset"): @@ -961,8 +978,8 @@ def kt_half_recursive(self, points, m, original_dataset): # Update indices: map current subset's indices to original dataset if hasattr(points, "nodes") and hasattr(points.nodes, "data"): parent_indices = points.nodes.data # Parent subset's indices - subset1_indices = subset1.nodes.data # Indices relative to parent - subset2_indices = subset2.nodes.data # Indices relative to parent + subset1_indices = subset1.nodes.data.flatten() # Indices relative to parent + subset2_indices = subset2.nodes.data.flatten() # Indices relative to parent # Map subset indices back to original dataset subset1_indices = parent_indices[subset1_indices] @@ -1141,101 +1158,26 @@ def body_fun( body_fun, # body function (arr1, arr2, param, bool_arr_1, bool_arr_2, self.random_key), ) - return [Coresubset(final_arr1, points), Coresubset(final_arr2, points)] - def kt_split(self, points: _Data) -> list[Coresubset[_Data]]: + def get_baseline_coreset( + self, dataset: Data, baseline_coreset_size: int + ) -> Coresubset[_Data]: """ - Perform hierarchical splitting of the input dataset into multiple coresets. + Generate a baseline coreset by randomly sampling from the dataset. - This method splits the dataset recursively halving at each level. At each step, - a probabilistic swapping is applied to refine the distribution of points across - the coresets. - - :param points: The dataset to be split into coresets. - :return: A list of refined coresets representing different hierarchical levels. + :param dataset: The input dataset from which the baseline coreset is sampled. + :param baseline_coreset_size: The number of points in the baseline coreset. + :return: A randomly sampled baseline coreset with the specified size. """ - n = len(points) - coresets = {(j, k): [] for j in range(self.m + 1) for k in range(1, 2**j + 1)} - sigma = { - (j, k): jnp.zeros(1) - for j in range(1, self.m + 1) - for k in range(1, 2 ** (j - 1) + 1) - } - - # Step 1: Initialise with pairs of indices in the lowest level coreset S(0,1) - for i in range(1, n // 2 + 1): - idx1, idx2 = 2 * i - 2, 2 * i - 1 - coresets[(0, 1)].extend([idx1, idx2]) - - # Step 2: Distribute indices hierarchically across levels - for j in range(1, self.m + 1): - if i % (2 ** (j - 1)) == 0: - for k in range(1, 2 ** (j - 1) + 1): - parent_set = coresets[(j - 1, k)] - if len(parent_set) <= 1: - continue - - idx_x, idx_x_prime = parent_set[-2], parent_set[-1] - x, x_prime = points[idx_x], points[idx_x_prime] - - # Calculate kernel values and b^2 - b_squared = ( - self.sqrt_kernel.compute_elementwise(x, x) - + self.sqrt_kernel.compute_elementwise(x_prime, x_prime) - - 2 * self.sqrt_kernel.compute_elementwise(x, x_prime) - ) - - # Compute swap threshold a and update sigma - a, sigma[(j, k)] = self.get_swap_params( - sigma[(j, k)], b_squared, self.delta - ) - - # Calculate alpha for probabilistic swapping - alpha = ( - self.sqrt_kernel.compute_elementwise(x_prime, x_prime) - - self.sqrt_kernel.compute_elementwise(x, x) - + sum( - self.sqrt_kernel.compute_elementwise(points[y], x) - - self.sqrt_kernel.compute_elementwise( - points[y], x_prime - ) - for y in parent_set - ) - - 2 - * sum( - self.sqrt_kernel.compute_elementwise(points[z], x) - - self.sqrt_kernel.compute_elementwise( - points[z], x_prime - ) - for z in coresets[(j, 2 * k - 1)] - ) - ) - - # Compute swap probability - swap_probability = min( - 1, max(0.5 * (1 - (alpha / a).item()), 0) - ) - - # Apply probabilistic swap - if jax.random.uniform(self.random_key) < swap_probability: - idx_x, idx_x_prime = idx_x_prime, idx_x - - # Assign indices to child coresets - coresets[(j, 2 * k - 1)].append(idx_x) - coresets[(j, 2 * k)].append(idx_x_prime) - - # Collect the indices of the final level's coresets - final_coresets = [ - Coresubset(Data(jnp.array(coresets[(self.m, k)])), points) - for k in range(1, 2**self.m + 1) - ] - - return final_coresets + baseline_coreset, _ = RandomSample( + coreset_size=baseline_coreset_size, random_key=self.random_key + ).reduce(dataset) + return baseline_coreset def kt_choose( self, candidate_coresets: list[Coresubset[_Data]], points: _Data - ) -> Coresubset[_Data]: + ) -> Shaped[Array, " coreset_size"]: """ Select the best coreset from a list of candidate coresets based on MMD. @@ -1244,20 +1186,28 @@ def kt_choose( :return: The coreset with the smallest MMD relative to the input dataset. """ mmd = MMD(kernel=self.kernel) + candidate_coresets_jax = jnp.array([c.coreset.data for c in candidate_coresets]) + candidate_coresets_indices = jnp.array([c.nodes for c in candidate_coresets]) + mmd_values = jax.vmap(lambda c: mmd.compute(c, points))(candidate_coresets_jax) - best_coreset = min( - candidate_coresets, key=lambda coreset: mmd.compute(points, coreset.coreset) - ) + best_index = jnp.argmin(mmd_values) - return best_coreset + return candidate_coresets_indices[best_index] def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]: """ Refine the selected candidate coreset. - It is not yet implemented and serves as a placeholder for future implementation. + Use meth:`~coreax.solvers.KernelHerding.refine` which achieves the result of + looping through each element in coreset replacing that element with a point in + the original dataset to minimise MMD in each step. :param candidate_coreset: The candidate coreset to be refined. :return: The refined coreset. """ - return candidate_coreset + print("before refine", candidate_coreset.nodes.data) + refined_coreset, _ = KernelHerding( + coreset_size=self.coreset_size, kernel=self.kernel + ).refine(candidate_coreset) + print("after refine", refined_coreset.nodes.data) + return refined_coreset, None From 5da0a7541eb8d044e5d692f7a324d9c99e089859 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:44:19 +0000 Subject: [PATCH 8/9] feat: refactor reduce method of KernelThinning #688 --- coreax/solvers/coresubset.py | 122 +++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 50 deletions(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index f5cb70ac3..6b4b7fc09 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -872,12 +872,11 @@ class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): coresets by repeatedly halving the dataset and applying probabilistic swapping. The best of these candidates is chosen which is further refined to minimise the Maximum Mean Discrepancy (MMD) between the original dataset and the coreset. + This implementation is a modification of the Kernel Thinning algorithm in + :cite:`dwivedi2024kernelthinning` to make it an ExplicitSizeSolver. :param kernel: A `~coreax.kernels.ScalarValuedKernel` instance defining the primary kernel function used for choosing the best coreset and refining it. - - :param m: An integer specifying the number of hierarchical halving steps in the - coreset construction. :param random_key: Key for random number generation, enabling reproducibility of probabilistic components in the algorithm. :param delta: A float between 0 and 1 used to compute the swapping probability @@ -892,92 +891,115 @@ class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): delta: Optional[float] = None sqrt_kernel: Optional[ScalarValuedKernel] = None - def __post_init__(self): + def reduce( + self, dataset: _Data, solver_state: None = None + ) -> tuple[Coresubset[_Data], None]: """ - Initialise square-root kernel. + Reduce 'dataset' to a :class:`~coreax.coreset.Coresubset` with 'KernelThinning'. - If square-root kernel is not provided, check if square-root kernel of the given - kernel is implemented and set that as the square root, otherwise raise an error. - """ - if self.sqrt_kernel is None: - self.sqrt_kernel = self.kernel + This is done by first computing the square root kernel if not already provided, + then using the recommended value for the delta parameter according to + :cite:`dwivedi2024kernelthinning` if not specified. Next, the number of halving + steps required, referred to as `m`, is calculated. The original data is clipped + so that it is divisible by a power of two. The kernel halving algorithm is then + recursively applied to halve the data. - @classmethod - def get_swap_params( - cls, sigma: Array, b: Array, delta: float - ) -> tuple[Array, Array]: - """ - Compute the swap threshold and update the scaling parameter for swapping. + Subsequently, a `baseline_coreset` is added to the ensemble of coresets. The + best coreset is selected to minimise the Maximum Mean Discrepancy (MMD) and + finally, it is refined further for optimal results. + + + :param dataset: The original dataset to be reduced. + :param solver_state: The state of the solver. - :param sigma: The current scaling parameter used in the swapping process. - :param b: The kernel-based distance between two points in the dataset. - :param delta: A parameter used in calculation of the swapping probability. - :return: The swap threshold and the updated scaling parameter. + :return: A tuple containing the final coreset and the solver state (None). """ - a = jnp.maximum(b * sigma * jnp.sqrt(2 * jnp.log(2 / delta)), b**2) + n = len(dataset) + + sqrt_kernel = self.sqrt_kernel + if sqrt_kernel is None: + if hasattr(self.kernel, "get_sqrt_kernel"): + sqrt_kernel = self.kernel.get_sqrt_kernel(dataset.data.ndim) + else: + raise NotImplementedError( + f"The square root of the " + f"{self.kernel.__class__.__name__} is not" + f" implemented. Please provide the square" + f" root kernel if known." + ) + + delta = self.delta + if delta is None: + log_n = math.log(n) + if log_n > 0: + log_log_n = math.log(log_n) + if log_log_n > 0: + delta = 1 / (n * log_log_n) + else: + delta = 1 / (n * log_n) + else: + delta = 1 / n - # Update sigma - new_sigma = jnp.sqrt( - sigma**2 + jnp.maximum(b**2 * (1 + (b**2 - 2 * a) * sigma**2 / a**2), 0) + # Create a new instance with updated parameters + new_instance = KernelThinning( + coreset_size=self.coreset_size, + kernel=self.kernel, + random_key=self.random_key, + delta=delta, + sqrt_kernel=sqrt_kernel, ) - return a, new_sigma + return new_instance.reduce_internal(dataset) - def reduce( - self, dataset: _Data, solver_state: None = None + def reduce_internal( + self, + dataset: _Data, ) -> tuple[Coresubset[_Data], None]: """ - Reduce the input dataset to a single refined coreset. + Implement `reduce` method for the new instance with set parameters. :param dataset: The original dataset to be reduced. - :param solver_state: The state of the solver (currently not used). :return: A tuple containing the final coreset and the solver state (None). """ n = len(dataset) - if self.delta is None: - log_n = math.log(n) - if log_n > 0: - log_log_n = math.log(log_n) - if log_log_n > 0: - self.delta = 1 / (n * log_log_n) - else: - self.delta = 1 / (n * log_n) - else: - self.delta = 1 / n m = math.floor(math.log2(n) - math.log2(self.coreset_size)) clipped_original_dataset = dataset[: self.coreset_size * 2**m] + partition = self.kt_half_recursive(clipped_original_dataset, m, dataset) baseline_coreset = self.get_baseline_coreset(dataset, self.coreset_size) partition.append(baseline_coreset) + best_coreset_indices = self.kt_choose(partition, dataset) return self.kt_refine(Coresubset(best_coreset_indices, dataset)) - def kt_half_recursive(self, points, m, original_dataset): + def kt_half_recursive(self, current_coreset, m, original_dataset): """ Recursively halve the original dataset into coresets. - :param points: The current coreset or dataset being partitioned. + :param current_coreset: The current coreset or dataset being partitioned. :param m: The remaining depth of recursion. :param original_dataset: The original dataset. :return: Fully partitioned list of coresets. """ if m == 0: - return [Coresubset(Data(jnp.arange(len(points))), original_dataset)] + return [ + Coresubset(Data(jnp.arange(len(current_coreset))), original_dataset) + ] # Recursively call self.kt_half on the coreset (or the dataset) - if hasattr(points, "coreset"): - subset1, subset2 = self.kt_half(points.coreset) + if hasattr(current_coreset, "coreset"): + subset1, subset2 = self.kt_half(current_coreset.coreset) else: - subset1, subset2 = self.kt_half(points) + subset1, subset2 = self.kt_half(current_coreset) # Update pre_coreset_data for both subsets to point to the original dataset subset1 = eqx.tree_at(lambda x: x.pre_coreset_data, subset1, original_dataset) subset2 = eqx.tree_at(lambda x: x.pre_coreset_data, subset2, original_dataset) # Update indices: map current subset's indices to original dataset - if hasattr(points, "nodes") and hasattr(points.nodes, "data"): - parent_indices = points.nodes.data # Parent subset's indices + if hasattr(current_coreset, "nodes") and hasattr(current_coreset.nodes, "data"): + parent_indices = current_coreset.nodes.data # Parent subset's indices subset1_indices = subset1.nodes.data.flatten() # Indices relative to parent subset2_indices = subset2.nodes.data.flatten() # Indices relative to parent @@ -1194,7 +1216,9 @@ def kt_choose( return candidate_coresets_indices[best_index] - def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]: + def kt_refine( + self, candidate_coreset: Coresubset[_Data] + ) -> tuple[Coresubset[_Data], None]: """ Refine the selected candidate coreset. @@ -1205,9 +1229,7 @@ def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]: :param candidate_coreset: The candidate coreset to be refined. :return: The refined coreset. """ - print("before refine", candidate_coreset.nodes.data) refined_coreset, _ = KernelHerding( coreset_size=self.coreset_size, kernel=self.kernel ).refine(candidate_coreset) - print("after refine", refined_coreset.nodes.data) return refined_coreset, None From 9c50aa851a3f726bdf6fbfd701d5d6d78113d83e Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:33:23 +0000 Subject: [PATCH 9/9] feat: changed the parameters delta and square root kernel in Kernel Thinning back to required. #893 --- coreax/solvers/coresubset.py | 44 ++++++++---------------------------- 1 file changed, 9 insertions(+), 35 deletions(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 6b4b7fc09..a0a3e4369 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -888,8 +888,8 @@ class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): kernel: ScalarValuedKernel random_key: KeyArrayLike - delta: Optional[float] = None - sqrt_kernel: Optional[ScalarValuedKernel] = None + delta: float + sqrt_kernel: ScalarValuedKernel def reduce( self, dataset: _Data, solver_state: None = None @@ -915,41 +915,15 @@ def reduce( :return: A tuple containing the final coreset and the solver state (None). """ n = len(dataset) + m = math.floor(math.log2(n) - math.log2(self.coreset_size)) + clipped_original_dataset = dataset[: self.coreset_size * 2**m] - sqrt_kernel = self.sqrt_kernel - if sqrt_kernel is None: - if hasattr(self.kernel, "get_sqrt_kernel"): - sqrt_kernel = self.kernel.get_sqrt_kernel(dataset.data.ndim) - else: - raise NotImplementedError( - f"The square root of the " - f"{self.kernel.__class__.__name__} is not" - f" implemented. Please provide the square" - f" root kernel if known." - ) - - delta = self.delta - if delta is None: - log_n = math.log(n) - if log_n > 0: - log_log_n = math.log(log_n) - if log_log_n > 0: - delta = 1 / (n * log_log_n) - else: - delta = 1 / (n * log_n) - else: - delta = 1 / n - - # Create a new instance with updated parameters - new_instance = KernelThinning( - coreset_size=self.coreset_size, - kernel=self.kernel, - random_key=self.random_key, - delta=delta, - sqrt_kernel=sqrt_kernel, - ) + partition = self.kt_half_recursive(clipped_original_dataset, m, dataset) + baseline_coreset = self.get_baseline_coreset(dataset, self.coreset_size) + partition.append(baseline_coreset) - return new_instance.reduce_internal(dataset) + best_coreset_indices = self.kt_choose(partition, dataset) + return self.kt_refine(Coresubset(best_coreset_indices, dataset)) def reduce_internal( self,