Skip to content

Commit

Permalink
Merge pull request #186 from mdbartos/mfd
Browse files Browse the repository at this point in the history
Add profile extraction and fix reading of multiband rasters
  • Loading branch information
mdbartos authored Feb 27, 2022
2 parents 1951f63 + 6a04a2f commit 9d96096
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ ax.set_title('Soil types (raster)', size=14)
- `stream_order` : Compute the (strahler) stream order.
- `extract_river_network` : Extract river segments from a catchment and return a geojson
object.
- `extract_profiles` : Extract river segments and return a list of channel indices along with a dictionary describing connectivity.
- `cell_dh` : Compute the drop in elevation from each cell to its downstream neighbor.
- `cell_distances` : Compute the distance from each cell to its downstream neighbor.
- `cell_slopes` : Compute the slope between each cell and its downstream neighbor.
Expand Down Expand Up @@ -471,6 +472,7 @@ Performance benchmarks on a 2015 MacBook Pro (M: million, K: thousand):
| `hand` | MFD | 36M total, 770K channel | 29.8 [s] |
| `stream_order` | D8 | 36M total, 1M channel | 3.99 [s] |
| `extract_river_network` | D8 | 36M total, 345K channel | 4.07 [s] |
| `extract_profiles` | D8 | 36M total, 345K channel | 2.89 [s] |
| `detect_pits` | N/A | 36M | 1.80 [s] |
| `detect_flats` | N/A | 36M | 1.84 [s] |
| `fill_pits` | N/A | 36M | 2.52 [s] |
Expand Down
Binary file added data/cogeo.tiff
Binary file not shown.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ ax.set_title('Soil types (raster)', size=14)
- `stream_order` : Compute the (strahler) stream order.
- `extract_river_network` : Extract river segments from a catchment and return a geojson
object.
- `extract_profiles` : Extract river segments and return a list of channel indices along with a dictionary describing connectivity.
- `cell_dh` : Compute the drop in elevation from each cell to its downstream neighbor.
- `cell_distances` : Compute the distance from each cell to its downstream neighbor.
- `cell_slopes` : Compute the slope between each cell and its downstream neighbor.
Expand Down Expand Up @@ -475,6 +476,7 @@ Performance benchmarks on a 2015 MacBook Pro (M: million, K: thousand):
| `hand` | MFD | 36M total, 770K channel | 29.8 [s] |
| `stream_order` | D8 | 36M total, 1M channel | 3.99 [s] |
| `extract_river_network` | D8 | 36M total, 345K channel | 4.07 [s] |
| `extract_profiles` | D8 | 36M total, 345K channel | 2.89 [s] |
| `detect_pits` | N/A | 36M | 1.80 [s] |
| `detect_flats` | N/A | 36M | 1.84 [s] |
| `fill_pits` | N/A | 36M | 2.52 [s] |
Expand Down
2 changes: 1 addition & 1 deletion pysheds/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.2"
__version__ = "0.3.3"
32 changes: 31 additions & 1 deletion pysheds/_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from heapq import heappop, heappush
import numpy as np
from numba import njit, prange
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, void
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void

# Functions for 'flowdir'

Expand Down Expand Up @@ -1466,6 +1466,36 @@ def _d8_stream_network_iter_numba(fdir, indegree, orig_indegree, startnodes):
endnode = fdir.flat[startnode]
return profiles

@njit(Tuple((List(List(int64)), DictType(int64, int64)))(int64[:,:], uint8[:],
uint8[:], int64[:], boolean),
cache=True)
def _d8_stream_connection_iter_numba(fdir, indegree, orig_indegree, startnodes,
include_endpoint):
n = startnodes.size
profiles = [[0]]
connections = {0 : 0}
_ = profiles.pop()
_ = connections.pop(0)
for k in range(n):
startnode = startnodes.flat[k]
endnode = fdir.flat[startnode]
profile = [startnode]
while (indegree.flat[startnode] == 0):
profile.append(endnode)
indegree.flat[endnode] -= 1
if (orig_indegree.flat[endnode] > 1):
chain_start = profile[0]
chain_end = profile[-1]
connections[chain_start] = chain_end
if not include_endpoint:
_ = profile.pop()
profiles.append(profile)
if (indegree.flat[endnode] == 0):
profile = [endnode]
startnode = endnode
endnode = fdir.flat[startnode]
return profiles, connections

@njit(float64[:,:](int64[:,:], int64[:,:], float64[:,:]),
parallel=True,
cache=True)
Expand Down
4 changes: 2 additions & 2 deletions pysheds/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def read_raster(data, band=1, window=None, window_crs=None, mask_geometry=False,
if window is None:
shape = f.shape
if len(f.indexes) > 1:
data = np.ma.filled(f.read_band(band))
data = np.ma.filled(f.read(band))
else:
data = np.ma.filled(f.read())
affine = f.transform
Expand All @@ -121,7 +121,7 @@ def read_raster(data, band=1, window=None, window_crs=None, mask_geometry=False,
# If window crs not specified, assume it is in raster crs
ix_window = f.window(*window)
if len(f.indexes) > 1:
data = np.ma.filled(f.read_band(band, window=ix_window))
data = np.ma.filled(f.read(band, window=ix_window))
else:
data = np.ma.filled(f.read(window=ix_window))
affine = f.window_transform(ix_window)
Expand Down
77 changes: 76 additions & 1 deletion pysheds/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def _mfd_compute_hand(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
def extract_river_network(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
routing='d8', algorithm='iterative', **kwargs):
"""
Generates river segments from accumulation and flow_direction arrays.
Generates river segments from flow direction and mask.
Parameters
----------
Expand Down Expand Up @@ -1453,6 +1453,81 @@ def extract_river_network(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32)
geo = geojson.FeatureCollection(featurelist)
return geo

def extract_profiles(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
include_endpoint=True, routing='d8', algorithm='iterative',
**kwargs):
"""
Extracts river segments and connectivity of river segments from flow direction and mask.
Parameters
----------
fdir : Raster
Flow direction data.
mask : Raster
Boolean raster indicating channelized regions
dirmap : list or tuple (length 8)
List of integer values representing the following
cardinal and intercardinal directions (in order):
[N, NE, E, SE, S, SW, W, NW]
include_endpoint : bool
If True, include last cell in each river segment.
If False, do not include last cell (such that
cell indices in each profile are unique).
routing : str
Routing algorithm to use:
'd8' : D8 flow directions
algorithm : str
Algorithm type to use:
'iterative' : Use an iterative algorithm (recommended).
'recursive' : Use a recursive algorithm.
Additional keyword arguments (**kwargs) are passed to self.view.
Returns
-------
profiles : list of lists of ints
A list containing a collection of river profiles. Each river profile
is a list containing the indices of the grid cells inside the
river segment. Indices correspond to the flattened index of river segment
cells.
connections : dict (int : int)
A dictionary describing the connectivity of the profiles. For each
key-value pair, the key represents index of the upstream profile and
the value represents the index of the downstream profile that it drains to.
Indices correspond to the ordered elements of the `profiles` object.
"""
if routing.lower() == 'd8':
fdir_overrides = {'dtype' : np.int64, 'nodata' : fdir.nodata}
else:
raise NotImplementedError('Only implemented for `d8` routing.')
mask_overrides = {'dtype' : np.bool8, 'nodata' : False}
kwargs.update(fdir_overrides)
fdir = self._input_handler(fdir, **kwargs)
kwargs.update(mask_overrides)
mask = self._input_handler(mask, **kwargs)
# Find nodata cells and invalid cells
nodata_cells = self._get_nodata_cells(fdir)
invalid_cells = ~np.in1d(fdir.ravel(), dirmap).reshape(fdir.shape)
# Set nodata cells to zero
fdir[nodata_cells] = 0
fdir[invalid_cells] = 0
maskleft, maskright, masktop, maskbottom = self._pop_rim(mask, nodata=False)
masked_fdir = np.where(mask, fdir, 0).astype(np.int64)
startnodes = np.arange(fdir.size, dtype=np.int64)
endnodes = _self._flatten_fdir_numba(masked_fdir, dirmap).reshape(fdir.shape)
indegree = np.bincount(endnodes.ravel(), minlength=fdir.size).astype(np.uint8)
orig_indegree = np.copy(indegree)
startnodes = startnodes[(indegree == 0)]
profiles, connections = _self._d8_stream_connection_iter_numba(endnodes, indegree,
orig_indegree,
startnodes,
include_endpoint)
connections = dict(connections)
indices = {profile[0] : index for index, profile in enumerate(profiles)}
connections = {indices[key] : indices.setdefault(value, indices[key])
for key, value in connections.items()}
return profiles, connections

def stream_order(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
nodata_out=0, routing='d8', algorithm='iterative', **kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import setup

setup(name='pysheds',
version='0.3.2',
version='0.3.3',
description='🌎 Simple and fast watershed delineation in python.',
author='Matt Bartos',
author_email='[email protected]',
Expand Down
17 changes: 13 additions & 4 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
dir_path = os.path.join(data_dir, 'dir.asc')
dem_path = os.path.join(data_dir, 'dem.tif')
roi_path = os.path.join(data_dir, 'roi.tif')
multiband_path = os.path.join(data_dir, 'cogeo.tiff')
feature_geometry = [{'type': 'Polygon',
'coordinates': (((-97.29749977660477, 32.74000135435936),
(-97.29083107907053, 32.74000328969928),
Expand Down Expand Up @@ -331,21 +332,22 @@ def test_cell_slopes():
slopes_dinf = grid.cell_slopes(dem, fdir_dinf, routing='dinf')
slopes_mfd = grid.cell_slopes(dem, fdir_mfd, routing='mfd')

# def test_set_nodata():
# grid.set_nodata('dir', 0)

def test_to_ascii():
catch = d.catch
fdir = d.fdir
grid.clip_to(catch)
# np.float is depreciated
grid.to_ascii(fdir, 'test_dir.asc', target_view=fdir.viewfinder, dtype=np.float64)
fdir_out = grid.read_ascii('test_dir.asc', dtype=np.uint8)
assert((fdir_out == fdir).all())
grid.to_ascii(fdir, 'test_dir.asc', dtype=np.uint8)
fdir_out = grid.read_ascii('test_dir.asc', dtype=np.uint8)
assert((fdir_out == grid.view(fdir)).all())

def test_read_raster():
band_1 = grid.read_raster(multiband_path, band=1)
band_2 = grid.read_raster(multiband_path, band=2)
band_3 = grid.read_raster(multiband_path, band=3)

def test_to_raster():
catch = d.catch
fdir = d.fdir
Expand Down Expand Up @@ -406,6 +408,13 @@ def test_extract_river_network():
grid.extract_river_network(catch, acc > 20, algorithm='recursive')
# TODO: Need more checks here. Check if endnodes equals next startnode

def test_extract_profiles():
fdir = d.fdir
catch = d.catch
acc = d.acc
grid.clip_to(catch)
profiles, connections = grid.extract_profiles(catch, acc > 20)

def test_view_methods():
dem = d.dem
catch = d.catch
Expand Down

0 comments on commit 9d96096

Please sign in to comment.