Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for batch size and multi processing to evaluate_detections #5376

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 65 additions & 32 deletions fiftyone/utils/eval/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
import inspect
import itertools
import logging
import multiprocessing as mp

import numpy as np

Expand Down Expand Up @@ -41,6 +43,8 @@ def evaluate_detections(
classwise=True,
dynamic=True,
progress=None,
batch_size=None,
num_workers=None,
**kwargs,
):
"""Evaluates the predicted detections in the given samples with respect to
Expand Down Expand Up @@ -135,6 +139,9 @@ def evaluate_detections(
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
batch_size (None): the batch size at which to process samples. By
default, all samples are processed in a single (1) batch
num_workers (None): number of parallel workers. Defaults to CPU count - 1
**kwargs: optional keyword arguments for the constructor of the
:class:`DetectionEvaluationConfig` being used

Expand Down Expand Up @@ -176,45 +183,50 @@ def evaluate_detections(
processing_frames = samples._is_frame_field(pred_field)
save = eval_key is not None

if save:
tp_field = "%s_tp" % eval_key
fp_field = "%s_fp" % eval_key
fn_field = "%s_fn" % eval_key

if config.requires_additional_fields:
_samples = samples
else:
_samples = samples.select_fields([gt_field, pred_field])

# Determine number of workers
if num_workers is None:
num_workers = max(1, mp.cpu_count() - 1)

logger.info(
f"Evaluating detections in parallel with {num_workers} workers..."
)

matches = []
logger.info("Evaluating detections...")
for sample in _samples.iter_samples(progress=progress, autosave=save):
if processing_frames:
docs = sample.frames.values()
else:
docs = [sample]

sample_tp = 0
sample_fp = 0
sample_fn = 0
for doc in docs:
doc_matches = eval_method.evaluate(doc, eval_key=eval_key)
matches.extend(doc_matches)
tp, fp, fn = _tally_matches(doc_matches)
sample_tp += tp
sample_fp += fp
sample_fn += fn

if processing_frames and save:
doc[tp_field] = tp
doc[fp_field] = fp
doc[fn_field] = fn

if save:
sample[tp_field] = sample_tp
sample[fp_field] = sample_fp
sample[fn_field] = sample_fn
# Create a thread pool
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []

# Submit samples to the worker pool
for sample in _samples.iter_samples(
progress=progress, batch_size=batch_size, autosave=save
):
if processing_frames:
docs = sample.frames.values()
else:
docs = [sample]

future = executor.submit(
_process_sample, docs, eval_method, eval_key, processing_frames
)
futures.append((future, sample))

# Collect results
for future, sample in futures:
sample_matches, sample_stats = future.result()
matches.extend(sample_matches)

if save:
sample_tp, sample_fp, sample_fn = sample_stats
sample[f"{eval_key}_tp"] = sample_tp
sample[f"{eval_key}_fp"] = sample_fp
sample[f"{eval_key}_fn"] = sample_fn

# Generate and save final results
results = eval_method.generate_results(
samples,
matches,
Expand All @@ -228,6 +240,27 @@ def evaluate_detections(
return results


def _process_sample(docs, eval_method, eval_key, processing_frames):
"""Process a single sample or its frames."""
matches = []
sample_tp = sample_fp = sample_fn = 0

for doc in docs:
doc_matches = eval_method.evaluate(doc, eval_key=eval_key)
matches.extend(doc_matches)
tp, fp, fn = _tally_matches(doc_matches)
sample_tp += tp
sample_fp += fp
sample_fn += fn

if processing_frames and eval_key is not None:
doc[f"{eval_key}_tp"] = tp
doc[f"{eval_key}_fp"] = fp
doc[f"{eval_key}_fn"] = fn

return matches, (sample_tp, sample_fp, sample_fn)


class DetectionEvaluationConfig(foe.EvaluationMethodConfig):
"""Base class for configuring :class:`DetectionEvaluation` instances.

Expand Down
Loading