Skip to content

Commit

Permalink
Merge pull request #7 from PierreBoyeau/ppi_mean_multid
Browse files Browse the repository at this point in the history
[WIP] fix multid bug + fast `_calc_lhat_glm`
  • Loading branch information
aangelopoulos authored Mar 7, 2024
2 parents 8066bab + 1216262 commit 493512f
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 69 deletions.
171 changes: 102 additions & 69 deletions ppi_py/ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from statsmodels.stats.weightstats import _zconfint_generic, _zstat_generic
from sklearn.linear_model import LogisticRegression
from .utils import (
construct_weight_vector,
safe_expit,
safe_log1pexp,
compute_cdf,
Expand All @@ -15,6 +16,7 @@
linfty_dkw,
linfty_binom,
form_discrete_distribution,
reshape_to_2d,
)


Expand Down Expand Up @@ -55,7 +57,14 @@ def rectified_p_value(


def ppi_mean_pointestimate(
Y, Yhat, Yhat_unlabeled, lhat=None, coord=None, w=None, w_unlabeled=None
Y,
Yhat,
Yhat_unlabeled,
lhat=None,
coord=None,
w=None,
w_unlabeled=None,
lambd_optim_mode="overall",
):
"""Computes the prediction-powered point estimate of the d-dimensional mean.
Expand All @@ -74,19 +83,20 @@ def ppi_mean_pointestimate(
Notes:
`[ADZ23] <https://arxiv.org/abs/2311.01453>`__ A. N. Angelopoulos, J. C. Duchi, and T. Zrnic. PPI++: Efficient Prediction Powered Inference. arxiv:2311.01453, 2023.
"""
Y = reshape_to_2d(Y)
Yhat = reshape_to_2d(Yhat)
Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)
n = Y.shape[0]
N = Yhat_unlabeled.shape[0]
d = Yhat.shape[1] if len(Yhat.shape) > 1 else 1
w = np.ones(n) if w is None else w / w.sum() * n
w_unlabeled = (
np.ones(N)
if w_unlabeled is None
else w_unlabeled / w_unlabeled.sum() * N
)
d = Yhat.shape[1]

w = construct_weight_vector(n, w, vectorized=True)
w_unlabeled = construct_weight_vector(N, w_unlabeled, vectorized=True)

if lhat is None:
ppi_pointest = (w_unlabeled * Yhat_unlabeled).mean() + (
ppi_pointest = (w_unlabeled * Yhat_unlabeled).mean(0) + (
w * (Y - Yhat)
).mean()
).mean(0)
grads = w * (Y - ppi_pointest)
grads_hat = w * (Yhat - ppi_pointest)
grads_hat_unlabeled = w_unlabeled * (Yhat_unlabeled - ppi_pointest)
Expand All @@ -98,6 +108,7 @@ def ppi_mean_pointestimate(
inv_hessian,
coord=None,
clip=True,
optim_mode=lambd_optim_mode,
)
return ppi_mean_pointestimate(
Y,
Expand All @@ -111,7 +122,7 @@ def ppi_mean_pointestimate(
else:
return (w_unlabeled * lhat * Yhat_unlabeled).mean(axis=0) + (
w * (Y - lhat * Yhat)
).mean(axis=0)
).mean(axis=0).squeeze()


def ppi_mean_ci(
Expand All @@ -124,6 +135,7 @@ def ppi_mean_ci(
coord=None,
w=None,
w_unlabeled=None,
lambd_optim_mode="overall",
):
"""Computes the prediction-powered confidence interval for a d-dimensional mean.
Expand All @@ -147,12 +159,13 @@ def ppi_mean_ci(
n = Y.shape[0]
N = Yhat_unlabeled.shape[0]
d = Y.shape[1] if len(Y.shape) > 1 else 1
w = np.ones(n) if w is None else w / w.sum() * n
w_unlabeled = (
np.ones(N)
if w_unlabeled is None
else w_unlabeled / w_unlabeled.sum() * N
)

Y = reshape_to_2d(Y)
Yhat = reshape_to_2d(Yhat)
Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)

w = construct_weight_vector(n, w, vectorized=True)
w_unlabeled = construct_weight_vector(N, w_unlabeled, vectorized=True)

if lhat is None:
ppi_pointest = ppi_mean_pointestimate(
Expand All @@ -174,6 +187,7 @@ def ppi_mean_ci(
inv_hessian,
coord=None,
clip=True,
optim_mode=lambd_optim_mode,
)
return ppi_mean_ci(
Y,
Expand All @@ -196,8 +210,8 @@ def ppi_mean_ci(
w_unlabeled=w_unlabeled,
)

imputed_std = (w_unlabeled * (lhat * Yhat_unlabeled)).std() / np.sqrt(N)
rectifier_std = (w * (Y - lhat * Yhat)).std() / np.sqrt(n)
imputed_std = (w_unlabeled * (lhat * Yhat_unlabeled)).std(0) / np.sqrt(N)
rectifier_std = (w * (Y - lhat * Yhat)).std(0) / np.sqrt(n)

return _zconfint_generic(
ppi_pointest,
Expand All @@ -217,6 +231,7 @@ def ppi_mean_pval(
coord=None,
w=None,
w_unlabeled=None,
lambd_optim_mode="overall",
):
"""Computes the prediction-powered p-value for a 1D mean.
Expand All @@ -239,35 +254,39 @@ def ppi_mean_pval(
"""
n = Y.shape[0]
N = Yhat_unlabeled.shape[0]
w = np.ones(n) if w is None else w / w.sum() * n
w_unlabeled = (
np.ones(N)
if w_unlabeled is None
else w_unlabeled / w_unlabeled.sum() * N
)
# w = np.ones(n) if w is None else w / w.sum() * n
w = construct_weight_vector(n, w, vectorized=True)
w_unlabeled = construct_weight_vector(N, w_unlabeled, vectorized=True)

Y = reshape_to_2d(Y)
Yhat = reshape_to_2d(Yhat)
Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)
d = Y.shape[1]

if lhat is None:
if len(Y.shape) > 1 and Y.shape[1] > 1:
lhat = 1
else:
ppi_pointest = (w_unlabeled * Yhat_unlabeled).mean() + (
w * (Y - Yhat)
).mean()
grads = w * (Y - ppi_pointest)
grads_hat = w * (Yhat - ppi_pointest)
grads_hat_unlabeled = w_unlabeled * (Yhat_unlabeled - ppi_pointest)
inv_hessian = np.ones((1, 1))
lhat = _calc_lhat_glm(
grads, grads_hat, grads_hat_unlabeled, inv_hessian, coord=None
)
ppi_pointest = (w_unlabeled * Yhat_unlabeled).mean(0) + (
w * (Y - Yhat)
).mean(0)
grads = w * (Y - ppi_pointest)
grads_hat = w * (Yhat - ppi_pointest)
grads_hat_unlabeled = w_unlabeled * (Yhat_unlabeled - ppi_pointest)
inv_hessian = np.eye(d)
lhat = _calc_lhat_glm(
grads,
grads_hat,
grads_hat_unlabeled,
inv_hessian,
coord=None,
optim_mode=lambd_optim_mode,
)

return rectified_p_value(
(w * Y - lhat * w * Yhat).mean(),
(w * Y - lhat * w * Yhat).std() / np.sqrt(n),
(w_unlabeled * lhat * Yhat_unlabeled).mean(),
(w_unlabeled * lhat * Yhat_unlabeled).std() / np.sqrt(N),
null,
alternative,
rectifier=(w * Y - lhat * w * Yhat).mean(0),
rectifier_std=(w * Y - lhat * w * Yhat).std(0) / np.sqrt(n),
imputed_mean=(w_unlabeled * lhat * Yhat_unlabeled).mean(0),
imputed_std=(w_unlabeled * lhat * Yhat_unlabeled).std(0) / np.sqrt(N),
null=null,
alternative=alternative,
)


Expand Down Expand Up @@ -1049,7 +1068,13 @@ def ppi_logistic_ci(


def _calc_lhat_glm(
grads, grads_hat, grads_hat_unlabeled, inv_hessian, coord=None, clip=False
grads,
grads_hat,
grads_hat_unlabeled,
inv_hessian,
coord=None,
clip=False,
optim_mode="overall",
):
"""
Calculates the optimal value of lhat for the prediction-powered confidence interval for GLMs.
Expand All @@ -1059,38 +1084,41 @@ def _calc_lhat_glm(
grads_hat (ndarray): Gradient of the loss function with respect to the model parameter evaluated using predictions on the labeled data.
grads_hat_unlabeled (ndarray): Gradient of the loss function with respect to the parameter evaluated using predictions on the unlabeled data.
inv_hessian (ndarray): Inverse of the Hessian of the loss function with respect to the parameter.
coord (int, optional): Coordinate for which to optimize `lhat`. If `None`, it optimizes the total variance over all coordinates. Must be in {1, ..., d} where d is the shape of the estimand.
coord (int, optional): Coordinate for which to optimize `lhat`, when `optim_mode="overall"`.
If `None`, it optimizes the total variance over all coordinates. Must be in {1, ..., d} where d is the shape of the estimand.
clip (bool, optional): Whether to clip the value of lhat to be non-negative. Defaults to `False`.
optim_mode (ndarray, optional): Mode for which to optimize `lhat`, either `overall` or `element`.
If `overall`, it optimizes the total variance over all coordinates, and the function returns a scalar.
If `element`, it optimizes the variance for each coordinate separately, and the function returns a vector.
Returns:
float: Optimal value of `lhat`. Lies in [0,1].
"""
grads = reshape_to_2d(grads)
grads_hat = reshape_to_2d(grads_hat)
grads_hat_unlabeled = reshape_to_2d(grads_hat_unlabeled)
n = grads.shape[0]
N = grads_hat_unlabeled.shape[0]
d = inv_hessian.shape[0]
cov_grads = np.zeros((d, d))

for i in range(n):
cov_grads += (1 / n) * (
np.outer(
grads[i] - grads.mean(axis=0),
grads_hat[i] - grads_hat.mean(axis=0),
)
+ np.outer(
grads_hat[i] - grads_hat.mean(axis=0),
grads[i] - grads.mean(axis=0),
)
if grads.shape[1] != d:
raise ValueError(
"Dimension mismatch between the gradient and the inverse Hessian."
)

grads_cent = grads - grads.mean(axis=0)
grad_hat_cent = grads_hat - grads_hat.mean(axis=0)
cov_grads = (1 / n) * (
grads_cent.T @ grad_hat_cent + grad_hat_cent.T @ grads_cent
)

var_grads_hat = np.cov(
np.concatenate([grads_hat, grads_hat_unlabeled], axis=0).T
)
var_grads_hat = var_grads_hat.reshape(d, d)

if coord is None:
vhat = inv_hessian
else:
vhat = inv_hessian @ np.eye(d)[coord]

if d > 1:
vhat = inv_hessian if coord is None else inv_hessian[coord, coord]
if optim_mode == "overall":
num = (
np.trace(vhat @ cov_grads @ vhat)
if coord is None
Expand All @@ -1101,14 +1129,19 @@ def _calc_lhat_glm(
if coord is None
else 2 * (1 + (n / N)) * vhat @ var_grads_hat @ vhat
)
lhat = num / denom
lhat = lhat.item()
elif optim_mode == "element":
num = np.diag(vhat @ cov_grads @ vhat)
denom = 2 * (1 + (n / N)) * np.diag(vhat @ var_grads_hat @ vhat)
lhat = num / denom
else:
num = vhat * cov_grads * vhat
denom = 2 * (1 + (n / N)) * vhat * var_grads_hat * vhat

lhat = num / denom
raise ValueError(
"Invalid value for optim_mode. Must be either 'overall' or 'element'."
)
if clip:
lhat = np.clip(lhat, 0, 1)
return lhat.item()
return lhat


"""
Expand Down
16 changes: 16 additions & 0 deletions ppi_py/utils/statistics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
from scipy.optimize import brentq


def construct_weight_vector(n_obs, existing_weight, vectorized=False):
res = (
np.ones(n_obs)
if existing_weight is None
else existing_weight / existing_weight.sum() * n_obs
)
if vectorized and (len(res.shape) == 1):
res = res[:, None]
return res


def reshape_to_2d(x):
"""Reshapes a 1D array to a 2D array."""
return x.reshape(-1, 1) if len(x.shape) == 1 else x.copy()


@njit
def safe_expit(x):
"""Computes the sigmoid function in a numerically stable way."""
Expand Down
43 changes: 43 additions & 0 deletions tests/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,49 @@ def test_ppi_mean_ci():
assert not failed


def test_ppi_mean_multid():
trials = 100
alphas = np.array([0.5, 0.2, 0.1, 0.05, 0.01])
n_alphas = alphas.shape[0]
n_dims = 5
epsilon = 0.1
includeds = np.zeros((n_alphas, n_dims))
for _ in range(trials):
Y = np.random.normal(0, 1, (10000, n_dims))
Yhat = np.random.normal(-2, 1, (10000, n_dims))
Yhat_unlabeled = np.random.normal(-2, 1, (10000, n_dims))
for j in range(alphas.shape[0]):
ci = ppi_mean_ci(Y, Yhat, Yhat_unlabeled, alpha=alphas[j])

included = (ci[0] <= 0) & (ci[1] >= 0)
includeds[j] += included.astype(int)
failed = np.any(includeds / trials < 1 - alphas - epsilon)
assert not failed


def test_ppi_mean_elem():
alpha = 0.1
Y = np.random.normal(0, 1, 10000)
Yhat = np.random.normal(-2, 1, 10000)
Yhat_unlabeled = np.random.normal(-2, 1, 10000)

ppi_mean_pointestimate(Y, Yhat, Yhat_unlabeled, lambd_optim_mode="element")
ppi_mean_ci(
Y, Yhat, Yhat_unlabeled, alpha=alpha, lambd_optim_mode="element"
)
ppi_mean_pval(Y, Yhat, Yhat_unlabeled, lambd_optim_mode="element")

Y = np.random.normal(0, 1, (10000, 5))
Yhat = np.random.normal(-2, 1, (10000, 5))
Yhat_unlabeled = np.random.normal(-2, 1, (10000, 5))

ppi_mean_pointestimate(Y, Yhat, Yhat_unlabeled, lambd_optim_mode="element")
ppi_mean_ci(
Y, Yhat, Yhat_unlabeled, alpha=alpha, lambd_optim_mode="element"
)
ppi_mean_pval(Y, Yhat, Yhat_unlabeled, lambd_optim_mode="element")


def test_ppi_mean_pval():
trials = 1000
alphas = np.array([0.5, 0.2, 0.1, 0.05, 0.01])
Expand Down

0 comments on commit 493512f

Please sign in to comment.