Source code for fme.core.dataset.config

import dataclasses
from datetime import timedelta
from typing import Literal, Mapping, Optional, Union

import numpy as np
import pandas as pd
import torch
import xarray as xr

from fme.core.typing_ import Slice, TensorDict


[docs]@dataclasses.dataclass class TimeSlice: """ Configuration of a slice of times. Step is an integer-valued index step. Note: start_time and stop_time may be provided as partial time strings and the stop_time will be included in the slice. See more details in `Xarray docs`_. Parameters: start_time: Start time of the slice. stop_time: Stop time of the slice. step: Step of the slice. .. _Xarray docs: https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range """ # noqa: E501 start_time: Optional[str] = None stop_time: Optional[str] = None step: Optional[int] = None def slice(self, time: xr.CFTimeIndex) -> slice: return time.slice_indexer(self.start_time, self.stop_time, self.step)
def _convert_interval_to_int( interval: pd.Timedelta, timestep: timedelta, ): """Convert interval to integer number of timesteps.""" if interval % timestep != timedelta(0): raise ValueError( f"Requested interval length {interval} is not a " f"multiple of the timestep {timestep}." ) return interval // timestep @dataclasses.dataclass class RepeatedInterval: """ Configuration for a repeated interval within a block. This configuration is used to generate a boolean mask for a dataset that will return values within the interval and repeat that throughout the dataset. Parameters: interval_length: Length of the interval to return values from start: Start position of the interval within the repeat block. block_length: Total length of the block to be repeated over the length of the dataset, including the interval length. Note: The interval_length, start, and block_length can be provided as either all integers or all strings representing timedeltas of the block. If provided as strings, the timestep must be provided when calling `get_boolean_mask`. Examples: To return values from the first 3 items of every 6 items, use: >>> RepeatedInterval(interval_length=3, repeat=6, start=0) To return a days worth of values starting after 2 days from every 7-day block, use: >>> RepeatedInterval(interval_length="1d", repeat="7d", start="2d") """ interval_length: Union[int, str] start: Union[int, str] block_length: Union[int, str] def __post_init__(self): types = {type(self.interval_length), type(self.block_length), type(self.start)} if len(types) > 1: raise ValueError( "All attributes of RepeatedInterval must be of the " "same type (either all int or all str)." ) self._is_time_delta_str = isinstance(self.interval_length, str) if self._is_time_delta_str: self.interval_length = pd.Timedelta(self.interval_length) self.block_length = pd.Timedelta(self.block_length) self.start = pd.Timedelta(self.start) def get_boolean_mask( self, length: int, timestep: Optional[timedelta] = None ) -> np.ndarray: """ Return a boolean mask for the repeated interval. Args: length: Length of the dataset. timestep: Timestep of the dataset. """ if self._is_time_delta_str: if timestep is None: raise ValueError( "Timestep must be provided when using time deltas " "for RepeatedInterval." ) interval_length = _convert_interval_to_int(self.interval_length, timestep) block_length = _convert_interval_to_int(self.block_length, timestep) start = _convert_interval_to_int(self.start, timestep) else: interval_length = self.interval_length block_length = self.block_length start = self.start if start + interval_length > block_length: raise ValueError( "The interval (with start point) must fit within the repeat block." ) block = np.zeros(block_length, dtype=bool) block[start : start + interval_length] = True num_blocks = length // block_length + 1 mask = np.tile(block, num_blocks)[:length] return mask @dataclasses.dataclass class OverwriteConfig: """Configuration to overwrite field values in XarrayDataset. Parameters: constant: Fill field with constant value. multiply_scalar: Multiply field by scalar value. """ constant: Mapping[str, float] = dataclasses.field(default_factory=dict) multiply_scalar: Mapping[str, float] = dataclasses.field(default_factory=dict) def __post_init__(self): key_overlap = set(self.constant.keys()) & set(self.multiply_scalar.keys()) if key_overlap: raise ValueError( "OverwriteConfig cannot have the same variable in both constant " f"and multiply_scalar: {key_overlap}" ) def apply(self, tensors: TensorDict) -> TensorDict: for var, fill_value in self.constant.items(): data = tensors[var] tensors[var] = torch.ones_like(data) * torch.tensor( fill_value, dtype=data.dtype, device=data.device ) for var, multiplier in self.multiply_scalar.items(): data = tensors[var] tensors[var] = data * torch.tensor( multiplier, dtype=data.dtype, device=data.device ) return tensors @property def variables(self): return set(self.constant.keys()) | set(self.multiply_scalar.keys()) @dataclasses.dataclass class FillNaNsConfig: """ Configuration to fill NaNs with a constant value or others. Parameters: method: Type of fill operation. Currently only 'constant' is supported. value: Value to fill NaNs with. """ method: Literal["constant"] = "constant" value: float = 0.0
[docs]@dataclasses.dataclass class XarrayDataConfig: """ Parameters: data_path: Path to the data. file_pattern: Glob pattern to match files in the data_path. n_repeats: Number of times to repeat the dataset (in time). It is up to the user to ensure that the input dataset to repeat results in data that is reasonably continuous across repetitions. engine: Backend used in xarray.open_dataset call. spatial_dimensions: Specifies the spatial dimensions for the grid, default is lat/lon. subset: Slice defining a subset of the XarrayDataset to load. This can either be a `Slice` of integer indices or a `TimeSlice` of timestamps. infer_timestep: Whether to infer the timestep from the provided data. This should be set to True (the default) for ACE training. It may be useful to toggle this to False for applications like downscaling, which do not depend on the timestep of the data and therefore lack the additional requirement that the data be ordered and evenly spaced in time. It must be set to True if n_repeats > 1 in order to be able to infer the full time coordinate. dtype: Data type to cast the data to. If None, no casting is done. It is required that 'torch.{dtype}' is a valid dtype. overwrite: Optional OverwriteConfig to overwrite loaded field values. If this is configured for a renamed field, the key should be the final updated name. renamed_variables: Optional mapping of {old_name: new_name} to rename variables fill_nans: Optional FillNaNsConfig to fill NaNs with a constant value. Examples: If data is stored in a directory with multiple netCDF files which can be concatenated along the time dimension, use: >>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") # doctest: +IGNORE_OUTPUT If data is stored in a single zarr store at ``/some/directory/dataset.zarr``, use: >>> fme.ace.XarrayDataConfig( ... data_path="/some/directory", ... file_pattern="dataset.zarr", ... engine="zarr" ... ) # doctest: +IGNORE_OUTPUT """ # noqa: E501 data_path: str file_pattern: str = "*.nc" n_repeats: int = 1 engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4" spatial_dimensions: Literal["healpix", "latlon"] = "latlon" subset: Union[Slice, TimeSlice, RepeatedInterval] = dataclasses.field( default_factory=Slice ) infer_timestep: bool = True dtype: Optional[str] = "float32" overwrite: OverwriteConfig = dataclasses.field(default_factory=OverwriteConfig) renamed_variables: Optional[Mapping[str, str]] = None fill_nans: Optional[FillNaNsConfig] = None def _default_file_pattern_check(self): if self.engine == "zarr" and self.file_pattern == "*.nc": raise ValueError( "The file pattern is set to the default NetCDF file pattern *.nc " "but the engine is specified as 'zarr'. Please set " "`XarrayDataConfig.file_pattern` to match the zarr filename." ) def __post_init__(self): if self.n_repeats > 1 and not self.infer_timestep: raise ValueError( "infer_timestep must be True if n_repeats is greater than 1" ) if self.dtype is None: self.torch_dtype = None else: try: self.torch_dtype = getattr(torch, self.dtype) except AttributeError: raise ValueError(f"Invalid dtype '{self.dtype}'") if not isinstance(self.torch_dtype, torch.dtype): raise ValueError(f"Invalid dtype '{self.dtype}'") # Raise error if overwrite variables are in the keys of renamed variables if self.renamed_variables is not None: overlap = set(self.overwrite.variables) & set(self.renamed_variables.keys()) if overlap: raise ValueError( "Variables in overwrite should not be the original names before " f"renaming: {overlap}. " "Please use the final renamed variables in the overwrite config." ) self._default_file_pattern_check()