Skip to content

Commit

Permalink
Fix numba wasserstein
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff1995 committed Jul 23, 2024
1 parent 1f446cd commit 148aa6e
Showing 1 changed file with 13 additions and 41 deletions.
54 changes: 13 additions & 41 deletions Cell_BLAST/blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
MINIMAL = 0


def _wasserstein_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover
@numba.jit(nopython=True, nogil=True, cache=True)
def wasserstein_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover
x_sorter = np.argsort(x)
y_sorter = np.argsort(y)
xy = np.concatenate((x, y))
Expand All @@ -34,37 +35,6 @@ def _wasserstein_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cove
return np.sum(np.multiply(np.abs(x_cdf - y_cdf), deltas))


def _energy_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover
x_sorter = np.argsort(x)
y_sorter = np.argsort(y)
xy = np.concatenate((x, y))
xy.sort()
deltas = np.diff(xy)
x_cdf = np.searchsorted(x[x_sorter], xy[:-1], "right") / x.size
y_cdf = np.searchsorted(y[y_sorter], xy[:-1], "right") / y.size
return np.sqrt(2 * np.sum(np.multiply(np.square(x_cdf - y_cdf), deltas)))


@numba.extending.overload(
scipy.stats.wasserstein_distance, jit_options={"nogil": True, "cache": True}
)
def _wasserstein_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover
if (
x == numba.float32[::1] and y == numba.float32[::1]
) or (
x == numba.float64[::1] and y == numba.float64[::1]
):
return _wasserstein_distance_impl


@numba.extending.overload(
scipy.stats.energy_distance, jit_options={"nogil": True, "cache": True}
)
def _energy_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover
if x == numba.float32[::1] and y == numba.float32[::1]:
return _energy_distance_impl


@numba.jit(nopython=True, nogil=True, cache=True)
def ed(x: np.ndarray, y: np.ndarray): # pragma: no cover
r"""
Expand Down Expand Up @@ -230,10 +200,10 @@ def npd_v1(
np.std(y_posterior) + np.float32(eps)
)
return 0.5 * (
scipy.stats.wasserstein_distance(
wasserstein_distance(
xy_posterior1[: len(x_posterior)], xy_posterior1[-len(y_posterior) :]
)
+ scipy.stats.wasserstein_distance(
+ wasserstein_distance(
xy_posterior2[: len(x_posterior)], xy_posterior2[-len(y_posterior) :]
)
)
Expand Down Expand Up @@ -939,13 +909,15 @@ def query(
hits,
dist,
pval,
query
if store_dataset
else anndata.AnnData(
X=scipy.sparse.csr_matrix((query.shape[0], 0)),
obs=pd.DataFrame(index=query.obs.index),
var=pd.DataFrame(),
uns={},
(
query
if store_dataset
else anndata.AnnData(
X=scipy.sparse.csr_matrix((query.shape[0], 0)),
obs=pd.DataFrame(index=query.obs.index),
var=pd.DataFrame(),
uns={},
)
),
)

Expand Down

0 comments on commit 148aa6e

Please sign in to comment.