Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible coordinate transform #9543

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

benbovy
Copy link
Member

@benbovy benbovy commented Sep 24, 2024

This PR is a first step towards adding generic support for coordinate transforms in Xarray (i.e., analytical coordinates, functional index, etc.), which has been discussed already in different issues or threads:

(I might miss other issues / discussions)

I started with a few rough experimentations but ended up with something more concrete that seems to work reasonably well, hence directly opening a (draft) PR. The design & implementation detailled below is still very much open to discussion, though! There's an usage example further below using a 2D affine transformation. It would be nice to test this with other examples.

cc @rabernat @dcherian @TomNicholas @martindurant

Design / Implementation

This PR adds three new classes that should facilitate integrating any coordinate transform into Xarray:

CoordinateTransform

Abstract (wrapper) class to handle coordinate transformation with support of dimension and coordinate names.

  • many transforms should be pluggable via this class (by subclassing it)
  • supports bulk (vectorized) transformation, both in forward and reverse direction
  • supports any arbitrary number of coordinates / dimensions
    • lon, lat = f(x, y)
    • lon, lat, time = f2(x, y, t)
    • etc.
  • one restriction is that the coordinates of a same transform must all have the same dimensions
    • lon(x,y) / lat(x,y)
    • lon(x,y,t) / lat(x,y,t) / time(x,y,t)
    • etc.
    • much simpler!
    • In some cases however, the transform parameter values are such that it can be applied independently over each dimension (e.g., rioxarray creates x(x) / y(y) dimension coordinates when the affine transform is rectilinear with no rotation). For those cases we cannot use a single CoordinateTransform instance, but it is still possible to wrap the same underlying transform object in several instances and link their respective coordinates at the xarray Index level (see below).

CoordinateTransformIndexingAdapter

Internal class for creating indexable coordinate variables from a transform (no need to change the Xarray data model!).

  • wraps a CoordinateTransform instance
  • coordinate labels are computed on-demand (lazy coordinates)
  • supports both (explicit) orthogonal and vectorized indexing
  • supports dimension re-ordering (transpose)
  • doesn't support item assignment (of course)

CoordinateTransformIndex

Helper class for creating Xarray (custom) indexes based on coordinate transforms.

  • wraps a CoordinateTransform instance
  • takes care of creating the index (lazy) coordinates
  • supports label-based selection (i.e., using "physical" or "world" labels)
    • only advanced (point-wise) indexing for now
    • any idea on what else would be nice here and how to implement it?
  • supports alignment by comparing indexes based on their transform (not on their explicit coordinate labels)
    • only exact alignment for now (no join)
  • may be used directly, although should mostly be either subclassed or encapsulated in another Xarray Index class
    • in the rioxarray example (see above), we might want to encapsulate two instances of CoordinateTransformIndex into a custom RasterIndex for the x and y coordinates respectively
    • a custom Xarray index is the right place for encoding / decoding coordinate definition

Usage Example (Affine 2D)

CoordinateTransform subclass

Let's write a subclass of CoordinateTransform that handles 2-d affine transformation (using affine). It is basically boilerplate code that takes care of dimension or coordinate names around the unlabelled input/output arrays of the underlying affine.Affine object.

import affine
import xarray as xr


class Affine2DCoordinateTransform(xr.CoordinateTransform):
    """Affine 2D coordinate transform."""

    affine: affine.Affine
    xy_dims = tuple[str]
    
    def __init__(
        self,
        affine: affine.Affine,
        coord_names: Iterable[Hashable],
        dim_size: Mapping[str, int],
        dtype: Any = np.dtype(np.float64),
    ):
        # two dimensions
        assert len(coord_names) == 2
        assert len(dim_size) == 2

        super().__init__(coord_names, dim_size, dtype=dtype)
        self.affine = affine

        # array dimensions in reverse order (y = rows, x = cols)
        self.xy_dims = tuple(self.dims)
        self.dims = (self.dims[1], self.dims[0])

    def forward(self, dim_positions):
        positions = [dim_positions[dim] for dim in self.xy_dims]
        x_labels, y_labels = self.affine * tuple(positions)

        results = {}
        for name, labels in zip(self.coord_names, [x_labels, y_labels]):
            results[name] = labels

        return results

    def reverse(self, coord_labels):
        labels = [coord_labels[name] for name in self.coord_names]
        x_positions, y_positions = ~self.affine * tuple(labels)

        results = {}
        for dim, positions in zip(self.xy_dims, [x_positions, y_positions]):
            results[dim] = positions

        return results
    
    def equals(self, other):
        return self.affine == other.affine and self.dim_size == other.dim_size

Dataset, coordinates and index creation

In this example the index and the lazy coordinates are created from scratch, no pre-existing (explicit) coordinates are required!

from xarray.indexes import CoordinateTransformIndex

transform = Affine2DCoordinateTransform(
    affine.Affine.scale(1.0, 2.0),
    coord_names=("xc", "yc"),
    dim_size={"x": 10_000, "y": 20_000},
)

index = CoordinateTransformIndex(transform)
ds = xr.Dataset(coords=index.create_coordinates())

The resulting Dataset:

>>> ds
<xarray.Dataset> Size: 3GB
Dimensions:  (y: 10000, x: 20000)
Coordinates:
  * xc       (y, x) float64 2GB 0.0 1.0 2.0 3.0 4.0 ... 2e+04 2e+04 2e+04 2e+04
  * yc       (y, x) float64 2GB 0.0 2.0 4.0 6.0 ... 1.999e+04 2e+04 2e+04
Dimensions without coordinates: y, x
Data variables:
    *empty*
Indexes:
  ┌ xc       CoordinateTransformIndexyc

Coordinates "xc" and "yc" are big but they are lazy!

>>> ds.xc
<xarray.DataArray 'xc' (y: 10000, x: 20000)> Size: 2GB
[200000000 values with dtype=float64]
Coordinates:
  * xc       (y, x) float64 2GB 0.0 1.0 2.0 3.0 4.0 ... 2e+04 2e+04 2e+04 2e+04
  * yc       (y, x) float64 2GB 0.0 2.0 4.0 6.0 ... 1.999e+04 2e+04 2e+04
Dimensions without coordinates: y, x
Indexes:
  ┌ xc       CoordinateTransformIndexyc

>>> ds["xc"].variable._data
CoordinateTransformIndexingAdapter(transform=<__main__.Affine2DCoordinateTransform object at 0x15fb63790>)

Indexing

Orthogonal indexing (it is fast, it only computes 2x6 coordinate values below):

>>> ds.yc.isel(y=[0, 1, 3], x=slice(0, 2))
<xarray.DataArray 'yc' (y: 3, x: 2)> Size: 48B
array([[0., 0.],
       [2., 2.],
       [6., 6.]])
Coordinates:
    xc       (y, x) float64 48B 0.0 1.0 0.0 1.0 0.0 1.0
    yc       (y, x) float64 48B 0.0 0.0 2.0 2.0 6.0 6.0
Dimensions without coordinates: y, x

Also works after re-ordering the dimensions:

>>> ds.transpose().yc.isel(y=[0, 1, 3], x=slice(0, 2))
<xarray.DataArray 'yc' (x: 2, y: 3)> Size: 48B
array([[0., 2., 6.],
       [0., 2., 6.]])
Coordinates:
    xc       (x, y) float64 48B 0.0 0.0 0.0 1.0 1.0 1.0
    yc       (x, y) float64 48B 0.0 2.0 6.0 0.0 2.0 6.0
Dimensions without coordinates: x, y

Vectorized indexing:

>>> ds.yc.isel(
...     y=xr.Variable("points", [0, 1, 3]),
...     x=xr.Variable("points", [0, 1, 3]),
... )
<xarray.DataArray 'yc' (points: 3)> Size: 24B
array([0., 2., 6.])
Coordinates:
    xc       (points) float64 24B 0.0 1.0 3.0
    yc       (points) float64 24B 0.0 2.0 6.0
Dimensions without coordinates: points

Label-based selection

Point-wise selection:

>>> ds.sel(
...     xc=xr.Variable("points", [101.34, 545.23, 876.76]),
...     yc=xr.Variable("points", [13.12, 54.98, 76.43]),
...     method="nearest",
... )
<xarray.Dataset> Size: 48B
Dimensions:  (points: 3)
Coordinates:
    xc       (points) float64 24B 101.0 545.0 877.0
    yc       (points) float64 24B 14.0 54.0 76.0
Dimensions without coordinates: points
Data variables:
    *empty*

What's next?

A few potential improvements from here:

  • allow returning or re-calculating the transform instead of computing the coordinate labels while indexing
    • when possible, this should keep the xarray coordinates lazy and should also preserve their index
    • the obvious case if when a full slice is given for each dimension... Are there other less obvious cases?
    • allow CoordinateTransform.forward() to return a new instance of CoordinateTransform?
  • allow special handling of dimension reduction (e.g., return another transform in the reduced space)
  • possible to add generic support for joining / concatenating coordinate transforms? I.e., implement CoordinateTransformIndex.concat and CoordinateTransformIndex.join
  • handle chunking
    • maybe best solved at a higher level? I.e., one transform instance per chunk
  • add some convenient API for setting a new Xarray index from existing dimensions in a Dataset or DataArray?

@mdsumner
Copy link

Nice!! Thanks, I'm having fun with this - appreciate all the detail and functionality here it really helps a (non-native) Python learner.

@mdsumner
Copy link

mdsumner commented Sep 25, 2024

One thing is that the coordinate values are currently "left"/"top" aligned, not the centre, so here we start at left/top 0,0 and end at 4,4.

from xarray.indexes import CoordinateTransformIndex

transform = Affine2DCoordinateTransform(
    affine.Affine.scale(1.0, 1.0),
    coord_names=("xc", "yc"),
    dim_size={"x": 5, "y": 5},
)

index = CoordinateTransformIndex(transform)
ds = xr.Dataset(coords=index.create_coordinates())

ds.yc.values
#array([[0., 0., 0., 0., 0.],
#       [1., 1., 1., 1., 1.],
#       [2., 2., 2., 2., 2.],
#       [3., 3., 3., 3., 3.],
#       [4., 4., 4., 4., 4.]])

I'm assuming that a pure-scale puts us in the cell-area context of 0, 0, shape[0], shape[1] and so I pursued a world-realistic-ish context to convince myself.

I prefer to think in shape+bbox than in transforms when no shear is needed, so I'm using the gdal transform as an intermediate with a helper fun:

https://gist.github.com/mdsumner/dde0b611a4523e3485006c0df0143c2d

(fwiw, I'm sure this is obvious and not exactly priority rn but I'm excited to be able to delve into this and flesh out how I think about it in this context)

edit: I appreciate there's no absolutely right answer here, you might want (even decoupled per dimension) different alignment for your lazy coords in different contexts.

@benbovy
Copy link
Member Author

benbovy commented Sep 25, 2024

Thanks for the feedback @mdsumner.

Yes the idea is to have something very generic in Xarray such that we can build domain-specific applications on top of it. It is very useful to test this functionality in various contexts now to make sure we are providing the right levels of abstractions. So please keep having fun with this :-) !

Regarding your example, I think that rioxarray combines the input affine transformation with affine.Affine.translation(0.5, 0.5) to make coordinate values center aligned. If it is more natural to think in shape+bbox than in transforms in the geo domain, let's build something on top of CoordinateTransformIndex, e.g., something like below adapted from your gist example:

class GeoIndex(CoordinateTransformIndex):

    @classmethod
    def from_shape(cls, shape, bbox=None, center=True):
        if bbox is None: 
            bbox = (0.0, 0.0, shape[0], shape[1])
    
        gdal = (
            bbox[0], (bbox[2] - bbox[0]) / shape[0], 0.0, 
            bbox[3], 0.0, (bbox[1] - bbox[3]) / shape[1]
        )

        aff = affine.Affine.from_gdal(*gdal)

        if center:
            coord_names = ("xc", "yc")
            aff *= affine.Affine.translation(0.5, 0.5)
        else:
            # left/top
            coord_names = ("xlt", "ylt")
            
        transform = Affine2DCoordinateTransform(
            aff,
            coord_names=coord_names,
            dim_size={"x": shape[1], "y": shape[0]},
        )
            
        return cls(transform=transform)
>>> bbox = (-3950000, -3950000, 3950000, 4350000)
>>> shape = (316, 332)
>>> index = GeoIndex.from_shape(shape, bbox=bbox)
>>> ds = xr.Dataset(coords=index.create_coordinates())
>>> ds.isel(x=slice(0, 159), y=slice(0, 167))
<xarray.Dataset> Size: 425kB
Dimensions:  (y: 167, x: 159)
Coordinates:
    xc       (y, x) float64 212kB -3.938e+06 -3.912e+06 ... -1.25e+04 1.25e+04
    yc       (y, x) float64 212kB 4.338e+06 4.338e+06 ... 1.875e+05 1.875e+05
Dimensions without coordinates: y, x
Data variables:
    *empty*

(This gives the same coordinate values than the "ice" dataset loaded from the .tif file using the rasterio engine in your 2nd gist)

Comment on lines +1479 to +1482
# TODO: rounding the decimal positions is not always the behavior we expect
# (there are different ways to represent implicit intervals)
# we should probably make this customizable.
pos = np.round(pos).astype("int")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important I think.

If the coordinates values correspond to the physical values at the top/left pixel corners in the 2D case, we may rather want np.floor(pos).astype("int") when converting decimal positions (obtained by inverse transformation) to integer indexers.

@martindurant
Copy link
Contributor

Great to see this. I haven't looked at the implementation yet, but I think I agree with the description whole heartedly.

It would be the place of the various IO backends to instantiate the affine (or whatever) transform from the the metadata standards of the respective formats.

@benbovy
Copy link
Member Author

benbovy commented Sep 25, 2024

For completeness, here is an implementation of the 1-dimensional "range index" discussed in #8955.

The coordinate transform subclass:

class Range1DCoordinateTransform(xr.CoordinateTransform):
    """Simple bounded interval 1-d coordinate transform."""

    left: float
    right: float
    dim: str
    size: int

    def __init__(
        self,
        left: float,
        right: float,
        coord_name: Hashable,
        dim: str,
        size: int,
        dtype: Any = None,
    ):  
        if dtype is None:
            dtype = np.dtype(np.float64)

        super().__init__([coord_name], {dim: size}, dtype=dtype)

        self.left = left
        self.right = right
        self.dim = dim
        self.size = size

    def forward(self, dim_positions):
        positions = dim_positions[self.dim]
        labels = self.left + positions * (self.right - self.left) / self.size
        return {self.dim: labels}
        
    def reverse(self, coord_labels):
        labels = coord_labels[self.coord_names[0]]
        positions = (labels - self.left) * self.size / (self.right - self.left)
        return {self.dim: positions}

    def equals(self, other):
        return (
            self.left == other.left
            and self.right == other.right
            and self.size == other.size
        )

Dataset creation:

>>> range_tr = Range1DCoordinateTransform(1.0, 2.0, "x", "x", 100)
>>> index = CoordinateTransformIndex(range_tr)
>>> ds = xr.Dataset(data_vars={"foo": ("x", np.arange(100))}, coords=index.create_coordinates())
>>> ds
<xarray.Dataset> Size: 2kB
Dimensions:  (x: 100)
Coordinates:
  * x        (x) float64 800B 1.0 1.01 1.02 1.03 1.04 ... 1.96 1.97 1.98 1.99
Data variables:
    foo      (x) int64 800B 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Indexes:
    x        CoordinateTransformIndex

This example is interesting because in this simple case we would expect a few more operations to work than in the case of more complex transformations such as 2D affine with rotation and/or shear, e.g.,

  • indexing with a slice (step=1) should preserve the coordinate index but it doesn't:
>>> ds.isel(x=slice(5, 10)).xindexes
Indexes:
    *empty*
  • basic label-based selection should also work, but it is not supported:
>>> ds.sel(x=1.65, method="nearest")
TypeError: CoordinateTransformIndex only supports advanced (point-wise) indexing with either xarray.DataArray or xarray.Variable objects.

Perhaps we could try adding support for this in CoordinateTransform and/or CoordinateTransformIndex? My concern is that we may end up cluttering the interface / implementation of those classes with many special cases.

An alternative option is building on top of it, e.g., in this case also provide a Range1DIndex class like so:

---- expand here to see the implementation of Range1DIndex ----
from xarray.core.indexes import IndexSelResult


class Range1DIndex(CoordinateTransformIndex):

    transform: Range1DCoordinateTransform
    dim: str
    coord_name: Hashable
    size: int

    def __init__(
        self,
        left: float,
        right: float,
        coord_name: Hashable,
        dim: str,
        size: int,
        dtype: Any = None,
    ):
        self.transform = Range1DCoordinateTransform(
            left, right, coord_name, dim, size, dtype
        )
        self.dim = dim
        self.coord_name = coord_name
        self.size = size

    def isel(self, indexers):
        idxer = indexers[self.dim]

        # straightforward to generate a new index if a slice is given with step 1
        if isinstance(idxer, slice) and (idxer.step == 1 or idxer.step is None):
            start = max(idxer.start, 0)
            stop = min(idxer.stop, self.size)
            
            new_left = self.transform.forward({self.dim: start})[self.coord_name]
            new_right = self.transform.forward({self.dim: stop})[self.coord_name]
            new_size = stop - start

            return Range1DIndex(new_left, new_right, self.coord_name, self.dim, new_size)

        return None

    def sel(self, labels, method=None, tolerance=None):
        label = labels[self.dim]

        if isinstance(label, slice):
            if label.step is None:
                # slice indexing (preserve the index)
                pos = self.transform.reverse({self.dim: np.array([label.start, label.stop])})
                pos = np.round(pos[self.coord_name]).astype("int")
                new_start = max(pos[0], 0)
                new_stop = min(pos[1], self.size)
                return IndexSelResult({self.dim: slice(new_start, new_stop)})
            else:
                # otherwise convert to basic (array) indexing
                label = np.arange(label.start, label.stop, label.step)

        # support basic indexing (in the 1D case basic vs. vectorized indexing
        # are pretty much similar)
        unwrap_xr = False
        if not isinstance(label, xr.Variable | xr.DataArray):
            # basic indexing -> either scalar or 1-d array
            try:
                var = xr.Variable("_", label)
            except ValueError:
                var = xr.Variable((), label)
            labels = {self.dim: var}
            unwrap_xr = True

        result = super().sel(labels, method=method, tolerance=tolerance)

        if unwrap_xr:
            dim_indexers = {self.dim: result.dim_indexers[self.dim].values}
            result = IndexSelResult(dim_indexers)
        
        return result
>>> index = Range1DIndex(1.0, 2.0, "x", "x", 100)
>>> ds2 = xr.Dataset(data_vars={"foo": ("x", np.arange(100))}, coords=index.create_coordinates())

Slicing (notice the preserved Range1DIndex):

>>> ds2.isel(x=slice(5, 10))
<xarray.Dataset> Size: 80B
Dimensions:  (x: 5)
Coordinates:
  * x        (x) float64 40B 1.05 1.06 1.07 1.08 1.09
Data variables:
    foo      (x) int64 40B 5 6 7 8 9
Indexes:
    x        Range1DIndex

Some basic label-based selection:

>>> ds2.sel(x=1.654, method="nearest")
<xarray.Dataset> Size: 16B
Dimensions:  ()
Coordinates:
    x        float64 8B 1.65
Data variables:
    foo      int64 8B 65

>>> ds2.sel(x=slice(1.465, 1.874), method="nearest")   # preserves the index!
<xarray.Dataset> Size: 640B
Dimensions:  (x: 40)
Coordinates:
  * x        (x) float64 320B 1.47 1.48 1.49 1.5 1.51 ... 1.83 1.84 1.85 1.86
Data variables:
    foo      (x) int64 320B 47 48 49 50 51 52 53 54 ... 79 80 81 82 83 84 85 86
Indexes:
    x        Range1DIndex

For such a simple 1-d range example, the coordinate transform abstraction is actually a bit overkill but still has the advantage of providing the lazy coordinate variable "for free".

More consistent with the rest of Xarray API where `coords` is used
everywhere.
@astrofrog
Copy link

@benbovy - @Cadair and I have been playing around with trying to get this to work with the astropy APE 14 WCS specification. Here is a minimal example:

https://gist.github.com/Cadair/4a03750868e044ac4bdd6f3a04ed7abc

We are running into a bug in the __repr__ which is causing an out of bounds error. It seems that accessing the coordinates directly works so it seems to be a problem specific to the __repr__?

Another unrelated comment: it would be nice to have the CoordinateTransform class be a proper abc class, and have the methods that need to be implemented be defined as abstract methods (e.g. forward and reverse)

@benbovy
Copy link
Member Author

benbovy commented Oct 2, 2024

Thanks for the feedback @astrofrog!

I'll look into the __repr__ issue. Could you provide a minimal reproducible example or a link where I can download the FITS file used in your example, please?

@Cadair
Copy link

Cadair commented Oct 2, 2024

@benbovy the fits file is included with astropy , so the code in the notebook should run as-is I believe.

@benbovy
Copy link
Member Author

benbovy commented Oct 2, 2024

Ah thanks. The __repr__ issue should now be fixed in 09667c5.

coord_labels = {
name: labels[name].values for name in self.transform.coord_names
}
dim_positions = self.transform.reverse(coord_labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to guarantee that out of bounds indexing raises an informative error, rather than silently attempting to access invalid data (or indexing from the end instead of the start of arrays).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still need to add unit tests but bound checking is done when indexing the coordinate variables in CoordinateTransformIndexingAdapter.

return None

def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to support tolerance in some form? This is a common and useful form of error checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty tricky to support it here I think, probably better to handle it on a per case basis.

For basic transformations I guess it could be possible to calculate a single, uniform tolerance value in decimal array index units and validate the selected elements using those units (cheap). In other cases we would need to compute the forward transformation of the extracted array indices and then validate the selected elements based on distances in physical units (more expensive).

Also, there may be cases where the coordinates of a same transform object don’t have all the same physical units (e.g., both degrees and radians coordinates in an Astropy WCS object). Unless we forbid that in xarray.CoordinateTransform, it doesn’t make much sense to pass a single tolerance value. Passing a dictionary tolerance={coord_name: value} doesn’t look very nice either IMO. A {unit: value} dict looks better but adding explicit support for units here might be opening a can of worms.

@shoyer
Copy link
Member

shoyer commented Oct 2, 2024

This very exciting! Nice work.

For indexing, it may be worth considering if you can implement .interp(). In practice I think that is often more desirable than nearest neighbor lookup.

@rbavery
Copy link

rbavery commented Oct 21, 2024

rioxarray creates x(x) / y(y) dimension coordinates when the affine transform is rectilinear with no rotation). For those cases we cannot use a single CoordinateTransform instance

I think rioxarray does this for performance reasons because it is faster and possible to correctly apply the affine transformation without calling numpy.meshgrid when the affine is rectilinear with no rotation. But with flexible coordinates, I think both approaches could be replaced with only a single CoordinateTransform with some refactoring.

import numpy
import affine

transform = affine.Affine.translation(.5,.5).scale(1.0)
width = 512
height = 200
%%time
x_coords, _ = transform * (numpy.arange(width), numpy.zeros(width))
_, y_coords = transform * (numpy.zeros(height), numpy.arange(height))

Wall time: 290 μs

%%time
x_coords_mesh, y_coords_mesh = transform * numpy.meshgrid(
    numpy.arange(width),
    numpy.arange(height),
)

Wall time: 3.09 ms

possible to add generic support for joining / concatenating coordinate transforms? I.e., implement CoordinateTransformIndex.concat and CoordinateTransformIndex.join

This sounds very valuable but want to make sure I understand what is meant. If I have a dataset of rasters across different UTM projections, would this allow me to read each with rioxarray and then concatenate the raster arrays such that each raster maintains it's original CRS? Or would this enable concatenating rasters that are already in the same CRS? Or something else?

My use case for this is I'd like to avoid reprojection and have a single xarray.DataArray representing rasters spread over global extents. And I'd like to be able to save this concatenated xarray DataArray to a Zarr v3 store with sharding in a way that preserves each CRS, with GeoZarr.

@benbovy
Copy link
Member Author

benbovy commented Oct 22, 2024

both approaches could be replaced with only a single CoordinateTransform with some refactoring.

Hmm do you have an idea on how this refactoring would look like? I've tried implementing a version of CoordinateTransform that supports coordinates with different dimensions but I eventually gave up because it was too complicated.

Here is one way to support the rectilinear / no rotation affine transform with independent x, y 1-dimensional coordinates without any refactoring:

  • a CoordinateTransform subclass that wraps an affine.Affine instance for either the x or y coordinate
---- expand here to see the implementation of AxisAffineCoordinateTransform ----
class AxisAffineCoordinateTransform(xr.CoordinateTransform):
    """1-axis wrapper of an affine 2D coordinate transform
    with no skew/rotation.
    
    """

    affine: affine.Affine
    is_xaxis: bool
    coord_name: Hashable
    dim: str
    size: int
    
    def __init__(
        self,
        affine: affine.Affine,
        coord_name: Hashable,
        dim: str,
        size: int,
        is_xaxis: bool,
        dtype: Any = np.dtype(np.float64),
    ):
        if (not affine.is_rectilinear or (affine.b == affine.d != 0)):
            raise ValueError("affine must be rectilinear with no rotation")

        super().__init__((coord_name,), {dim: size}, dtype=dtype)
        self.affine = affine
        self.is_xaxis = is_xaxis
        self.coord_name = coord_name
        self.dim = dim
        self.size = size

    def forward(self, dim_positions):
        positions = dim_positions[self.dim]

        if self.is_xaxis:
            labels, _ = self.affine * (positions, np.zeros_like(positions))
        else:
            _, labels = self.affine * (np.zeros_like(positions), positions)

        return {self.coord_name: labels}

    def reverse(self, coord_labels):
        labels = coord_labels[self.coord_name]

        if self.is_xaxis:
            positions, _ = ~self.affine * (labels, np.zeros_like(labels))
        else:
            _, positions = ~self.affine * (np.zeros_like(labels), labels)

        return {self.dim: positions}
    
    def equals(self, other):
        return self.affine == other.affine and self.dim_size == other.dim_size
  • an Xarray Index that encapsulates two CoordinateTransformIndex instances (sharing the same Affine object) for the x and y axis respectively
---- expand here to see the implementation of RasterIndex ----
from xarray import Variable
from xarray.indexes import CoordinateTransformIndex
from xarray.core.indexing import IndexSelResult, merge_sel_results


class RasterIndex(xr.indexes.Index):

    def __init__(
        self,
        x_index: CoordinateTransformIndex,
        y_index: CoordinateTransformIndex,
    ):
        self.x_index = x_index
        self.y_index = y_index

    @classmethod
    def from_transform(
        cls,
        affine: affine.Affine,
        shape: tuple[int, int],
        xy_coord_names: tuple[Hashable, Hashable] = ("x", "y"),
    ):
        # shape is in y, x order
        xtr = AxisAffineCoordinateTransform(
            affine, xy_coord_names[0], xy_coord_names[0], shape[1], is_xaxis=True
        )
        ytr = AxisAffineCoordinateTransform(
            affine, xy_coord_names[1], xy_coord_names[1], shape[0], is_xaxis=False
        )

        return cls(CoordinateTransformIndex(xtr), CoordinateTransformIndex(ytr))

    def create_variables(
        self, variables: Mapping[Any, Variable] | None = None
    ) -> dict[Hashable, Variable]:
        return {**self.x_index.create_variables(), **self.y_index.create_variables()}

    def create_coords(self) -> xr.Coordinates:
        variables = self.create_variables()
        indexes = {name: self for name in variables}
        return xr.Coordinates(coords=variables, indexes=indexes)
    
    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        results = []

        xlabels = {k: v for k, v in labels if k in self.x_index.transform.coord_names}
        if xlabels:
            results.append(self.x_index.sel(xlabels))
        
        ylabels = {k: v for k, v in labels if k in self.y_index.transform.coord_names}
        if ylabels:
            results.append(self.y_index.sel(ylabels))
        
        return merge_sel_results(results)
       
     def equals(self, other: Self) -> bool:
        return self.x_index.equals(other.x_index) and self.y_index.equals(other.y_index)

Usage example:

>>> index = RasterIndex.from_transform(affine.Affine.translation(0.5, 0.5), (1000, 2000))
>>> ds = xr.Dataset(coords=index.create_coords())
>>> ds
<xarray.Dataset> Size: 24kB
Dimensions:  (x: 2000, y: 1000)
Coordinates:
  * x        (x) float64 16kB 0.5 1.5 2.5 3.5 ... 1.998e+03 1.998e+03 2e+03
  * y        (y) float64 8kB 0.5 1.5 2.5 3.5 4.5 ... 996.5 997.5 998.5 999.5
Data variables:
    *empty*
Indexes:
  ┌ x        RasterIndexy

>>> ds.isel(x=slice(100, 200), y=500)
<xarray.Dataset> Size: 808B
Dimensions:  (x: 100)
Coordinates:
    x        (x) float64 800B 100.5 101.5 102.5 103.5 ... 197.5 198.5 199.5
    y        float64 8B 500.5
Data variables:
    *empty*

@benbovy
Copy link
Member Author

benbovy commented Oct 22, 2024

My use case for this is I'd like to avoid reprojection and have a single xarray.DataArray representing rasters spread over global extents. And I'd like to be able to save this concatenated xarray DataArray to a Zarr v3 store with sharding in a way that preserves each CRS, with GeoZarr.

This seems complicated to me. It would be easier in this case to have a unique CRS per DataArray and provide a virtual layer built on top of DataArray to handle lazy reprojection (e.g., similarly to GDAL VRT I guess?).

Also, IIUC Zarr doesn't allow per-chunk or per-shard metadata so it isn't clear to me how GeoZarr would support multiple CRS datasets (zarr-developers/geozarr-spec#4).

@martindurant
Copy link
Contributor

Zarr doesn't allow per-chunk or per-shard metadata

perhaps virtualizarr can do this (@TomNicholas )

@keewis
Copy link
Collaborator

keewis commented Oct 22, 2024

xarray doesn't support per-chunk / per-shard metadata, either, so we'd have to add a coordinate containing all the different crs.

@TomNicholas
Copy link
Member

VirtualiZarr isn't the right layer for that feature - Zarr itself (or Icechunk) would have to support it on disk.

@benbovy
Copy link
Member Author

benbovy commented Oct 22, 2024

Technically it could be possible to implement an Xarray index that keeps track of the original CRSs and that encapsulates custom CoordinateTransform objects, such that from the user point of view concat() returns geospatial (lazy) coordinates with a unique "virtual" CRS (i.e., with lazy re-projection).

This might be useful in some cases... Although we'll always have to choose the final CRS and re-project the data at some point (e.g., visualization, load in memory, write the data to a store or file, etc...).

Besides the index coordinates themselves, all the (data) variables sharing the same geospatial dimensions would also need to be lazy and somehow wrap the logic to re-project (resample) their data. A custom Xarray index based on coordinate transforms is therefore not enough, we would also need a custom Xarray IO backend, accessors, etc.

@RichardScottOZ
Copy link
Contributor

rioxarray creates x(x) / y(y) dimension coordinates when the affine transform is rectilinear with no rotation). For those cases we cannot use a single CoordinateTransform instance

I think rioxarray does this for performance reasons because it is faster and possible to correctly apply the affine transformation without calling numpy.meshgrid when the affine is rectilinear with no rotation. But with flexible coordinates, I think both approaches could be replaced with only a single CoordinateTransform with some refactoring.

import numpy
import affine

transform = affine.Affine.translation(.5,.5).scale(1.0)
width = 512
height = 200
%%time
x_coords, _ = transform * (numpy.arange(width), numpy.zeros(width))
_, y_coords = transform * (numpy.zeros(height), numpy.arange(height))

Wall time: 290 μs

%%time
x_coords_mesh, y_coords_mesh = transform * numpy.meshgrid(
    numpy.arange(width),
    numpy.arange(height),
)

Wall time: 3.09 ms

possible to add generic support for joining / concatenating coordinate transforms? I.e., implement CoordinateTransformIndex.concat and CoordinateTransformIndex.join

This sounds very valuable but want to make sure I understand what is meant. If I have a dataset of rasters across different UTM projections, would this allow me to read each with rioxarray and then concatenate the raster arrays such that each raster maintains it's original CRS? Or would this enable concatenating rasters that are already in the same CRS? Or something else?

My use case for this is I'd like to avoid reprojection and have a single xarray.DataArray representing rasters spread over global extents. And I'd like to be able to save this concatenated xarray DataArray to a Zarr v3 store with sharding in a way that preserves each CRS, with GeoZarr.

Good question - and have global scale problems currently - sounds like this would need to be invented though?

@mdsumner
Copy link

On mixed crs collections, suggest looking at the GDAL warper which takes multiple inputs of any kind, and wrappers like odc-geo loader for xarray, and collection formats like GTI, STACIT, and the geoparquet for STAC, you always have a latent grid spec for the warper to target, but that can be overridden at write time. It doesn't have to just be stac or utm/mgrs and could include geolocation-array sources.

https://gdal.org/en/latest/drivers/raster/gti.html

I don't think I've seen model grids included in a collection like this, but odc-geo is working towards multidimensional support, which at least should be on the radar here 🙏

@rbavery
Copy link

rbavery commented Oct 23, 2024

Thanks for the explanations and comments all. I made an issue to continue discussion here so as not to distract from this PR.

Hmm do you have an idea on how this refactoring would look like? I've tried implementing a version of CoordinateTransform that supports coordinates with different dimensions but I eventually gave up because it was too complicated.

Ok disregard what I said previously, this seems much simpler to implement, I'll try to test this out this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

10 participants