diff --git a/pysheds/_sgrid.py b/pysheds/_sgrid.py index 7ec23da..f46865a 100644 --- a/pysheds/_sgrid.py +++ b/pysheds/_sgrid.py @@ -1,8 +1,10 @@ -from heapq import heappop, heappush +from heapq import heappop, heappush, heapify import math import numpy as np -from numba import njit, prange +from functools import wraps +from numba import njit, prange, from_dtype from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void +from numba.typed import typedlist # Functions for 'flowdir' @@ -1856,3 +1858,149 @@ def _fill_pits_numba(dem, pit_indices): adjustment = min(diff, adjustment) pits_filled.flat[k] += (adjustment) return pits_filled + +@njit(boundscheck=True, cache=True) +def _first_true1d(arr, start=0, end=None, step=1, invert=False): + if end is None: + end = len(arr) + + if invert: + for i in range(start, end, step): + if not arr[i]: + return i + else: + return -1 + else: + for i in range(start, end, step): + if arr[i]: + return i + else: + return -1 + +@njit(parallel=True, cache=True) +def _top(mask): + nc = mask.shape[1] + rv = np.zeros(nc, dtype='int64') + for i in prange(nc): + rv[i] = _first_true1d(mask[:, i], invert=True) + return rv + +@njit(parallel=True, cache=True) +def _bottom(mask): + nr, nc = mask.shape[0], mask.shape[1] + rv = np.zeros(nc, dtype='int64') + for i in prange(nc): + rv[i] = _first_true1d(mask[:, i], start=nr - 1, end=-1, step=-1, invert=True) + return rv + +@njit(parallel=True, cache=True) +def _left(mask): + nr = mask.shape[0] + rv = np.zeros(nr, dtype='int64') + for i in prange(nr): + rv[i] = _first_true1d(mask[i, :], invert=True) + return rv + +@njit(parallel=True, cache=True) +def _right(mask): + nr, nc = mask.shape[0], mask.shape[1] + rv = np.zeros(nr, dtype='int64') + for i in prange(nr): + rv[i] = _first_true1d(mask[i, :], start=nc - 1, end=-1, step=-1, invert=True) + return rv + + +@njit(cache=True) +def count(start=0, step=1): + # Numba accelerated count() from itertools + # count(10) --> 10 11 12 13 14 ... + # count(2.5, 0.5) --> 2.5 3.0 3.5 ... + n = start + while True: + yield n + n += step + + +def pfwrapper(func): + # Implemenation detail of priority-flood algorithm + # Needed to define the types used in priority queue + @wraps(func) + def _wrapper(dem, mask, *args): + # Tuple elements: + # 0: dem data type (for elevation priority) + # 1: int64 for insertion index (to maintain total ordering) + # 2: int64 for row index + # 3: int64 for col index + tuple_type = Tuple([from_dtype(dem.dtype), int64, int64, int64]) + return func(dem, mask, tuple_type, *args) + return _wrapper + + +@pfwrapper +@njit(cache=True) +def _priority_flood(dem, dem_mask, tuple_type): + open_cells = typedlist.List.empty_list(tuple_type) # Priority queue + pits = typedlist.List.empty_list(tuple_type) # FIFO queue + closed_cells = dem_mask.copy() + isertn = count() + + # Push the edges onto priority queue + y, x = dem.shape + + edge = _left(dem_mask)[:-1] + for row, col in zip(count(), edge): + if col >= 0: + open_cells.append((dem[row, col], next(isertn), row, col)) + closed_cells[row, col] = True + edge = _bottom(dem_mask)[:-1] + for row, col in zip(edge, count()): + if row >= 0: + open_cells.append((dem[row, col], next(isertn), row, col)) + closed_cells[row, col] = True + edge = np.flip(_right(dem_mask))[:-1] + for row, col in zip(count(y - 1, step=-1), edge): + if col >= 0: + open_cells.append((dem[row, col], next(isertn), row, col)) + closed_cells[row, col] = True + edge = np.flip(_top(dem_mask))[:-1] + for row, col in zip(edge, count(x - 1, step=-1)): + if row >= 0: + open_cells.append((dem[row, col], next(isertn), row, col)) + closed_cells[row, col] = True + heapify(open_cells) + + row_offsets = np.array([-1, -1, 0, 1, 1, 1, 0, -1]) + col_offsets = np.array([0, 1, 1, 1, 0, -1, -1, -1]) + + pits_pos = 0 + while open_cells or pits_pos < len(pits): + if pits_pos < len(pits): + elv, _, i, j = pits[pits_pos] + pits_pos += 1 + else: + elv, _, i, j = heappop(open_cells) + + for n in range(8): + row = i + row_offsets[n] + col = j + col_offsets[n] + + if row < 0 or row >= y or col < 0 or col >= x: + continue + + if dem_mask[row, col] or closed_cells[row, col]: + continue + + if dem[row, col] <= elv: + dem[row, col] = elv + pits.append((elv, next(isertn), row, col)) + else: + heappush(open_cells, (dem[row, col], next(isertn), row, col)) + closed_cells[row, col] = True + + # pits book-keeping + if pits_pos == len(pits) and len(pits) > 1024: + # Queue is empty, lets clear it out + pits.clear() + pits_pos = 0 + + return dem \ No newline at end of file diff --git a/pysheds/sgrid.py b/pysheds/sgrid.py index 1093b0c..71625e2 100644 --- a/pysheds/sgrid.py +++ b/pysheds/sgrid.py @@ -6,9 +6,11 @@ import pandas as pd import geojson from affine import Affine +from numba.types import Tuple, int64 +from numba import from_dtype + try: import skimage.measure - import skimage.morphology _HAS_SKIMAGE = True except ModuleNotFoundError: _HAS_SKIMAGE = False @@ -2113,8 +2115,6 @@ def detect_depressions(self, dem, **kwargs): depressions : Raster Boolean Raster indicating locations of depressions. """ - if not _HAS_SKIMAGE: - raise ImportError('detect_depressions requires skimage.morphology module') input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata} kwargs.update(input_overrides) dem = self._input_handler(dem, **kwargs) @@ -2148,23 +2148,13 @@ def fill_depressions(self, dem, nodata_out=np.nan, **kwargs): Raster representing digital elevation data with multi-celled depressions removed. """ - if not _HAS_SKIMAGE: - raise ImportError('resolve_flats requires skimage.morphology module') - input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata} - kwargs.update(input_overrides) - dem = self._input_handler(dem, **kwargs) dem_mask = self._get_nodata_cells(dem) - dem_mask[0, :] = True - dem_mask[-1, :] = True - dem_mask[:, 0] = True - dem_mask[:, -1] = True - # Make sure nothing flows to the nodata cells - seed = np.copy(dem) - seed[~dem_mask] = np.nanmax(dem) - dem_out = skimage.morphology.reconstruction(seed, dem, method='erosion') - dem_out = self._output_handler(data=dem_out, viewfinder=dem.viewfinder, - metadata=dem.metadata, nodata=nodata_out) - return dem_out + result = _self._priority_flood(dem, dem_mask) + dem_filled = self._output_handler(data=result, + viewfinder=dem.viewfinder, + metadata=dem.metadata, + nodata=dem.nodata) + return dem_filled def detect_flats(self, dem, **kwargs): """