diff --git a/climada/engine/test/test_impact_forecast.py b/climada/engine/test/test_impact_forecast.py index ac47deab99..f571018598 100644 --- a/climada/engine/test/test_impact_forecast.py +++ b/climada/engine/test/test_impact_forecast.py @@ -19,10 +19,14 @@ Tests for Impact Forecast. """ +import datetime as dt +from pathlib import Path + import numpy as np import numpy.testing as npt import pandas as pd import pytest +import xarray as xr from scipy.sparse import csr_matrix from climada.engine import Impact, ImpactForecast @@ -58,6 +62,7 @@ def impact_forecast(impact, lead_time, member): class TestImpactForecastInit: + def assert_impact_kwargs(self, impact: Impact, **kwargs): for key, value in kwargs.items(): attr = getattr(impact, key) diff --git a/climada/hazard/forecast.py b/climada/hazard/forecast.py index cf6980f5ff..bebd4ddd5d 100644 --- a/climada/hazard/forecast.py +++ b/climada/hazard/forecast.py @@ -20,9 +20,14 @@ """ import logging +import pathlib +from typing import Any, Dict, List, Optional import numpy as np import scipy.sparse as sparse +import xarray as xr + +from climada.hazard.xarray import HazardXarrayReader from ..util.checker import size from ..util.forecast import Forecast @@ -86,9 +91,9 @@ def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray): event_id=hazard.event_id, frequency=hazard.frequency, frequency_unit=hazard.frequency_unit, + orig=hazard.orig, event_name=hazard.event_name, date=hazard.date, - orig=hazard.orig, intensity=hazard.intensity, fraction=hazard.fraction, ) @@ -282,6 +287,126 @@ def select( reset_frequency=reset_frequency, ) + @classmethod + def from_xarray_raster( + cls, + data: xr.Dataset | pathlib.Path | str, + hazard_type: str, + intensity_unit: str, + *, + intensity: Optional[str] = None, + coordinate_vars: Optional[Dict[str, str]] = None, + crs: str = "EPSG:4326", + open_dataset_kws: dict[str, Any] | None = None, + ): + """Read forecast hazard data from an xarray Dataset + + This extends the parent :py:meth:`~climada.hazard.base.Hazard.from_xarray_raster` + to handle forecast dimensions (lead_time and member). For forecast data, the + "event" dimension is constructed from the Cartesian product of lead_time and + member dimensions, so you don't need to specify an "event" coordinate. + + Parameters + ---------- + data : xarray.Dataset or Path or str + The filepath to read the data from or the already opened dataset + hazard_type : str + The type identifier of the hazard + intensity_unit : str + The physical units of the intensity + intensity : str, optional + Identifier of the DataArray containing the hazard intensity data + coordinate_vars : dict(str, str), optional + Mapping from default coordinate names to coordinate names in the data. + For HazardForecast, should include: + - ``"lead_time"``: name of the lead time coordinate (required) + - ``"member"``: name of the ensemble member coordinate (required) + - ``"longitude"``: name of longitude coordinate (default: "longitude") + - ``"latitude"``: name of latitude coordinate (default: "latitude") + + Note: The "event" coordinate is automatically constructed from lead_time + and member, so it should not be specified. + crs : str, optional + Coordinate reference system identifier. Defaults to "EPSG:4326" + open_dataset_kws : dict, optional + Keyword arguments passed to xarray.open_dataset if data is a file path + A forecast hazard object with lead_time and member attributes populated + + See Also + -------- + :py:meth:`climada.hazard.base.Hazard.from_xarray_raster` + Parent method documentation for standard hazard loading + """ + + # Open dataset if needed + if isinstance(data, (pathlib.Path, str)): + open_dataset_kws = open_dataset_kws or {} + open_dataset_kws = {"chunks": "auto"} | open_dataset_kws + dset = xr.open_dataset(data, **open_dataset_kws) + else: + dset = data + + if intensity is None: + data_var_names = list(dset.data_vars.keys()) + if len(data_var_names) == 0: + raise ValueError("Dataset has no data variables") + intensity = data_var_names[0] + LOGGER.info( + "No intensity variable specified. " + "Assuming intensity variable is '%s'", + intensity, + ) + + # Extract forecast coordinates + coordinate_vars = coordinate_vars or {} + for key in ["lead_time", "member"]: + if key not in coordinate_vars: + raise ValueError( + f"coordinate_vars must include '{key}' key. " + f"Available coordinates: {list(dset.coords.keys())}" + ) + leadtime_var = coordinate_vars["lead_time"] + member_var = coordinate_vars["member"] + + dset = dset.assign_coords( + event=( + (leadtime_var, member_var), + np.zeros((len(dset[leadtime_var]), len(dset[member_var]))), + ) + ) + + dset_squeezed = dset.squeeze() + + # Prepare coordinate_vars for parent call + parent_coord_vars = { + k: v for k, v in coordinate_vars.items() if k not in ["member", "lead_time"] + } + parent_coord_vars["event"] = "event" + + reader = HazardXarrayReader( + data=dset_squeezed, + coordinate_vars=parent_coord_vars, + intensity=intensity, + crs=crs, + ) + + kwargs = reader.get_hazard_kwargs() | { + "haz_type": hazard_type, + "units": intensity_unit, + "lead_time": reader.data_stacked[leadtime_var].to_numpy(), + "member": reader.data_stacked[member_var].to_numpy(), + } + + # Generate from lead_time/member + kwargs["event_name"] = [ + f"lt_{lt / np.timedelta64(1, 'h'):.0f}h_m_{m}" + for lt, m in zip(kwargs["lead_time"], kwargs["member"]) + ] + kwargs["date"] = np.zeros_like(kwargs["date"], dtype=int) + + # Convert to HazardForecast with forecast attributes + return cls(**Hazard._check_and_cast_attrs(kwargs)) + def _quantile(self, q: float, event_name: str | None = None): """ Reduce the impact matrix and at_event of a HazardForecast to the quantile value. diff --git a/climada/hazard/test/test_forecast.py b/climada/hazard/test/test_forecast.py index 26f26de4b1..c667db664c 100644 --- a/climada/hazard/test/test_forecast.py +++ b/climada/hazard/test/test_forecast.py @@ -19,16 +19,29 @@ Tests for Hazard Forecast. """ +import datetime as dt +from pathlib import Path + import numpy as np import numpy.testing as npt import pandas as pd import pytest +import xarray as xr +from packaging.version import Version from scipy.sparse import csr_matrix from climada.hazard.base import Hazard +from climada.hazard.centroids.centr import Centroids from climada.hazard.forecast import HazardForecast from climada.hazard.test.test_base import hazard_kwargs +# See https://docs.xarray.dev/en/stable/whats-new.html#id80 +xarray_leadtime = pytest.mark.skipif( + (Version(xr.__version__) < Version("2025.07.0")) + and (Version(xr.__version__) >= Version("2025.04.0")), + reason="xarray timedelta bug", +) + @pytest.fixture def haz_kwargs(): @@ -128,6 +141,156 @@ def test_type_fail(self, haz_fc, hazard): Hazard.concat([haz_fc, hazard]) +class TestXarrayReader: + + @pytest.fixture() + def forecast_netcdf_file(self, tmp_path_factory): + """Create a NetCDF file with forecast data structure""" + tmpdir = tmp_path_factory.mktemp("forecast_data") + netcdf_path = tmpdir / "forecast_data.nc" + + crs = "EPSG:4326" + + n_eps = 5 + n_lead_time = 4 + n_lat = 3 + n_lon = 4 + + eps = np.array([3, 8, 13, 16, 20]) + ref_time = np.array([dt.datetime(2025, 12, 8, 6, 0, 0)], dtype="datetime64[ns]") + lead_time_vals = pd.timedelta_range( + "3h", periods=n_lead_time, freq="2h" + ).to_numpy() + lon = np.array([10.0, 10.5, 11.0, 11.5]) + lat = np.array([45.0, 45.5, 46.0]) + + valid_time = ref_time[0] + lead_time_vals + + np.random.seed(42) + intensity = np.random.rand(n_eps, 1, n_lead_time, n_lat, n_lon) * 10 + + # Create xarray Dataset + dset = xr.Dataset( + { + "__xarray_dataarray_variable__": ( + ["eps", "ref_time", "lead_time", "lat", "lon"], + intensity, + ), + }, + coords={ + "eps": eps, + "ref_time": ref_time, + "lead_time": lead_time_vals, + "lon": lon, + "lat": lat, + "valid_time": (["lead_time"], valid_time), + }, + ) + dset.to_netcdf(netcdf_path) + + return { + "path": netcdf_path, + "n_eps": n_eps, + "n_lead_time": n_lead_time, + "n_lat": n_lat, + "n_lon": n_lon, + "eps": eps, + "lead_time": lead_time_vals, + "lon": lon, + "lat": lat, + "crs": crs, + } + + @xarray_leadtime + def test_from_xarray_raster_basic(self, forecast_netcdf_file): + """Test basic loading of forecast hazard from xarray""" + haz_fc = HazardForecast.from_xarray_raster( + forecast_netcdf_file["path"], + hazard_type="PR", + intensity_unit="mm/h", + coordinate_vars={ + "longitude": "lon", + "latitude": "lat", + "lead_time": "lead_time", + "member": "eps", + }, + ) + + # Check that it's a HazardForecast instance + assert isinstance(haz_fc, HazardForecast) + + # Check dimensions - after stacking, we should have n_eps * n_lead_time events + expected_n_events = ( + forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"] + ) + assert len(haz_fc.event_id) == expected_n_events + assert len(haz_fc.lead_time) == expected_n_events + assert len(haz_fc.member) == expected_n_events + + # Check that lead_time and member are correctly extracted + npt.assert_array_equal(np.unique(haz_fc.member), forecast_netcdf_file["eps"]) + + # Check intensity shape (events x centroids) + expected_n_centroids = ( + forecast_netcdf_file["n_lat"] * forecast_netcdf_file["n_lon"] + ) + assert haz_fc.intensity.shape == (expected_n_events, expected_n_centroids) + + # Check centroids + assert len(haz_fc.centroids.lat) == expected_n_centroids + assert len(haz_fc.centroids.lon) == expected_n_centroids + + @xarray_leadtime + def test_from_xarray_raster_event_names(self, forecast_netcdf_file): + """Test that event names are auto-generated from lead_time and member""" + haz_fc = HazardForecast.from_xarray_raster( + forecast_netcdf_file["path"], + hazard_type="PR", + intensity_unit="mm/h", + coordinate_vars={ + "longitude": "lon", + "latitude": "lat", + "lead_time": "lead_time", + "member": "eps", + }, + crs=forecast_netcdf_file["crs"], + ) + + # Check that event names are generated with lead_time in hours + expected_n_events = ( + forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"] + ) + assert len(haz_fc.event_name) == expected_n_events + + event_names_expected = [ + f"lt_{lt / np.timedelta64(1, 'h'):.0f}h_m_{mm}" + for lt, mm in zip(haz_fc.lead_time, haz_fc.member) + ] + npt.assert_array_equal(haz_fc.event_name, event_names_expected) + + @xarray_leadtime + def test_from_xarray_raster_dates(self, forecast_netcdf_file): + """Test that dates are set to 0 for forecast events""" + haz_fc = HazardForecast.from_xarray_raster( + forecast_netcdf_file["path"], + hazard_type="PR", + intensity_unit="mm/h", + coordinate_vars={ + "longitude": "lon", + "latitude": "lat", + "lead_time": "lead_time", + "member": "eps", + }, + crs=forecast_netcdf_file["crs"], + ) + + # Check that all dates are 0 (undefined for forecast) + expected_n_events = ( + forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"] + ) + npt.assert_array_equal(haz_fc.date, np.zeros(expected_n_events, dtype=int)) + + class TestSelect: @pytest.mark.parametrize( diff --git a/climada/hazard/xarray.py b/climada/hazard/xarray.py index df7fc9bf67..7e04f430ce 100644 --- a/climada/hazard/xarray.py +++ b/climada/hazard/xarray.py @@ -238,6 +238,8 @@ class HazardXarrayReader: ---------- data : xr.Dataset The data to be read as hazard. + data_stacked : xr.Dataset + The internally stacked (vectorized) version of ``data``. intensity : str The name of the variable containing the hazard intensity information. Default: ``"intensity"`` @@ -254,6 +256,7 @@ class HazardXarrayReader: """ data: xr.Dataset + data_stacked: xr.Dataset = field(init=False) intensity: str = "intensity" coordinate_vars: InitVar[dict[str, str] | None] = field(default=None, kw_only=True) data_vars: dict[str, str] | None = field(default=None, kw_only=True) @@ -344,7 +347,7 @@ def get_hazard_kwargs(self) -> dict[str, Any]: # preserve order. However, we want longitude to run faster than latitude. # So we use 'dict' without values, as 'dict' preserves insertion order # (dict keys behave like a set). - data = data.stack( + self.data_stacked = data.stack( event=self.data_dims["event"], lat_lon=list( dict.fromkeys( @@ -355,20 +358,20 @@ def get_hazard_kwargs(self) -> dict[str, Any]: # Transform coordinates into centroids centroids = Centroids( - lat=data[self.coords["latitude"]].to_numpy(), - lon=data[self.coords["longitude"]].to_numpy(), + lat=self.data_stacked[self.coords["latitude"]].to_numpy(), + lon=self.data_stacked[self.coords["longitude"]].to_numpy(), crs=self.crs, ) # Read the intensity data LOGGER.debug("Loading Hazard intensity from DataArray '%s'", self.intensity) - intensity_matrix = _to_csr_matrix(data[self.intensity]) + intensity_matrix = _to_csr_matrix(self.data_stacked[self.intensity]) # Create a DataFrame storing access information for each of data_vars # NOTE: Each row will be passed as arguments to # `load_from_xarray_or_return_default`, see its docstring for further # explanation of the DataFrame columns / keywords. - num_events = data.sizes["event"] + num_events = self.data_stacked.sizes["event"] data_ident = pd.DataFrame( data={ # The attribute of the Hazard class where the data will be stored @@ -384,10 +387,12 @@ def get_hazard_kwargs(self) -> dict[str, Any]: np.array(range(num_events), dtype=int) + 1, list( _year_month_day_accessor( - data[self.coords["event"]], strict=False + self.data_stacked[self.coords["event"]], strict=False ).flat ), - _date_to_ordinal_accessor(data[self.coords["event"]], strict=False), + _date_to_ordinal_accessor( + self.data_stacked[self.coords["event"]], strict=False + ), ], # The accessor for the data in the Dataset "accessor": [ @@ -411,10 +416,11 @@ def get_hazard_kwargs(self) -> dict[str, Any]: # Set the Hazard attributes for _, ident in data_ident.iterrows(): self.hazard_kwargs[ident["hazard_attr"]] = ( - _load_from_xarray_or_return_default(data=data, **ident) + _load_from_xarray_or_return_default(data=self.data_stacked, **ident) ) # Done! + LOGGER.debug("Hazard successfully loaded. Number of events: %i", num_events) self.hazard_kwargs.update(centroids=centroids, intensity=intensity_matrix) return self.hazard_kwargs