Skip to content

Commit

Permalink
Merge pull request #5289 from voxel51/fix/detections-sources
Browse files Browse the repository at this point in the history
add support for collection overlays
  • Loading branch information
sashankaryal authored Dec 18, 2024
2 parents ec2c2e8 + bc44e3e commit 2115658
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 24 deletions.
37 changes: 28 additions & 9 deletions app/packages/looker/src/worker/disk-overlay-decoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,29 @@ export const decodeOverlayOnDisk = async (
sources: { [path: string]: string },
cls: string,
maskPathDecodingPromises: Promise<void>[] = [],
maskTargetsBuffers: ArrayBuffer[] = []
maskTargetsBuffers: ArrayBuffer[] = [],
overlayCollectionProcessingParams:
| { idx: number; cls: string }
| undefined = undefined
) => {
// handle all list types here
if (cls === DETECTIONS) {
if (cls === DETECTIONS && label.detections) {
const promises: Promise<void>[] = [];
for (const detection of label.detections) {

for (let i = 0; i < label.detections.length; i++) {
const detection = label.detections[i];
promises.push(
decodeOverlayOnDisk(
field,
detection,
coloring,
customizeColorSetting,
colorscale,
{},
sources,
DETECTION,
maskPathDecodingPromises,
maskTargetsBuffers
maskTargetsBuffers,
{ idx: i, cls: DETECTIONS }
)
);
}
Expand Down Expand Up @@ -74,16 +80,29 @@ export const decodeOverlayOnDisk = async (
return;
}

// if we have an explicit source defined from sample.urls, use that
// otherwise, use the path field from the label
let source = sources[`${field}.${overlayPathField}`];

if (typeof overlayCollectionProcessingParams !== "undefined") {
// example: for detections, we need to access the source from the parent label
// like: if field is "prediction_masks", we're trying to get "predictiion_masks.detections[INDEX].mask"
source =
sources[
`${field}.${overlayCollectionProcessingParams.cls.toLocaleLowerCase()}[${
overlayCollectionProcessingParams.idx
}].${overlayPathField}`
];
}

// convert absolute file path to a URL that we can "fetch" from
const overlayImageUrl = getSampleSrc(
sources[`${field}.${overlayPathField}`] || label[overlayPathField]
);
const overlayImageUrl = getSampleSrc(source || label[overlayPathField]);
const urlTokens = overlayImageUrl.split("?");

let baseUrl = overlayImageUrl;

// remove query params if not local URL
if (!urlTokens.at(1)?.startsWith("filepath=")) {
if (!urlTokens.at(1)?.startsWith("filepath=") && !source) {
baseUrl = overlayImageUrl.split("?")[0];
}

Expand Down
56 changes: 41 additions & 15 deletions fiftyone/server/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import typing as t

from functools import reduce
from pydash import get

import asyncio
import aiofiles
Expand All @@ -31,6 +32,8 @@
logger = logging.getLogger(__name__)

_ADDITIONAL_MEDIA_FIELDS = {
fol.Detection: "mask_path",
fol.Detections: "mask_path",
fol.Heatmap: "map_path",
fol.Segmentation: "mask_path",
OrthographicProjectionMetadata: "filepath",
Expand Down Expand Up @@ -68,7 +71,11 @@ async def get_metadata(
filepath = sample["filepath"]
metadata = sample.get("metadata", None)

opm_field, additional_fields = _get_additional_media_fields(collection)
(
opm_field,
detections_fields,
additional_fields,
) = _get_additional_media_fields(collection)

filepath_result, filepath_source, urls = _create_media_urls(
collection,
Expand All @@ -77,6 +84,7 @@ async def get_metadata(
url_cache,
additional_fields=additional_fields,
opm_field=opm_field,
detections_fields=detections_fields,
)
if filepath_result is not None:
filepath = filepath_result
Expand Down Expand Up @@ -389,13 +397,31 @@ def _create_media_urls(
cache: t.Dict,
additional_fields: t.Optional[t.List[str]] = None,
opm_field: t.Optional[str] = None,
detections_fields: t.Optional[t.List[str]] = None,
) -> t.Dict[str, str]:
filepath_source = None
media_fields = collection.app_config.media_fields.copy()

if additional_fields is not None:
media_fields.extend(additional_fields)

if detections_fields is not None:
for field in detections_fields:
detections = get(sample, field)

if not detections:
continue

detections_list = get(detections, "detections")

if not detections_list or len(detections_list) == 0:
continue

len_detections = len(detections_list)

for i in range(len_detections):
media_fields.append(f"{field}.detections[{i}].mask_path")

if (
sample_media_type == fom.POINT_CLOUD
or sample_media_type == fom.THREE_D
Expand All @@ -413,7 +439,10 @@ def _create_media_urls(
media_urls = []

for field in media_fields:
path = _deep_get(sample, field)
path = get(sample, field)

if not path:
continue

if path not in cache:
cache[path] = path
Expand All @@ -435,6 +464,8 @@ def _get_additional_media_fields(
) -> t.List[str]:
additional = []
opm_field = None
detections_fields = None

for cls, subfield_name in _ADDITIONAL_MEDIA_FIELDS.items():
for field_name, field in collection.get_field_schema(
flat=True
Expand All @@ -447,18 +478,13 @@ def _get_additional_media_fields(
if cls == OrthographicProjectionMetadata:
opm_field = field_name

additional.append(f"{field_name}.{subfield_name}")

return opm_field, additional
if cls == fol.Detections:
if detections_fields is None:
detections_fields = [field_name]
else:
detections_fields.append(field_name)

else:
additional.append(f"{field_name}.{subfield_name}")

def _deep_get(sample, keys, default=None):
"""
Get a value from a nested dictionary by specifying keys delimited by '.',
similar to lodash's ``_.get()``.
"""
return reduce(
lambda d, key: d.get(key, default) if isinstance(d, dict) else default,
keys.split("."),
sample,
)
return opm_field, detections_fields, additional

0 comments on commit 2115658

Please sign in to comment.