diff --git a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx index 819fdbd5b5..cb22817d63 100644 --- a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx +++ b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx @@ -99,15 +99,32 @@ export default function Evaluation(props: EvaluationProps) { const evaluation = data?.[`evaluation_${compareKey}_error`]; return evaluation; }, [data]); + const evaluationMaskTargets = useMemo(() => { + return evaluation?.mask_targets || {}; + }, [evaluation]); + const compareEvaluationMaskTargets = useMemo(() => { + return compareEvaluation?.mask_targets || {}; + }, [compareEvaluation]); const confusionMatrix = useMemo(() => { - return getMatrix(evaluation?.confusion_matrices, confusionMatrixConfig); - }, [evaluation, confusionMatrixConfig]); + return getMatrix( + evaluation?.confusion_matrices, + confusionMatrixConfig, + evaluationMaskTargets + ); + }, [evaluation, confusionMatrixConfig, evaluationMaskTargets]); const compareConfusionMatrix = useMemo(() => { return getMatrix( compareEvaluation?.confusion_matrices, - confusionMatrixConfig + confusionMatrixConfig, + evaluationMaskTargets, + compareEvaluationMaskTargets ); - }, [compareEvaluation, confusionMatrixConfig]); + }, [ + compareEvaluation, + confusionMatrixConfig, + evaluationMaskTargets, + compareEvaluationMaskTargets, + ]); const compareKeys = useMemo(() => { const keys: string[] = []; const evaluations = data?.evaluations || []; @@ -452,9 +469,12 @@ export default function Evaluation(props: EvaluationProps) { if (!perClassPerformance[metric]) { perClassPerformance[metric] = []; } + const maskTarget = evaluationMaskTargets?.[key]; + const compareMaskTarget = compareEvaluationMaskTargets?.[key]; perClassPerformance[metric].push({ id: key, - property: key, + property: maskTarget || key, + compareProperty: compareMaskTarget || maskTarget || key, value: metrics[metric], compareValue: compareMetrics[metric], }); @@ -1059,7 +1079,10 @@ export default function Evaluation(props: EvaluationProps) { y: classPerformance.map( (metrics) => metrics.compareValue ), - x: classPerformance.map((metrics) => metrics.property), + x: classPerformance.map( + (metrics) => + metrics.compareProperty || metrics.property + ), type: "histogram", name: `${CLASS_LABELS[performanceClass]} per class`, marker: { @@ -1218,6 +1241,10 @@ export default function Evaluation(props: EvaluationProps) { layout={{ yaxis: { autorange: "reversed", + type: "category", + }, + xaxis: { + type: "category", }, }} /> @@ -1258,6 +1285,10 @@ export default function Evaluation(props: EvaluationProps) { layout={{ yaxis: { autorange: "reversed", + type: "category", + }, + xaxis: { + type: "category", }, }} /> @@ -1613,14 +1644,17 @@ function formatPerClassPerformance(perClassPerformance, barConfig) { return computedPerClassPerformance; } -function getMatrix(matrices, config) { +function getMatrix(matrices, config, maskTargets, compareMaskTargets?) { if (!matrices) return; const { sortBy = "az", limit } = config; const parsedLimit = typeof limit === "number" ? limit : undefined; const classes = matrices[`${sortBy}_classes`].slice(0, parsedLimit); const matrix = matrices[`${sortBy}_matrix`].slice(0, parsedLimit); const colorscale = matrices[`${sortBy}_colorscale`]; - return { labels: classes, matrix, colorscale }; + const labels = classes.map((c) => { + return compareMaskTargets?.[c] || maskTargets?.[c] || c; + }); + return { labels, matrix, colorscale }; } function getConfigLabel({ config, type, dashed }) { diff --git a/fiftyone/operators/builtins/panels/model_evaluation/__init__.py b/fiftyone/operators/builtins/panels/model_evaluation/__init__.py index ef8d9b1c47..7b27aa45a8 100644 --- a/fiftyone/operators/builtins/panels/model_evaluation/__init__.py +++ b/fiftyone/operators/builtins/panels/model_evaluation/__init__.py @@ -288,6 +288,16 @@ def get_confusion_matrices(self, results): "lc_colorscale": lc_colorscale, } + def get_mask_targets(self, dataset, gt_field): + mask_targets = dataset.mask_targets.get(gt_field, None) + if mask_targets: + return mask_targets + + if dataset.default_mask_targets: + return dataset.default_mask_targets + + return None + def load_evaluation(self, ctx): view_state = ctx.panel.get_state("view") or {} eval_key = view_state.get("key") @@ -300,7 +310,6 @@ def load_evaluation(self, ctx): ) if evaluation_data is None: info = ctx.dataset.get_evaluation_info(computed_eval_key) - serialized_info = info.serialize() evaluation_type = info.config.type if evaluation_type not in SUPPORTED_EVALUATION_TYPES: ctx.panel.set_data( @@ -308,6 +317,13 @@ def load_evaluation(self, ctx): {"error": "unsupported", "info": serialized_info}, ) return + serialized_info = info.serialize() + gt_field = info.config.gt_field + mask_targets = ( + self.get_mask_targets(ctx.dataset, gt_field) + if evaluation_type == "segmentation" + else None + ) results = ctx.dataset.load_evaluation_results(computed_eval_key) metrics = results.metrics() per_class_metrics = self.get_per_class_metrics(info, results) @@ -323,6 +339,7 @@ def load_evaluation(self, ctx): "info": serialized_info, "confusion_matrices": self.get_confusion_matrices(results), "per_class_metrics": per_class_metrics, + "mask_targets": mask_targets, } if ENABLE_CACHING: # Cache the evaluation data