From 686be45e78255804c3a2ad7c42c24b13f82b9ece Mon Sep 17 00:00:00 2001 From: brimoor Date: Fri, 13 Dec 2024 01:15:18 -0500 Subject: [PATCH 1/2] fix #5254 --- .../panels/model_evaluation/__init__.py | 58 ++++++++++--------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/fiftyone/operators/builtins/panels/model_evaluation/__init__.py b/fiftyone/operators/builtins/panels/model_evaluation/__init__.py index cb33082f9d..b91efbe01c 100644 --- a/fiftyone/operators/builtins/panels/model_evaluation/__init__.py +++ b/fiftyone/operators/builtins/panels/model_evaluation/__init__.py @@ -5,16 +5,17 @@ | `voxel51.com `_ | """ - +from collections import defaultdict, Counter import os import traceback -import fiftyone.operators.types as types -from collections import defaultdict, Counter +import numpy as np + from fiftyone import ViewField as F from fiftyone.operators.categories import Categories from fiftyone.operators.panel import Panel, PanelConfig from fiftyone.core.plots.plotly import _to_log_colorscale +import fiftyone.operators.types as types STORE_NAME = "model_evaluation_panel_builtin" @@ -104,29 +105,32 @@ def get_avg_confidence(self, per_class_metrics): total += metrics["confidence"] return total / count if count > 0 else None - def get_tp_fp_fn(self, ctx): - view_state = ctx.panel.get_state("view") or {} - key = view_state.get("key") - dataset = ctx.dataset - tp_key = f"{key}_tp" - fp_key = f"{key}_fp" - fn_key = f"{key}_fn" - tp_total = ( - sum(ctx.dataset.values(tp_key)) - if dataset.has_field(tp_key) - else None - ) - fp_total = ( - sum(ctx.dataset.values(fp_key)) - if dataset.has_field(fp_key) - else None - ) - fn_total = ( - sum(ctx.dataset.values(fn_key)) - if dataset.has_field(fn_key) - else None - ) - return tp_total, fp_total, fn_total + def get_tp_fp_fn(self, info, results): + # Binary classification + if ( + info.config.type == "classification" + and info.config.method == "binary" + ): + neg_label, pos_label = results.classes + tp_count = np.count_nonzero( + (results.ytrue == pos_label) & (results.ypred == pos_label) + ) + fp_count = np.count_nonzero( + (results.ytrue != pos_label) & (results.ypred == pos_label) + ) + fn_count = np.count_nonzero( + (results.ytrue == pos_label) & (results.ypred != pos_label) + ) + return tp_count, fp_count, fn_count + + # Object detection + if info.config.type == "detection": + tp_count = np.count_nonzero(results.ytrue == results.ypred) + fp_count = np.count_nonzero(results.ytrue == results.missing) + fn_count = np.count_nonzero(results.ypred == results.missing) + return tp_count, fp_count, fn_count + + return None, None, None def get_map(self, results): try: @@ -298,7 +302,7 @@ def load_evaluation(self, ctx): per_class_metrics ) metrics["tp"], metrics["fp"], metrics["fn"] = self.get_tp_fp_fn( - ctx + info, results ) metrics["mAP"] = self.get_map(results) evaluation_data = { From d41fffce5d5ba8a34361bf3de2f3b5ad238eca8a Mon Sep 17 00:00:00 2001 From: imanjra Date: Fri, 13 Dec 2024 14:51:13 -0500 Subject: [PATCH 2/2] TP/FP/NP support for binary classification model evaluation --- .../NativeModelEvaluationView/Evaluation.tsx | 10 +++++--- .../panels/model_evaluation/__init__.py | 25 +++++++++++++------ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx index d7f932b23f..c3ee377dab 100644 --- a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx +++ b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx @@ -166,6 +166,7 @@ export default function Evaluation(props: EvaluationProps) { const evaluationConfig = evaluationInfo.config; const evaluationMetrics = evaluation.metrics; const evaluationType = evaluationConfig.type; + const evaluationMethod = evaluationConfig.method; const compareEvaluationInfo = compareEvaluation?.info || {}; const compareEvaluationKey = compareEvaluationInfo?.key; const compareEvaluationTimestamp = compareEvaluationInfo?.timestamp; @@ -174,6 +175,9 @@ export default function Evaluation(props: EvaluationProps) { const compareEvaluationType = compareEvaluationConfig.type; const isObjectDetection = evaluationType === "detection"; const isSegmentation = evaluationType === "segmentation"; + const isBinaryClassification = + evaluationType === "classification" && evaluationMethod === "binary"; + const showTpFpFn = isObjectDetection || isBinaryClassification; const infoRows = [ { id: "evaluation_key", @@ -385,7 +389,7 @@ export default function Evaluation(props: EvaluationProps) { ? "compare" : "selected" : false, - hide: !isObjectDetection, + hide: !showTpFpFn, }, { id: "fp", @@ -400,7 +404,7 @@ export default function Evaluation(props: EvaluationProps) { ? "compare" : "selected" : false, - hide: !isObjectDetection, + hide: !showTpFpFn, }, { id: "fn", @@ -415,7 +419,7 @@ export default function Evaluation(props: EvaluationProps) { ? "compare" : "selected" : false, - hide: !isObjectDetection, + hide: !showTpFpFn, }, ]; diff --git a/fiftyone/operators/builtins/panels/model_evaluation/__init__.py b/fiftyone/operators/builtins/panels/model_evaluation/__init__.py index b91efbe01c..e8a1aff301 100644 --- a/fiftyone/operators/builtins/panels/model_evaluation/__init__.py +++ b/fiftyone/operators/builtins/panels/model_evaluation/__init__.py @@ -5,6 +5,7 @@ | `voxel51.com `_ | """ + from collections import defaultdict, Counter import os import traceback @@ -96,6 +97,12 @@ def on_load(self, ctx): ctx.panel.set_data("permissions", permissions) self.load_pending_evaluations(ctx) + def is_binary_classification(self, info): + return ( + info.config.type == "classification" + and info.config.method == "binary" + ) + def get_avg_confidence(self, per_class_metrics): count = 0 total = 0 @@ -107,10 +114,7 @@ def get_avg_confidence(self, per_class_metrics): def get_tp_fp_fn(self, info, results): # Binary classification - if ( - info.config.type == "classification" - and info.config.method == "binary" - ): + if self.is_binary_classification(info): neg_label, pos_label = results.classes tp_count = np.count_nonzero( (results.ytrue == pos_label) & (results.ypred == pos_label) @@ -422,10 +426,15 @@ def load_view(self, ctx): gt_field, F("label") == y ).filter_labels(pred_field, F("label") == x) elif view_type == "field": - view = ctx.dataset.filter_labels( - pred_field, F(computed_eval_key) == field - ) - + if self.is_binary_classification(info): + uppercase_field = field.upper() + view = ctx.dataset.match( + {computed_eval_key: {"$eq": uppercase_field}} + ) + else: + view = ctx.dataset.filter_labels( + pred_field, F(computed_eval_key) == field + ) if view is not None: ctx.ops.set_view(view)