Skip to content

Commit

Permalink
TP/FP/NP support for binary classification model evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
imanjra committed Dec 13, 2024
1 parent 686be45 commit d41fffc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ export default function Evaluation(props: EvaluationProps) {
const evaluationConfig = evaluationInfo.config;
const evaluationMetrics = evaluation.metrics;
const evaluationType = evaluationConfig.type;
const evaluationMethod = evaluationConfig.method;
const compareEvaluationInfo = compareEvaluation?.info || {};
const compareEvaluationKey = compareEvaluationInfo?.key;
const compareEvaluationTimestamp = compareEvaluationInfo?.timestamp;
Expand All @@ -174,6 +175,9 @@ export default function Evaluation(props: EvaluationProps) {
const compareEvaluationType = compareEvaluationConfig.type;
const isObjectDetection = evaluationType === "detection";
const isSegmentation = evaluationType === "segmentation";
const isBinaryClassification =
evaluationType === "classification" && evaluationMethod === "binary";
const showTpFpFn = isObjectDetection || isBinaryClassification;
const infoRows = [
{
id: "evaluation_key",
Expand Down Expand Up @@ -385,7 +389,7 @@ export default function Evaluation(props: EvaluationProps) {
? "compare"
: "selected"
: false,
hide: !isObjectDetection,
hide: !showTpFpFn,
},
{
id: "fp",
Expand All @@ -400,7 +404,7 @@ export default function Evaluation(props: EvaluationProps) {
? "compare"
: "selected"
: false,
hide: !isObjectDetection,
hide: !showTpFpFn,
},
{
id: "fn",
Expand All @@ -415,7 +419,7 @@ export default function Evaluation(props: EvaluationProps) {
? "compare"
: "selected"
: false,
hide: !isObjectDetection,
hide: !showTpFpFn,
},
];

Expand Down
25 changes: 17 additions & 8 deletions fiftyone/operators/builtins/panels/model_evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""

from collections import defaultdict, Counter
import os
import traceback
Expand Down Expand Up @@ -96,6 +97,12 @@ def on_load(self, ctx):
ctx.panel.set_data("permissions", permissions)
self.load_pending_evaluations(ctx)

def is_binary_classification(self, info):
return (
info.config.type == "classification"
and info.config.method == "binary"
)

def get_avg_confidence(self, per_class_metrics):
count = 0
total = 0
Expand All @@ -107,10 +114,7 @@ def get_avg_confidence(self, per_class_metrics):

def get_tp_fp_fn(self, info, results):
# Binary classification
if (
info.config.type == "classification"
and info.config.method == "binary"
):
if self.is_binary_classification(info):
neg_label, pos_label = results.classes
tp_count = np.count_nonzero(
(results.ytrue == pos_label) & (results.ypred == pos_label)
Expand Down Expand Up @@ -422,10 +426,15 @@ def load_view(self, ctx):
gt_field, F("label") == y
).filter_labels(pred_field, F("label") == x)
elif view_type == "field":
view = ctx.dataset.filter_labels(
pred_field, F(computed_eval_key) == field
)

if self.is_binary_classification(info):
uppercase_field = field.upper()
view = ctx.dataset.match(
{computed_eval_key: {"$eq": uppercase_field}}
)
else:
view = ctx.dataset.filter_labels(
pred_field, F(computed_eval_key) == field
)
if view is not None:
ctx.ops.set_view(view)

Expand Down

0 comments on commit d41fffc

Please sign in to comment.