diff --git a/pysheds/grid.py b/pysheds/grid.py index 4915171..b585cb3 100644 --- a/pysheds/grid.py +++ b/pysheds/grid.py @@ -550,7 +550,7 @@ def resize(self, data, new_shape, out_suffix='_resized', inplace=True, return self._output_handler(data=data, out_name=out_name, properties=grid_props, inplace=inplace, metadata=metadata) - def nearest_cell(self, x, y, affine=None): + def nearest_cell(self, x, y, affine=None, snap='corner'): """ Returns the index of the cell (column, row) closest to a given geographical coordinate. @@ -565,6 +565,11 @@ def nearest_cell(self, x, y, affine=None): Affine transformation that defines the translation between geographic x/y coordinate and array row/column coordinate. Defaults to self.affine. + snap : str + Indicates the cell indexing method. If "corner", will resolve to + snapping the (x,y) geometry to the index of the nearest top-left + cell corner. If "center", will return the index of the cell that + the geometry falls within. Returns ------- x_i, y_i : tuple of ints @@ -576,7 +581,8 @@ def nearest_cell(self, x, y, affine=None): assert isinstance(affine, Affine) except: raise TypeError('affine must be an Affine instance.') - col, row = np.around(~affine * (x, y)).astype(int) + snap_dict = {'corner': np.around, 'center': np.floor} + col, row = snap_dict[snap](~affine * (x, y)).astype(int) return col, row def set_bbox(self, new_bbox): @@ -837,7 +843,7 @@ def facet_flow(self, e0, e1, e2, d1=1, d2=1): def catchment(self, x, y, data, pour_value=None, out_name='catch', dirmap=None, nodata_in=None, nodata_out=0, xytype='index', routing='d8', recursionlimit=15000, inplace=True, apply_mask=False, ignore_metadata=False, - **kwargs): + snap='corner', **kwargs): """ Delineates a watershed from a given pour point (x, y). @@ -884,6 +890,10 @@ def catchment(self, x, y, data, pour_value=None, out_name='catch', dirmap=None, If True, "mask" the output using self.mask. ignore_metadata : bool If False, require a valid affine transform and crs. + snap : str + Function to use on array for indexing: + 'corner' : numpy.around() + 'center' : numpy.floor() """ # TODO: Why does this use set_dirmap but flowdir doesn't? dirmap = self._set_dirmap(dirmap, data) @@ -909,7 +919,7 @@ def catchment(self, x, y, data, pour_value=None, out_name='catch', dirmap=None, dirmap=dirmap, nodata_in=nodata_in, nodata_out=nodata_out, xytype=xytype, recursionlimit=recursionlimit, inplace=inplace, apply_mask=apply_mask, ignore_metadata=ignore_metadata, - properties=properties, metadata=metadata, **kwargs) + properties=properties, metadata=metadata, snap=snap, **kwargs) elif routing.lower() == 'dinf': return self._dinf_catchment(x, y, fdir=fdir, pour_value=pour_value, out_name=out_name, dirmap=dirmap, nodata_in=nodata_in, nodata_out=nodata_out, @@ -920,7 +930,7 @@ def catchment(self, x, y, data, pour_value=None, out_name='catch', dirmap=None, def _d8_catchment(self, x, y, fdir=None, pour_value=None, out_name='catch', dirmap=None, nodata_in=None, nodata_out=0, xytype='index', recursionlimit=15000, inplace=True, apply_mask=False, ignore_metadata=False, properties={}, - metadata={}, **kwargs): + metadata={}, snap='corner', **kwargs): # Vectorized Recursive algorithm: # for each cell j, recursively search through grid to determine @@ -943,7 +953,7 @@ def d8_catchment_search(cells): # to given geographic coordinate # Valid if the dataset is a view. if xytype == 'label': - x, y = self.nearest_cell(x, y, fdir.affine) + x, y = self.nearest_cell(x, y, fdir.affine, snap) # get the flattened index of the pour point pour_point = np.ravel_multi_index(np.array([y, x]), fdir.shape) @@ -977,7 +987,7 @@ def d8_catchment_search(cells): def _dinf_catchment(self, x, y, fdir=None, pour_value=None, out_name='catch', dirmap=None, nodata_in=None, nodata_out=0, xytype='index', recursionlimit=15000, inplace=True, apply_mask=False, ignore_metadata=False, properties={}, - metadata={}, **kwargs): + metadata={}, snap='corner', **kwargs): # Filter warnings due to invalid values np.warnings.filterwarnings(action='ignore', message='Invalid value encountered', category=RuntimeWarning) @@ -1028,7 +1038,7 @@ def dinf_catchment_search(cells): # TODO: This relies on the bbox of the grid instance, not the dataset # Valid if the dataset is a view. if xytype == 'label': - x, y = self.nearest_cell(x, y, fdir.affine) + x, y = self.nearest_cell(x, y, fdir.affine, snap) # get the flattened index of the pour point pour_point = np.ravel_multi_index(np.array([y, x]), fdir.shape) @@ -1496,7 +1506,7 @@ def _remove_dinf_cycles(self, fdir_0, fdir_1, startnodes, max_cycles=2): def flow_distance(self, x, y, data, weights=None, dirmap=None, nodata_in=None, nodata_out=0, out_name='dist', routing='d8', method='shortest', inplace=True, xytype='index', apply_mask=True, ignore_metadata=False, - **kwargs): + snap='corner', **kwargs): """ Generates an array representing the topological distance from each cell to the outlet. @@ -1540,6 +1550,10 @@ def flow_distance(self, x, y, data, weights=None, dirmap=None, nodata_in=None, If True, "mask" the output using self.mask. ignore_metadata : bool If False, require a valid affine transform and CRS. + snap : str + Function to use on array for indexing: + 'corner' : numpy.around() + 'center' : numpy.floor() """ if not _HAS_SCIPY: raise ImportError('flow_distance requires scipy.sparse module') @@ -1565,19 +1579,21 @@ def flow_distance(self, x, y, data, weights=None, dirmap=None, nodata_in=None, out_name=out_name, method=method, inplace=inplace, xytype=xytype, apply_mask=apply_mask, ignore_metadata=ignore_metadata, - properties=properties, metadata=metadata, **kwargs) + properties=properties, metadata=metadata, + snap=snap, **kwargs) elif routing.lower() == 'dinf': return self._dinf_flow_distance(x, y, fdir, weights=weights, dirmap=dirmap, nodata_in=nodata_in, nodata_out=nodata_out, out_name=out_name, method=method, inplace=inplace, xytype=xytype, apply_mask=apply_mask, ignore_metadata=ignore_metadata, - properties=properties, metadata=metadata, **kwargs) + properties=properties, metadata=metadata, + snap=snap, **kwargs) def _d8_flow_distance(self, x, y, fdir, weights=None, dirmap=None, nodata_in=None, nodata_out=0, out_name='dist', method='shortest', inplace=True, xytype='index', apply_mask=True, ignore_metadata=False, properties={}, - metadata={}, **kwargs): + metadata={}, snap='corner', **kwargs): # Construct flat index onto flow direction array domain = np.arange(fdir.size) fdir_orig_type = fdir.dtype @@ -1595,7 +1611,7 @@ def _d8_flow_distance(self, x, y, fdir, weights=None, dirmap=None, nodata_in=Non startnodes, endnodes = self._construct_matching(fdir, domain, dirmap=dirmap) if xytype == 'label': - x, y = self.nearest_cell(x, y, fdir.affine) + x, y = self.nearest_cell(x, y, fdir.affine, snap) # TODO: Currently the size of weights is hard to understand if weights is not None: weights = weights.ravel() @@ -1625,7 +1641,7 @@ def _d8_flow_distance(self, x, y, fdir, weights=None, dirmap=None, nodata_in=Non def _dinf_flow_distance(self, x, y, fdir, weights=None, dirmap=None, nodata_in=None, nodata_out=0, out_name='dist', method='shortest', inplace=True, xytype='index', apply_mask=True, ignore_metadata=False, - properties={}, metadata={}, **kwargs): + properties={}, metadata={}, snap='corner', **kwargs): # Filter warnings due to invalid values np.warnings.filterwarnings(action='ignore', message='Invalid value encountered', category=RuntimeWarning) @@ -1658,7 +1674,7 @@ def _dinf_flow_distance(self, x, y, fdir, weights=None, dirmap=None, nodata_in=N assert(startnodes.size == endnodes_0.size) assert(startnodes.size == endnodes_1.size) if xytype == 'label': - x, y = self.nearest_cell(x, y, fdir.affine) + x, y = self.nearest_cell(x, y, fdir.affine, snap) # TODO: Currently the size of weights is hard to understand if weights is not None: if isinstance(weights, list) or isinstance(weights, tuple): diff --git a/tests/test_grid.py b/tests/test_grid.py index d35d480..34d200c 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -51,6 +51,16 @@ def test_constructors(): def test_dtype(): assert(grid.dir.dtype == np.uint8) +def test_nearest_cell(): + ''' + corner: snaps to nearest top/left + center: snaps to index of cell that contains the geometry + ''' + col, row = grid.nearest_cell(x, y, snap='corner') + assert (col, row) == (229, 101) + col, row = grid.nearest_cell(x, y, snap='center') + assert (col, row) == (228, 100) + def test_catchment(): # Reference routing grid.catchment(x, y, data='dir', dirmap=dirmap, out_name='catch',