Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve KNN Graph Construction for Handling Exact Duplicates and Numerical Precision #1119

Merged
merged 97 commits into from
May 24, 2024

Conversation

elisno
Copy link
Member

@elisno elisno commented May 6, 2024

Summary

This PR introduces a new argument, correct_exact_duplicates: bool = True, to the function construct_knn_graph_from_features and others. This update ensures that exact duplicates in the feature array are handled correctly during the k-nearest neighbors (KNN) graph construction.

New Functions

  • correct_knn_graph:
    Wrapper around correct_knn_distances_and_indices to correct KNN graph based on the feature array.

    def correct_knn_graph(features: FeatureArray, knn_graph: csr_matrix) -> csr_matrix:
        ...
  • correct_knn_distances_and_indices_with_exact_duplicate_sets_inplace:
    Main logic for in-place correction of distances and indices arrays.

    def correct_knn_distances_and_indices_with_exact_duplicate_sets_inplace(
        distances: np.ndarray,
        indices: np.ndarray,
        exact_duplicate_sets: List[np.ndarray],
    ) -> None:
        ...
  • correct_knn_distances_and_indices:
    Corrects KNN distances and indices with optional exact duplicate sets and warning.

    def correct_knn_distances_and_indices(
        features: FeatureArray,
        distances: np.ndarray,
        indices: np.ndarray,
        exact_duplicate_sets: Optional[List[np.ndarray]] = None,
        enable_warning: bool = False,
    ) -> tuple[np.ndarray, np.ndarray]:
        ...

Impact

  • The default behavior of most functions now includes correction for exact duplicates unless a knn: NearestNeighbors object is passed explicitly.
  • This change affects the outlier detection in cleanlab/outliers.py, where correction is applied manually if no knn object is provided.
  • The noniid check in Datalab is updated to construct a KNN graph without relying on NearestNeighbors from sklearn.

What this PR does not address:

  • The correction logic addresses exact duplicates but does not explicitly cover scenarios where there are near-duplicates or small variations in features that might need similar handling.
  • Performance optimization for large duplicated datasets. Users with large datasets and numerous sets of exact duplicates might experience slower performance due to the iteration across all the different duplicate sets.
  • The PR adds corrections for exact duplicates during KNN graph construction on training data, but it does not provide flexibility for other k nearest neighbor search libraries. Such graphs should be constructed by the user, and subsequently corrected with the same features.
    • The same applies to correcting knn graphs on test data.

Benchmark Results

Two benchmark scenarios were tested:

  1. All-Identical Dataset:

    • One unique point duplicated N times.
  2. Copied Dataset:

    • Several points duplicated a few times.

The graphs below compare runtime and memory usage for different functions:

  • Top Graph: Runtime vs. Number of Points
  • Bottom Graph: Memory Usage vs. Number of Points

For the All-Identical Dataset, the correction function spends its time constructing a small circulant matrix to find the nearest neighbors of the first k+1 elements. All other points just refer to the first k points. The purple line shows how the correction algorithm performs if all the duplicate information is pre-computed and the output can be modified in-place. No knn-graph construction occurs in that function.

image

For the Copied Dataset, there are far more sets to iterate over, which impacts the performance of the correction function.
This shouldn't really exceed the runtime of the exhaustive search algorithm by too much.

image

The benchmark code and additional results are provided in the expandable sections.

Benchmark Code for All-Identical Dataset

Code for All-Identical Dataset

from __future__ import annotations
import time
import tracemalloc
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from cleanlab.internal.neighbor.knn_graph import (
    _compute_exact_duplicate_sets,
    correct_knn_distances_and_indices,
    correct_knn_distances_and_indices_with_exact_duplicate_sets_inplace,
    features_to_knn,
    construct_knn_graph_from_index,
    create_knn_graph_and_index,
)

# Define the sizes of feature arrays for the benchmark
feature_sizes = np.logspace(1, 7.0, num=25, base=10, dtype=int)

# Define a function to benchmark the memory and runtime of a given function
def benchmark_function(func, *args, **kwargs):
    # Record the start time and memory usage
    start_time = time.time()
    tracemalloc.start()
    
    # Run the function
    result = func(*args, **kwargs)
    
    # Record the end time and memory usage
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    
    # Calculate runtime and peak memory usage
    runtime = end_time - start_time
    peak_memory = peak / 10**6  # Convert to MB
    
    return runtime, peak_memory, result

# Initialize a DataFrame to store the benchmark results
columns = [f_col:='Function', N_col:='Num points', runtime_col:='Runtime (s)', memory_col:='Memory (MB)']
results = pd.DataFrame(columns=columns)

# Define functions to benchmark
functions = {
    # 'features_to_knn': features_to_knn,
    'construct_knn_graph_from_index_without_correction': construct_knn_graph_from_index,
    'construct_knn_graph_from_index_with_correction': construct_knn_graph_from_index,
    # 'create_knn_graph_and_index': create_knn_graph_and_index,
    'correct_knn_distances_and_indices': correct_knn_distances_and_indices,
    'correct_with_precomputed_exact_duplicate_sets': correct_knn_distances_and_indices,
    'correct_with_precomputed_exact_duplicate_sets_inplace': correct_knn_distances_and_indices_with_exact_duplicate_sets_inplace,
}
results_list_of_dicts = []
# Run the benchmark for each function and feature size
N_max_slow = 20000
for N in tqdm(feature_sizes):
    # features = np.random.rand(N, 10)  # Generate random feature array
    features = np.ones((N, 10))
    if N > N_max_slow:
        k = min(10, N-1)
        distances = np.tile(np.ones(k), (N, 1))
        indices = np.tile(np.arange(k), (N, 1))
        exact_duplicate_sets = [np.arange(N)]
    else:

        knn_graph, knn = create_knn_graph_and_index(features, correct_exact_duplicates=False)
        distances = knn_graph.data.reshape(knn_graph.shape[0], -1)
    
        indices = knn_graph.indices.reshape(knn_graph.shape[0], -1)

        exact_duplicate_sets = _compute_exact_duplicate_sets(features)
    
    for func_name, func in functions.items():
        if func_name == 'construct_knn_graph_from_index_without_correction':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, knn)
        elif func_name == 'construct_knn_graph_from_index_with_correction':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, knn, correct_exact_duplicates=True)
        elif func_name == 'create_knn_graph_and_index':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, features)
        elif func_name == 'correct_knn_distances_and_indices':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, features=features, distances=distances, indices=indices)
        elif func_name == 'correct_with_precomputed_exact_duplicate_sets':
            runtime, peak_memory, _ = benchmark_function(func, features=features, distances=distances, indices=indices, exact_duplicate_sets=exact_duplicate_sets)
        elif func_name == 'correct_with_precomputed_exact_duplicate_sets_inplace':
            runtime, peak_memory, _ = benchmark_function(func, distances=distances, indices=indices, exact_duplicate_sets=exact_duplicate_sets)
        else:
            if N > 1000:
                continue
            runtime, peak_memory, _ = benchmark_function(func, features)
        
        
        # Store the results in the DataFrame
        results_list_of_dicts = results_list_of_dicts + [dict({
            f_col: func_name,
            N_col: N,
            runtime_col: runtime,
            memory_col: peak_memory,
        })]

results = pd.DataFrame(results_list_of_dicts)
# Save the results to a CSV file
results.to_csv('benchmark_results.csv', index=False)

# Print the results
print(results)

# Plot the results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 1, figsize=(10, 10))
for func_name in functions.keys():
    subset = results[results[f_col] == func_name]
    axes[0].plot(subset[N_col], subset[runtime_col], label=func_name)
    axes[1].plot(subset[N_col], subset[memory_col], label=func_name)

axes[0].set_xlabel(N_col)
axes[0].set_ylabel(runtime_col)
axes[0].set_title('Runtime vs. Num Points')
axes[0].legend()
axes[0].set_xscale('log')
axes[0].set_yscale('log')

axes[1].set_xlabel(N_col)
axes[1].set_ylabel(memory_col)
axes[1].set_title('Memory Usage vs. Num Points')
axes[1].legend()
axes[1].set_xscale('log')
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

Benchmark Code for Copied Dataset

Code for Copied Dataset

from __future__ import annotations
import time
import tracemalloc
import numpy as np
import pandas as pd
from memory_profiler import memory_usage
from tqdm.auto import tqdm

from cleanlab.internal.neighbor.knn_graph import (
    _compute_exact_duplicate_sets,
    correct_knn_distances_and_indices,
    correct_knn_distances_and_indices_with_exact_duplicate_sets_inplace,
    features_to_knn,
    construct_knn_graph_from_index,
    create_knn_graph_and_index,
)

# Define the sizes of feature arrays for the benchmark
feature_sizes = np.logspace(1, 4.6, num=25, base=10, dtype=int)

# Define a function to benchmark the memory and runtime of a given function
def benchmark_function(func, *args, **kwargs):
    # Record the start time and memory usage
    start_time = time.time()
    tracemalloc.start()
    
    # Run the function
    result = func(*args, **kwargs)
    
    # Record the end time and memory usage
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    
    # Calculate runtime and peak memory usage
    runtime = end_time - start_time
    peak_memory = peak / 10**6  # Convert to MB
    
    return runtime, peak_memory, result

# Initialize a DataFrame to store the benchmark results
columns = [f_col:='Function', N_col:='Num points', runtime_col:='Runtime (s)', memory_col:='Memory (MB)']
results = pd.DataFrame(columns=columns)

# Define functions to benchmark
functions = {
    # 'features_to_knn': features_to_knn,
    'construct_knn_graph_from_index_without_correction': construct_knn_graph_from_index,
    'construct_knn_graph_from_index_with_correction': construct_knn_graph_from_index,
    # 'create_knn_graph_and_index': create_knn_graph_and_index,
    'correct_knn_distances_and_indices': correct_knn_distances_and_indices,
    'correct_with_precomputed_exact_duplicate_sets': correct_knn_distances_and_indices,
    'correct_with_precomputed_exact_duplicate_sets_inplace': correct_knn_distances_and_indices_with_exact_duplicate_sets_inplace,
}
results_list_of_dicts = []
# Run the benchmark for each function and feature size
N_max_slow = 20000
num_copies = 5
for N in tqdm(feature_sizes):

    features = np.random.rand(N // num_copies, 10)
    features = np.vstack([features] * num_copies)
    


    N = features.shape[0]

    knn_graph, knn = create_knn_graph_and_index(features, correct_exact_duplicates=False)
    distances = knn_graph.data.reshape(knn_graph.shape[0], -1)

    indices = knn_graph.indices.reshape(knn_graph.shape[0], -1)

    exact_duplicate_sets = _compute_exact_duplicate_sets(features)
    
    for func_name, func in functions.items():
        if func_name == 'construct_knn_graph_from_index_without_correction':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, knn)
        elif func_name == 'construct_knn_graph_from_index_with_correction':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, knn, correct_exact_duplicates=True)
        elif func_name == 'create_knn_graph_and_index':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, features)
        elif func_name == 'correct_knn_distances_and_indices':
            if N > N_max_slow:
                continue
            runtime, peak_memory, _ = benchmark_function(func, features=features, distances=distances, indices=indices)
        elif func_name == 'correct_with_precomputed_exact_duplicate_sets':
            runtime, peak_memory, _ = benchmark_function(func, features=features, distances=distances, indices=indices, exact_duplicate_sets=exact_duplicate_sets)
        elif func_name == 'correct_with_precomputed_exact_duplicate_sets_inplace':
            runtime, peak_memory, _ = benchmark_function(func, distances=distances, indices=indices, exact_duplicate_sets=exact_duplicate_sets)
        else:
            if N > 1000:
                continue
            runtime, peak_memory, _ = benchmark_function(func, features)
        
        
        # Store the results in the DataFrame
        results_list_of_dicts = results_list_of_dicts + [dict({
            f_col: func_name,
            N_col: N,
            runtime_col: runtime,
            memory_col: peak_memory,
        })]

results = pd.DataFrame(results_list_of_dicts)
# Save the results to a CSV file
results.to_csv('benchmark_results_with_dataset_copy.csv', index=False)

# Print the results
print(results)

# Plot the results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 1, figsize=(10, 10))
for func_name in functions.keys():
    subset = results[results[f_col] == func_name]
    axes[0].plot(subset[N_col], subset[runtime_col], label=func_name)
    axes[1].plot(subset[N_col], subset[memory_col], label=func_name)

axes[0].set_xlabel(N_col)
axes[0].set_ylabel(runtime_col)
axes[0].set_title('Runtime vs. Num Points')
axes[0].legend()
axes[0].set_xscale('log')
axes[0].set_yscale('log')

axes[1].set_xlabel(N_col)
axes[1].set_ylabel(memory_col)
axes[1].set_title('Memory Usage vs. Num Points')
axes[1].legend()
axes[1].set_xscale('log')
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

elisno and others added 30 commits May 3, 2024 15:49
…earch object from an array of numerical features
the aim of the function is to eliminate unecessary memory allocations and reduce runtime with more efficient numpy operations and manipulating duplicate indices as little as possible.

Currently, the correction would end with 2 situations:

1. The number of duplicates equals or exceeds the number of neighbors:
- A single ciruclant matrix is enough to make sure that each row gets only other duplicates as neighbors. Any other duplicate points can just point to the first k duplicate points as their neigbors. This will basically ALWAYS happen when the dataset consists of all-identical examples that exceed the number of neighbors (So a circulant matrix should take O(k^2) space, and the first point takes O(k) space searching through O(N) duplicate points for an all-identical dataset.

2. The number of duplicates is smaller than the number of neighbors.
- The same circulant matrix can be used to fill out first few columns of the indices matrix. But before that, we must most all non-duplicate points to the far right. In practice, EVERY POINT has enough non-duplicate points as neighbors to make this work, it's just about ensuring that they are put on the far-right side of the array.
also add docstring for the helper function generating the circulant matrix representing the neighbors of the first k+1 duplicates.
@elisno elisno requested a review from huiwengoh May 21, 2024 00:25
@elisno
Copy link
Member Author

elisno commented May 21, 2024

@huiwengoh I've addressed all of @jwmueller's comments.

Can you give a review and merge this?

@elisno
Copy link
Member Author

elisno commented May 21, 2024

I've added some graphs that show the runtime of the knn graph construction in the library.

The main conclusion is that:

  • As the dataset increases in size, the knn search takes the longest time, and the runtime when correcting for exact duplicates gets amortized.
  • The green line shows what the additional runtime is when correcting a knn_graph for exact duplicates.
    • It involves calling np.unique on a feature array, but also makes some copies of the output arrays distances and indices.
  • The purple line shows how long it takes to run the core-correction logic.
    • For a single exact duplicate set, we have a fully optimized algorithm that runs many orders of magnitude faster than running the actual knn search.
    • As the number of exact duplicate sets increases, the runtime seems to approach the runtime complexity of the knn search (at least for ~1000 data points.

cleanlab/outlier.py Outdated Show resolved Hide resolved
@elisno
Copy link
Member Author

elisno commented May 21, 2024

CI failure is unrelated to this PR.

Scikit-learn 1.5.0 was released just a few hours ago and it only affects one test case in

cv.fit(X=DATA["X_train"], y=DATA["labels"])

cleanlab/outlier.py Outdated Show resolved Hide resolved
Copy link
Contributor

@huiwengoh huiwengoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks for detailed docstrings and clarifications in the comments above :)

@elisno elisno changed the title Correct knn graph to better handle duplicates and numerical issues Improve KNN Graph Construction for Handling Exact Duplicates and Numerical Precision May 24, 2024
@elisno elisno merged commit 25b7aab into cleanlab:master May 24, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants