-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation_gde_based.py
207 lines (178 loc) · 7.67 KB
/
evaluation_gde_based.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from classification.load_model_and_config import get_config_data_model_for_eval
import numpy as np
import pandas as pd
from tabulate import tabulate
from pathlib import Path
from evaluation.confidence_estimates import ConfidenceBasedAccuracyEstimator
from evaluation.distance_checker import DistanceChecker
from evaluation.inference_utils import (
get_train_and_val_predictions,
get_ood_predictions,
)
from collections import defaultdict
from itertools import combinations, permutations
from scipy.stats import wilcoxon
from yacs.config import CfgNode
from typing import Union
def run_evaluation_agreement(config_name_or_path: Union[CfgNode, str], dataset: str) -> pd.DataFrame:
"""
Run evaluation of agreement based accuracy estimation (GDE) for a given training configuration.
Assumes there is a least 2 models trained with this configuration available (two different seeds).
"""
config, data_modules, models, output_dirs = get_config_data_model_for_eval(config_name_or_path, dataset)
config.dataset = dataset
metrics = pd.DataFrame()
# Need to have a least two trained models for agreement
# based accuracy estimation
if len(models) <= 1:
return
ood_results = defaultdict(list)
kept_by_distance = defaultdict(list)
kept_by_cs_distance = defaultdict(list)
for model, output_dir, data_module in zip(models, output_dirs, data_modules):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_dir = Path(output_dir)
train_results, val_results = get_train_and_val_predictions(output_dir, dataset, data_module, model)
# Fit Distance Checker
distance_checker = DistanceChecker(output_dir)
distance_checker.fit(train_results, val_results)
# Fit TS, ATC, DOC
accuracy_estimator = ConfidenceBasedAccuracyEstimator()
accuracy_estimator.fit(val_results)
for (name_eval_loader, eval_loader) in data_module.get_all_ood_dataloaders():
ood_results[name_eval_loader].append(
get_ood_predictions(
eval_loader,
name_eval_loader,
model,
output_dir,
accuracy_estimator.ts,
accuracy_estimator.cs_ts,
)
)
# Get DIST-estimate
kept_by_distance[name_eval_loader].append(
distance_checker.get_kept_samples(ood_results[name_eval_loader][-1], name_eval_loader)
)
if distance_checker.adt_cs is not None:
kept_by_cs_distance[name_eval_loader].append(
distance_checker.get_kept_cs_samples(ood_results[name_eval_loader][-1], name_eval_loader)
)
for name_eval_loader in ood_results.keys():
for i, j in combinations(np.arange(len(ood_results[name_eval_loader])), 2):
for ref, aux in permutations([i, j]):
current_metrics_dict = {
"dataset": name_eval_loader,
"ref": config.seed[ref],
"aux": config.seed[aux],
}
current_metrics_dict["accuracy"] = (
(
ood_results[name_eval_loader][ref]["predictions"]
== ood_results[name_eval_loader][ref]["targets"]
)
.float()
.mean()
.item()
)
current_metrics_dict["predicted_aggreement"] = (
(
ood_results[name_eval_loader][ref]["predictions"]
== ood_results[name_eval_loader][aux]["predictions"]
)
.float()
.mean()
.item()
)
current_metrics_dict["predicted_aggreement_w_dist"] = (
(
(
ood_results[name_eval_loader][ref]["predictions"]
== ood_results[name_eval_loader][aux]["predictions"]
)
& kept_by_distance[name_eval_loader][ref]
)
.float()
.mean()
.item()
)
if distance_checker.adt_cs is not None:
current_metrics_dict["predicted_aggreement_w_csdist"] = (
(
(
ood_results[name_eval_loader][ref]["predictions"]
== ood_results[name_eval_loader][aux]["predictions"]
)
& kept_by_cs_distance[name_eval_loader][ref]
)
.float()
.mean()
.item()
)
current_metrics = pd.DataFrame(current_metrics_dict, index=[0])
for c in current_metrics.columns:
if c.startswith("predicted"):
current_metrics[f"error_{c[10:]}"] = current_metrics[c].apply(
lambda x: np.abs(x - current_metrics_dict["accuracy"]) if x is not None else np.nan
)
metrics = pd.concat([metrics, current_metrics], ignore_index=True)
error_cols = [i for i in current_metrics.columns if i.startswith("error")]
print(
tabulate(
metrics[["dataset"] + error_cols]
.groupby("dataset")
.aggregate(func=lambda x: np.nanmean(x * 100))
.dropna(),
headers="keys",
)
)
metrics.to_csv(output_dir.parent / "metrics_agreement.csv")
return metrics
def is_significant(ref, new):
if (ref - new).sum() == 0:
return ""
p = wilcoxon(ref, new, nan_policy="omit", alternative="two-sided")[1]
if p < 1e-3:
return f" **({p:.0E})"
if p < 0.05:
return f" *({p:.0E})"
return f" ({p:.3f})"
if __name__ == "__main__":
"""
Main script to run the analysis on GDE versus GDE+DistCS.
Usage:
python evaluation/run_evaluation_agreement_based.py --dataset [TEST_DATASET]
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
dest="dataset",
type=str,
required=True,
)
args = parser.parse_args()
all_metrics = pd.DataFrame()
if args.dataset != "imagenet":
config_dir = Path(__file__).parent.parent / "classification" / "configs" / "general"
for config_file in config_dir.glob("scratch/*.yml"):
print(config_file)
metrics = run_evaluation_agreement(config_file, args.dataset)
all_metrics = pd.concat([all_metrics, metrics], ignore_index=True)
if args.dataset not in ["entity13", "living17", "nonliving26"]:
for config_file in config_dir.glob("pretrained/*.yml"):
print(config_file)
metrics = run_evaluation_agreement(config_file, args.dataset)
all_metrics = pd.concat([all_metrics, metrics], ignore_index=True)
else:
raise ValueError(
"Can't run this analysis with ImageNet" + "we only have one seed for each trained model in timm."
)
error_cols = ["error_aggreement", "error_aggreement_w_csdist"]
results = all_metrics[error_cols].aggregate(
func=lambda x: f"{np.nanmean(x * 100, keepdims=True)[0]:.2f}"
+ is_significant(all_metrics["error_aggreement"].values, x)
)
results.to_csv(f"/data/performance_estimation/outputs/{args.dataset}/aggrement_summary.csv")
print(tabulate(results))