Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added the initial implementation of KT-split #871

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

qh681248
Copy link
Contributor

PR Type

  • Feature

Description

How Has This Been Tested?

Checklist before requesting a review

  • I have made sure that my PR is not a duplicate.
  • My code follows the style guidelines of this project.
  • I have ensured my code is easy to understand, including docstrings and comments where necessary.
  • I have performed a self-review of my code.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • New and existing unit tests pass locally with my changes.
  • Any dependent changes have been merged and published in downstream modules.
  • I have updated CHANGELOG.md, if appropriate.

@qh681248 qh681248 linked an issue Nov 15, 2024 that may be closed by this pull request
@qh681248 qh681248 marked this pull request as draft November 15, 2024 13:08
Copy link
Contributor

Performance review

Commit 27887b6 - Merge eb68cc6 into 98b6142

No significant changes to performance.

Copy link
Contributor

Performance review

Commit 17ef126 - Merge e008146 into ed78130

No significant changes to performance.

Copy link
Contributor

Performance review

Commit cd03a13 - Merge 1475026 into ed78130

No significant changes to performance.

@qh681248 qh681248 linked an issue Nov 25, 2024 that may be closed by this pull request
Copy link
Contributor

Performance review

Commit 1cafdf8 - Merge e1dedb4 into ed78130

No significant changes to performance.

@gw265981 gw265981 self-requested a review November 27, 2024 16:40
Copy link
Contributor

@gw265981 gw265981 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the comments which we mostly discussed previously. After you have a final working implementation, add some unit tests or create an issue for those (but it might be helpful for you to check that everything is working).

probabilistic components in the algorithm.
"""

kernel: ScalarValuedKernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the defaults for some parameters as discussed previously

random_key: KeyArrayLike

@classmethod
def get_swap_params(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be the same as get_a_and_param in kt_half, do you need both?

final_coresets = self.kt_split(dataset)
return self.kt_refine(self.kt_choose(final_coresets, dataset)), solver_state

def kt_half_recursive(self, points, m, original_dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename points to something like current_subset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, add type annotations.

"""
n = len(points) // 2
original_array = points.data
arr1 = jnp.zeros(n, dtype=jnp.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use more descriptive variable names and preferable add a comment before to explain what they are for. If the variables are simply temporary placeholders, explain that in a comment.

subset1 = eqx.tree_at(lambda x: x.nodes.data, subset1, subset1_indices)
subset2 = eqx.tree_at(lambda x: x.nodes.data, subset2, subset2_indices)

# Recur for both subsets and concatenate results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recur -> recurse

alpha = term1 + term2
return alpha, bool_arr_1, bool_arr_2

def final_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, just be more descriptive in the name, e.g., apply_probabilistic_assignment (just an example, feel free to make it more appropriate of course).


return Coresubset(final_arr1, points), Coresubset(final_arr2, points)

def kt_split(self, points: _Data) -> list[Coresubset[_Data]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, we probably want to remove this from here for now but save it for later.


return final_coresets

def kt_choose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will likely have to be changed to be jit-compatible, e.g., using vmap to get a vector of MMD values and then jnp.argmin to select the best.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, implement the baseline coreset computation and expose the method for this as parameter (random is probably a good default).


return best_coreset

def kt_refine(self, candidate_coreset: Coresubset[_Data]) -> Coresubset[_Data]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to use the Kernel Herding refine method here.


return a, new_sigma

def reduce(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to make this an ExplicitSizeSolver, this might be a place to do the logic of discarding and padding the points. Also, this will provide the coreset_size parameter, so you will probably want to remove m as a parameter and compute it as log2(data_size/coreset_size) (after discarding etc).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add KT-select_best algorithm Add KT-split algorithm
2 participants