Skip to content

Commit

Permalink
use mask targets in model evaluation panel
Browse files Browse the repository at this point in the history
  • Loading branch information
imanjra committed Dec 17, 2024
1 parent 09bb793 commit a78dd41
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,32 @@ export default function Evaluation(props: EvaluationProps) {
const evaluation = data?.[`evaluation_${compareKey}_error`];
return evaluation;
}, [data]);
const evaluationMaskTargets = useMemo(() => {
return evaluation?.mask_targets || {};
}, [evaluation]);
const compareEvaluationMaskTargets = useMemo(() => {
return compareEvaluation?.mask_targets || {};
}, [compareEvaluation]);
const confusionMatrix = useMemo(() => {
return getMatrix(evaluation?.confusion_matrices, confusionMatrixConfig);
}, [evaluation, confusionMatrixConfig]);
return getMatrix(
evaluation?.confusion_matrices,
confusionMatrixConfig,
evaluationMaskTargets
);
}, [evaluation, confusionMatrixConfig, evaluationMaskTargets]);
const compareConfusionMatrix = useMemo(() => {
return getMatrix(
compareEvaluation?.confusion_matrices,
confusionMatrixConfig
confusionMatrixConfig,
evaluationMaskTargets,
compareEvaluationMaskTargets
);
}, [compareEvaluation, confusionMatrixConfig]);
}, [
compareEvaluation,
confusionMatrixConfig,
evaluationMaskTargets,
compareEvaluationMaskTargets,
]);
const compareKeys = useMemo(() => {
const keys: string[] = [];
const evaluations = data?.evaluations || [];
Expand Down Expand Up @@ -452,9 +469,12 @@ export default function Evaluation(props: EvaluationProps) {
if (!perClassPerformance[metric]) {
perClassPerformance[metric] = [];
}
const maskTarget = evaluationMaskTargets?.[key];
const compareMaskTarget = compareEvaluationMaskTargets?.[key];
perClassPerformance[metric].push({
id: key,
property: key,
property: maskTarget || key,
compareProperty: compareMaskTarget || maskTarget || key,
value: metrics[metric],
compareValue: compareMetrics[metric],
});
Expand Down Expand Up @@ -1059,7 +1079,10 @@ export default function Evaluation(props: EvaluationProps) {
y: classPerformance.map(
(metrics) => metrics.compareValue
),
x: classPerformance.map((metrics) => metrics.property),
x: classPerformance.map(
(metrics) =>
metrics.compareProperty || metrics.property
),
type: "histogram",
name: `${CLASS_LABELS[performanceClass]} per class`,
marker: {
Expand Down Expand Up @@ -1218,6 +1241,10 @@ export default function Evaluation(props: EvaluationProps) {
layout={{
yaxis: {
autorange: "reversed",
type: "category",
},
xaxis: {
type: "category",
},
}}
/>
Expand Down Expand Up @@ -1258,6 +1285,10 @@ export default function Evaluation(props: EvaluationProps) {
layout={{
yaxis: {
autorange: "reversed",
type: "category",
},
xaxis: {
type: "category",
},
}}
/>
Expand Down Expand Up @@ -1613,14 +1644,17 @@ function formatPerClassPerformance(perClassPerformance, barConfig) {
return computedPerClassPerformance;
}

function getMatrix(matrices, config) {
function getMatrix(matrices, config, maskTargets, compareMaskTargets?) {
if (!matrices) return;
const { sortBy = "az", limit } = config;
const parsedLimit = typeof limit === "number" ? limit : undefined;
const classes = matrices[`${sortBy}_classes`].slice(0, parsedLimit);
const matrix = matrices[`${sortBy}_matrix`].slice(0, parsedLimit);
const colorscale = matrices[`${sortBy}_colorscale`];
return { labels: classes, matrix, colorscale };
const labels = classes.map((c) => {
return compareMaskTargets?.[c] || maskTargets?.[c] || c;
});
return { labels, matrix, colorscale };
}

function getConfigLabel({ config, type, dashed }) {
Expand Down
19 changes: 18 additions & 1 deletion fiftyone/operators/builtins/panels/model_evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ def get_confusion_matrices(self, results):
"lc_colorscale": lc_colorscale,
}

def get_mask_targets(self, dataset, gt_field):
mask_targets = dataset.mask_targets.get(gt_field, None)
if mask_targets:
return mask_targets

if dataset.default_mask_targets:
return dataset.default_mask_targets

return None

def load_evaluation(self, ctx):
view_state = ctx.panel.get_state("view") or {}
eval_key = view_state.get("key")
Expand All @@ -300,14 +310,20 @@ def load_evaluation(self, ctx):
)
if evaluation_data is None:
info = ctx.dataset.get_evaluation_info(computed_eval_key)
serialized_info = info.serialize()
evaluation_type = info.config.type
if evaluation_type not in SUPPORTED_EVALUATION_TYPES:
ctx.panel.set_data(
f"evaluation_{computed_eval_key}_error",
{"error": "unsupported", "info": serialized_info},
)
return
serialized_info = info.serialize()
gt_field = info.config.gt_field
mask_targets = (
self.get_mask_targets(ctx.dataset, gt_field)
if evaluation_type == "segmentation"
else None
)
results = ctx.dataset.load_evaluation_results(computed_eval_key)
metrics = results.metrics()
per_class_metrics = self.get_per_class_metrics(info, results)
Expand All @@ -323,6 +339,7 @@ def load_evaluation(self, ctx):
"info": serialized_info,
"confusion_matrices": self.get_confusion_matrices(results),
"per_class_metrics": per_class_metrics,
"mask_targets": mask_targets,
}
if ENABLE_CACHING:
# Cache the evaluation data
Expand Down

0 comments on commit a78dd41

Please sign in to comment.