Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions notebooks/Reprojection.ipynb

Large diffs are not rendered by default.

111 changes: 73 additions & 38 deletions rasters/kdtree.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from __future__ import annotations

import warnings
from typing import Dict, TYPE_CHECKING
from typing import Union, Dict, TYPE_CHECKING


import msgpack
import msgpack_numpy
import numpy as np
from pyresample import SwathDefinition, AreaDefinition
from pyresample.kd_tree import get_neighbour_info, get_sample_from_neighbour_info

if TYPE_CHECKING:
from .raster import Raster
from .multi_raster import MultiRaster
from .raster_geometry import RasterGeometry
from .raster_grid import RasterGrid
from .raster_geolocation import RasterGeolocation

if TYPE_CHECKING:
from .raster_geometry import RasterGeometry
from .raster_grid import RasterGrid
Expand Down Expand Up @@ -43,8 +51,8 @@ class KDTree:
"""
def __init__(
self,
source_geometry: RasterGeometry,
target_geometry: RasterGeometry,
source_geometry: 'RasterGeometry',
target_geometry: 'RasterGeometry',
radius_of_influence: float = None,
neighbours: int = 1,
epsilon: float = 0,
Expand Down Expand Up @@ -248,43 +256,70 @@ def load(cls, filename: str) -> KDTree:

def resample(
self,
source,
source: Union["Raster", "MultiRaster"],
fill_value=0,
**kwargs):
**kwargs) -> Union["Raster", "MultiRaster"]:
from .raster import Raster

source = np.array(source)

if not source.shape == self.source_geo_def.shape:
raise ValueError("source data does not match source geometry")

bool_conversion = str(source.dtype) == "bool"

if bool_conversion:
source = source.astype(np.uint16)

with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# Support both Raster (2D) and MultiRaster (3D) objects
from .multi_raster import MultiRaster
from .raster import Raster

resampled_data = get_sample_from_neighbour_info(
resample_type=self.resample_type,
output_shape=self.target_geo_def.shape,
data=source,
valid_input_index=self.valid_input_index,
valid_output_index=self.valid_output_index,
index_array=self.index_array,
distance_array=self.distance_array,
fill_value=fill_value,
**kwargs
# Determine if input is Raster or MultiRaster
is_multiraster = False
if isinstance(source, MultiRaster):
arr = np.array(source.array)
is_multiraster = True
elif isinstance(source, Raster):
arr = np.array(source.array)
else:
arr = np.array(source)
if arr.ndim > 2:
is_multiraster = True

# Always iterate over bands, whether Raster (single band) or MultiRaster (multi-band)
if arr.ndim == 2:
arr = arr[np.newaxis, ...] # shape becomes (1, rows, cols)
bands = arr.shape[0]

src_rows, src_cols = arr.shape[-2], arr.shape[-1]
geo_rows, geo_cols = self.source_geo_def.shape
if (src_rows, src_cols) != (geo_rows, geo_cols):
raise ValueError(f"source data rows/cols ({src_rows}, {src_cols}) do not match source geometry rows/cols ({geo_rows}, {geo_cols})")

band_results = []
for i in range(bands):
band_arr = arr[i]
band_bool_conversion = str(band_arr.dtype) == "bool"
if band_bool_conversion:
band_arr = band_arr.astype(np.uint16)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
band_resampled = get_sample_from_neighbour_info(
resample_type=self.resample_type,
output_shape=self.target_geo_def.shape,
data=band_arr,
valid_input_index=self.valid_input_index,
valid_output_index=self.valid_output_index,
index_array=self.index_array,
distance_array=self.distance_array,
fill_value=fill_value,
**kwargs
)
if band_bool_conversion:
band_resampled = band_resampled.astype(bool)
band_results.append(band_resampled)

resampled_data = np.stack(band_results, axis=0)
if bands == 1:
return Raster(
array=resampled_data[0],
geometry=self.target_geometry,
nodata=fill_value
)
else:
return MultiRaster(
array=resampled_data,
geometry=self.target_geometry,
nodata=fill_value
)

if bool_conversion:
resampled_data = resampled_data.astype(bool)

output_raster = Raster(
array=resampled_data,
geometry=self.target_geometry,
nodata=fill_value
)

return output_raster