Skip to content

Commit

Permalink
Merge pull request #5267 from voxel51/model-eval-fixes
Browse files Browse the repository at this point in the history
Fixing #5254
  • Loading branch information
brimoor authored Dec 13, 2024
2 parents 568da8a + d41fffc commit f627b82
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ export default function Evaluation(props: EvaluationProps) {
const evaluationConfig = evaluationInfo.config;
const evaluationMetrics = evaluation.metrics;
const evaluationType = evaluationConfig.type;
const evaluationMethod = evaluationConfig.method;
const compareEvaluationInfo = compareEvaluation?.info || {};
const compareEvaluationKey = compareEvaluationInfo?.key;
const compareEvaluationTimestamp = compareEvaluationInfo?.timestamp;
Expand All @@ -174,6 +175,9 @@ export default function Evaluation(props: EvaluationProps) {
const compareEvaluationType = compareEvaluationConfig.type;
const isObjectDetection = evaluationType === "detection";
const isSegmentation = evaluationType === "segmentation";
const isBinaryClassification =
evaluationType === "classification" && evaluationMethod === "binary";
const showTpFpFn = isObjectDetection || isBinaryClassification;
const infoRows = [
{
id: "evaluation_key",
Expand Down Expand Up @@ -385,7 +389,7 @@ export default function Evaluation(props: EvaluationProps) {
? "compare"
: "selected"
: false,
hide: !isObjectDetection,
hide: !showTpFpFn,
},
{
id: "fp",
Expand All @@ -400,7 +404,7 @@ export default function Evaluation(props: EvaluationProps) {
? "compare"
: "selected"
: false,
hide: !isObjectDetection,
hide: !showTpFpFn,
},
{
id: "fn",
Expand All @@ -415,7 +419,7 @@ export default function Evaluation(props: EvaluationProps) {
? "compare"
: "selected"
: false,
hide: !isObjectDetection,
hide: !showTpFpFn,
},
];

Expand Down
73 changes: 43 additions & 30 deletions fiftyone/operators/builtins/panels/model_evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
|
"""

from collections import defaultdict, Counter
import os
import traceback
import fiftyone.operators.types as types

from collections import defaultdict, Counter
import numpy as np

from fiftyone import ViewField as F
from fiftyone.operators.categories import Categories
from fiftyone.operators.panel import Panel, PanelConfig
from fiftyone.core.plots.plotly import _to_log_colorscale
import fiftyone.operators.types as types


STORE_NAME = "model_evaluation_panel_builtin"
Expand Down Expand Up @@ -95,6 +97,12 @@ def on_load(self, ctx):
ctx.panel.set_data("permissions", permissions)
self.load_pending_evaluations(ctx)

def is_binary_classification(self, info):
return (
info.config.type == "classification"
and info.config.method == "binary"
)

def get_avg_confidence(self, per_class_metrics):
count = 0
total = 0
Expand All @@ -104,29 +112,29 @@ def get_avg_confidence(self, per_class_metrics):
total += metrics["confidence"]
return total / count if count > 0 else None

def get_tp_fp_fn(self, ctx):
view_state = ctx.panel.get_state("view") or {}
key = view_state.get("key")
dataset = ctx.dataset
tp_key = f"{key}_tp"
fp_key = f"{key}_fp"
fn_key = f"{key}_fn"
tp_total = (
sum(ctx.dataset.values(tp_key))
if dataset.has_field(tp_key)
else None
)
fp_total = (
sum(ctx.dataset.values(fp_key))
if dataset.has_field(fp_key)
else None
)
fn_total = (
sum(ctx.dataset.values(fn_key))
if dataset.has_field(fn_key)
else None
)
return tp_total, fp_total, fn_total
def get_tp_fp_fn(self, info, results):
# Binary classification
if self.is_binary_classification(info):
neg_label, pos_label = results.classes
tp_count = np.count_nonzero(
(results.ytrue == pos_label) & (results.ypred == pos_label)
)
fp_count = np.count_nonzero(
(results.ytrue != pos_label) & (results.ypred == pos_label)
)
fn_count = np.count_nonzero(
(results.ytrue == pos_label) & (results.ypred != pos_label)
)
return tp_count, fp_count, fn_count

# Object detection
if info.config.type == "detection":
tp_count = np.count_nonzero(results.ytrue == results.ypred)
fp_count = np.count_nonzero(results.ytrue == results.missing)
fn_count = np.count_nonzero(results.ypred == results.missing)
return tp_count, fp_count, fn_count

return None, None, None

def get_map(self, results):
try:
Expand Down Expand Up @@ -298,7 +306,7 @@ def load_evaluation(self, ctx):
per_class_metrics
)
metrics["tp"], metrics["fp"], metrics["fn"] = self.get_tp_fp_fn(
ctx
info, results
)
metrics["mAP"] = self.get_map(results)
evaluation_data = {
Expand Down Expand Up @@ -418,10 +426,15 @@ def load_view(self, ctx):
gt_field, F("label") == y
).filter_labels(pred_field, F("label") == x)
elif view_type == "field":
view = ctx.dataset.filter_labels(
pred_field, F(computed_eval_key) == field
)

if self.is_binary_classification(info):
uppercase_field = field.upper()
view = ctx.dataset.match(
{computed_eval_key: {"$eq": uppercase_field}}
)
else:
view = ctx.dataset.filter_labels(
pred_field, F(computed_eval_key) == field
)
if view is not None:
ctx.ops.set_view(view)

Expand Down

0 comments on commit f627b82

Please sign in to comment.