Skip to content

Commit

Permalink
add create_index kwarg to geo stages
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Dec 28, 2024
1 parent 30b7469 commit f989a0e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
19 changes: 17 additions & 2 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5277,6 +5277,7 @@ def geo_near(
min_distance=None,
max_distance=None,
query=None,
create_index=True,
):
"""Sorts the samples in the collection by their proximity to a
specified geolocation.
Expand Down Expand Up @@ -5359,6 +5360,8 @@ def geo_near(
query (None): an optional dict defining a
`MongoDB read query <https://docs.mongodb.com/manual/tutorial/query-documents/#read-operations-query-argument>`_
that samples must match in order to be included in this view
create_index (True): whether to create the required spherical
index, if necessary
Returns:
a :class:`fiftyone.core.view.DatasetView`
Expand All @@ -5370,11 +5373,18 @@ def geo_near(
min_distance=min_distance,
max_distance=max_distance,
query=query,
create_index=create_index,
)
)

@view_stage
def geo_within(self, boundary, location_field=None, strict=True):
def geo_within(
self,
boundary,
location_field=None,
strict=True,
create_index=True,
):
"""Filters the samples in this collection to only include samples whose
geolocation is within a specified boundary.
Expand Down Expand Up @@ -5420,13 +5430,18 @@ def geo_within(self, boundary, location_field=None, strict=True):
strict (True): whether a sample's location data must strictly fall
within boundary (True) in order to match, or whether any
intersection suffices (False)
create_index (True): whether to create the required spherical
index, if necessary
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
return self._add_view_stage(
fos.GeoWithin(
boundary, location_field=location_field, strict=strict
boundary,
location_field=location_field,
strict=strict,
create_index=create_index,
)
)

Expand Down
49 changes: 43 additions & 6 deletions fiftyone/core/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3017,15 +3017,21 @@ def _extract_filter_field(val):


class _GeoStage(ViewStage):
def __init__(self, location_field):
def __init__(self, location_field=None, create_index=True):
self._location_field = location_field
self._location_key = None
self._create_index = create_index

@property
def location_field(self):
"""The location field."""
return self._location_field

@property
def create_index(self):
"""Whether to create the required spherical index, if necessary."""
return self._create_index

def validate(self, sample_collection):
if self._location_field is None:
self._location_field = sample_collection._get_geo_location_field()
Expand All @@ -3039,8 +3045,9 @@ def validate(self, sample_collection):
# Assume the user directly specified the subfield to use
self._location_key = self._location_field

# These operations require a spherical index
sample_collection.create_index([(self._location_key, "2dsphere")])
if self._create_index:
# These operations require a spherical index
sample_collection.create_index([(self._location_key, "2dsphere")])


class GeoNear(_GeoStage):
Expand Down Expand Up @@ -3128,6 +3135,8 @@ class GeoNear(_GeoStage):
query (None): an optional dict defining a
`MongoDB read query <https://docs.mongodb.com/manual/tutorial/query-documents/#read-operations-query-argument>`_
that samples must match in order to be included in this view
create_index (True): whether to create the required spherical index,
if necessary
"""

def __init__(
Expand All @@ -3137,8 +3146,12 @@ def __init__(
min_distance=None,
max_distance=None,
query=None,
create_index=True,
):
super().__init__(location_field)
super().__init__(
location_field=location_field,
create_index=create_index,
)
self._point = foug.parse_point(point)
self._min_distance = min_distance
self._max_distance = max_distance
Expand Down Expand Up @@ -3195,6 +3208,7 @@ def _kwargs(self):
["min_distance", self._min_distance],
["max_distance", self._max_distance],
["query", self._query],
["create_index", self._create_index],
]

@classmethod
Expand Down Expand Up @@ -3225,6 +3239,12 @@ def _params(cls):
"placeholder": "",
"default": "None",
},
{
"name": "create_index",
"type": "bool",
"default": "True",
"placeholder": "create_index (default=True)",
},
]


Expand Down Expand Up @@ -3275,10 +3295,20 @@ class GeoWithin(_GeoStage):
strict (True): whether a sample's location data must strictly fall
within boundary (True) in order to match, or whether any
intersection suffices (False)
"""

def __init__(self, boundary, location_field=None, strict=True):
super().__init__(location_field)
def __init__(
self,
boundary,
location_field=None,
strict=True,
create_index=True,
):
super().__init__(
location_field=location_field,
create_index=create_index,
)
self._boundary = foug.parse_polygon(boundary)
self._strict = strict

Expand Down Expand Up @@ -3307,6 +3337,7 @@ def _kwargs(self):
["boundary", self._boundary],
["location_field", self._location_field],
["strict", self._strict],
["create_index", self._create_index],
]

@classmethod
Expand All @@ -3325,6 +3356,12 @@ def _params(cls):
"default": "True",
"placeholder": "strict (default=True)",
},
{
"name": "create_index",
"type": "bool",
"default": "True",
"placeholder": "create_index (default=True)",
},
]


Expand Down

0 comments on commit f989a0e

Please sign in to comment.