import dataclasses
from collections.abc import Mapping, Sequence
from datetime import timedelta
from typing import Literal
import numpy as np
import pandas as pd
import torch
import xarray as xr
from fme.core.typing_ import Slice, TensorDict
[docs]@dataclasses.dataclass
class TimeSlice:
"""
Configuration of a slice of times. Step is an integer-valued index step.
Note: start_time and stop_time may be provided as partial time strings and the
stop_time will be included in the slice. See more details in `Xarray docs`_.
Parameters:
start_time: Start time of the slice.
stop_time: Stop time of the slice.
step: Step of the slice.
.. _Xarray docs:
https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range
""" # noqa: E501
start_time: str | None = None
stop_time: str | None = None
step: int | None = None
def slice(self, time: xr.CFTimeIndex) -> slice:
return time.slice_indexer(self.start_time, self.stop_time, self.step)
def _convert_interval_to_int(
interval: pd.Timedelta,
timestep: timedelta,
):
"""Convert interval to integer number of timesteps."""
if interval % timestep != timedelta(0):
raise ValueError(
f"Requested interval length {interval} is not a "
f"multiple of the timestep {timestep}."
)
return interval // timestep
[docs]@dataclasses.dataclass
class RepeatedInterval:
"""
Configuration for a repeated interval within a block. This configuration
is used to generate a boolean mask for a dataset that will return values
within the interval and repeat that throughout the dataset.
Parameters:
interval_length: Length of the interval to return values from
start: Start position of the interval within the repeat block.
block_length: Total length of the block to be repeated over the length of
the dataset, including the interval length.
Note:
The interval_length, start, and block_length can be provided as either
all integers or all strings representing timedeltas of the block.
If provided as strings, the timestep must be provided when calling
`get_boolean_mask`.
Examples:
To return values from the first 3 items of every 6 items, use:
>>> fme.ace.RepeatedInterval(interval_length=3, block_length=6, start=0) # doctest: +IGNORE_OUTPUT
To return a days worth of values starting after 2 days from every 7-day
block, use:
>>> fme.ace.RepeatedInterval(interval_length="1d", block_length="7d", start="2d") # doctest: +IGNORE_OUTPUT
""" # noqa: E501
interval_length: int | str
start: int | str
block_length: int | str
def __post_init__(self):
types = {type(self.interval_length), type(self.block_length), type(self.start)}
if len(types) > 1:
raise ValueError(
"All attributes of RepeatedInterval must be of the "
"same type (either all int or all str)."
)
self._is_time_delta_str = isinstance(self.interval_length, str)
if self._is_time_delta_str:
self.interval_length = pd.Timedelta(self.interval_length)
self.block_length = pd.Timedelta(self.block_length)
self.start = pd.Timedelta(self.start)
[docs] def get_boolean_mask(
self, length: int, timestep: timedelta | None = None
) -> np.ndarray:
"""
Return a boolean mask for the repeated interval.
Args:
length: Length of the dataset.
timestep: Timestep of the dataset.
"""
if self._is_time_delta_str:
if timestep is None:
raise ValueError(
"Timestep must be provided when using time deltas "
"for RepeatedInterval."
)
interval_length = _convert_interval_to_int(self.interval_length, timestep)
block_length = _convert_interval_to_int(self.block_length, timestep)
start = _convert_interval_to_int(self.start, timestep)
else:
interval_length = self.interval_length
block_length = self.block_length
start = self.start
if start + interval_length > block_length:
raise ValueError(
"The interval (with start point) must fit within the repeat block."
)
block = np.zeros(block_length, dtype=bool)
block[start : start + interval_length] = True
num_blocks = length // block_length + 1
mask = np.tile(block, num_blocks)[:length]
return mask
[docs]@dataclasses.dataclass
class OverwriteConfig:
"""Configuration to overwrite field values in XarrayDataset.
Parameters:
constant: Fill field with constant value.
multiply_scalar: Multiply field by scalar value.
"""
constant: Mapping[str, float] = dataclasses.field(default_factory=dict)
multiply_scalar: Mapping[str, float] = dataclasses.field(default_factory=dict)
def __post_init__(self):
key_overlap = set(self.constant.keys()) & set(self.multiply_scalar.keys())
if key_overlap:
raise ValueError(
"OverwriteConfig cannot have the same variable in both constant "
f"and multiply_scalar: {key_overlap}"
)
def apply(self, tensors: TensorDict) -> TensorDict:
for var, fill_value in self.constant.items():
data = tensors[var]
tensors[var] = torch.ones_like(data) * torch.tensor(
fill_value, dtype=data.dtype, device=data.device
)
for var, multiplier in self.multiply_scalar.items():
data = tensors[var]
tensors[var] = data * torch.tensor(
multiplier, dtype=data.dtype, device=data.device
)
return tensors
@property
def variables(self):
return set(self.constant.keys()) | set(self.multiply_scalar.keys())
[docs]@dataclasses.dataclass
class FillNaNsConfig:
"""
Configuration to fill NaNs with a constant value or others.
Parameters:
method: Type of fill operation. Currently only 'constant' is supported.
value: Value to fill NaNs with.
"""
method: Literal["constant"] = "constant"
value: float = 0.0
[docs]@dataclasses.dataclass
class XarrayDataConfig:
"""
Parameters:
data_path: Path to the data.
file_pattern: Glob pattern to match files in the data_path.
n_repeats: Number of times to repeat the dataset (in time). It is up
to the user to ensure that the input dataset to repeat results in
data that is reasonably continuous across repetitions.
engine: Backend used in xarray.open_dataset call.
spatial_dimensions: Specifies the spatial dimensions for the grid, default
is lat/lon. If 'latlon', it is assumed that the last two dimensions are
latitude and longitude, respectively. If 'healpix', it is assumed that the
last three dimensions are face, height, and width, respectively.
subset: Slice defining a subset of the XarrayDataset to load. This can
either be a `Slice` of integer indices or a `TimeSlice` of timestamps.
This feature is applied directly to the dataset samples. For example,
if the file(s) have the time coordinate (t0, t1, t2, t3) and
requirements.n_timesteps=2, then subset=Slice(stop=2) will
provide two samples: (t0, t1), (t1, t2).
infer_timestep: Whether to infer the timestep from the provided data.
This should be set to True (the default) for ACE training. It may
be useful to toggle this to False for applications like downscaling,
which do not depend on the timestep of the data and therefore lack
the additional requirement that the data be ordered and evenly
spaced in time. It must be set to True if n_repeats > 1 in order
to be able to infer the full time coordinate.
dtype: Data type to cast the data to. If None, no casting is done. It is
required that 'torch.{dtype}' is a valid dtype.
overwrite: Optional OverwriteConfig to overwrite loaded field values.
fill_nans: Optional FillNaNsConfig to fill NaNs with a constant value.
isel: Optional xarray isel arguments to be passed to the dataset. Will
raise ValueError if time is included here, since the subset argument
is used specifically for selecting times. Horizontal dimensions are
also not currently supported.
Examples:
If data is stored in a directory with multiple netCDF files which can be
concatenated along the time dimension, use:
>>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") # doctest: +IGNORE_OUTPUT
If data is stored in a single zarr store at ``/some/directory/dataset.zarr``,
use:
>>> fme.ace.XarrayDataConfig(
... data_path="/some/directory",
... file_pattern="dataset.zarr",
... engine="zarr"
... ) # doctest: +IGNORE_OUTPUT
""" # noqa: E501
data_path: str
file_pattern: str = "*.nc"
n_repeats: int = 1
engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4"
spatial_dimensions: Literal["healpix", "latlon"] = "latlon"
subset: Slice | TimeSlice | RepeatedInterval = dataclasses.field(
default_factory=Slice
)
infer_timestep: bool = True
dtype: str | None = "float32"
overwrite: OverwriteConfig = dataclasses.field(default_factory=OverwriteConfig)
fill_nans: FillNaNsConfig | None = None
isel: Mapping[str, Slice | int] = dataclasses.field(default_factory=dict)
def _default_file_pattern_check(self):
if self.engine == "zarr" and self.file_pattern == "*.nc":
raise ValueError(
"The file pattern is set to the default NetCDF file pattern *.nc "
"but the engine is specified as 'zarr'. Please set "
"`XarrayDataConfig.file_pattern` to match the zarr filename."
)
@property
def torch_dtype(self) -> torch.dtype | None:
if self.dtype is None:
return None
else:
try:
torch_dtype = getattr(torch, self.dtype)
except AttributeError:
raise ValueError(f"Invalid dtype '{self.dtype}'")
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"Invalid dtype '{self.dtype}'")
return torch_dtype
def __post_init__(self):
if self.n_repeats > 1 and not self.infer_timestep:
raise ValueError(
"infer_timestep must be True if n_repeats is greater than 1"
)
if self.spatial_dimensions not in ["latlon", "healpix"]:
raise ValueError(
f"unexpected spatial_dimensions {self.spatial_dimensions},"
" should be one of 'latlon' or 'healpix'"
)
self.torch_dtype # check it can be retrieved
self._default_file_pattern_check()
self.zarr_engine_used = False
if self.engine == "zarr":
self.zarr_engine_used = True
[docs]@dataclasses.dataclass
class ConcatDatasetConfig:
"""
Configuration for concatenating multiple datasets.
Parameters:
concat: List of XarrayDataConfig objects to concatenate.
strict: Whether to enforce that the datasets to be concatenated
have the same dimensions and coordinates.
"""
concat: Sequence[XarrayDataConfig]
strict: bool = True
def __post_init__(self):
self.zarr_engine_used = any(ds.engine == "zarr" for ds in self.concat)
[docs]@dataclasses.dataclass
class MergeDatasetConfig:
"""
Configuration for merging multiple datasets.
Parameters:
merge: List of ConcatDatasetConfig or XarrayDataConfig to merge.
"""
merge: Sequence[ConcatDatasetConfig | XarrayDataConfig]
def __post_init__(self):
self.zarr_engine_used = False
for ds in self.merge:
if isinstance(ds, ConcatDatasetConfig):
if ds.zarr_engine_used:
self.zarr_engine_used = ds.zarr_engine_used
break
elif isinstance(ds, XarrayDataConfig):
if ds.engine == "zarr":
self.zarr_engine_used = True
break
[docs]@dataclasses.dataclass
class MergeNoConcatDatasetConfig:
"""
Configuration for merging multiple datasets. No concatenation is allowed.
Parameters:
merge: List of XarrayDataConfig to merge.
"""
merge: Sequence[XarrayDataConfig]
def __post_init__(self):
self.zarr_engine_used = False
for ds in self.merge:
if ds.engine == "zarr":
self.zarr_engine_used = True
break