Source code for fme.core.data_loading.config

import dataclasses
from typing import Literal, Optional, Sequence, Union

import xarray as xr

from fme.core.distributed import Distributed


[docs]@dataclasses.dataclass class Slice: """ Configuration of a python `slice` built-in. Required because `slice` cannot be initialized directly by dacite. Attributes: start: Start index of the slice. stop: Stop index of the slice. step: Step of the slice. """ start: Optional[int] = None stop: Optional[int] = None step: Optional[int] = None @property def slice(self) -> slice: return slice(self.start, self.stop, self.step)
[docs]@dataclasses.dataclass class TimeSlice: """ Configuration of a slice of times. Step is an integer-valued index step. Note: start_time and stop_time may be provided as partial time strings and the stop_time will be included in the slice. See more details in `Xarray docs`_. Attributes: start_time: Start time of the slice. stop_time: Stop time of the slice. step: Step of the slice. .. _Xarray docs: https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range # noqa """ start_time: Optional[str] = None stop_time: Optional[str] = None step: Optional[int] = None def slice(self, times: xr.CFTimeIndex) -> slice: return times.slice_indexer(self.start_time, self.stop_time, self.step)
[docs]@dataclasses.dataclass class XarrayDataConfig: """ Attributes: data_path: Path to the data. file_pattern: Glob pattern to match files in the data_path. n_repeats: Number of times to repeat the dataset (in time). It is up to the user to ensure that the input dataset to repeat results in data that is reasonably continuous across repetitions. engine: Backend for xarray.open_dataset. Currently supported options are "netcdf4" (the default) and "h5netcdf". Only valid when using XarrayDataset. spatial_dimensions: Specifies the spatial dimensions for the grid, default is lat/lon. subset: Slice defining a subset of the XarrayDataset to load. This can either be a `Slice` of integer indices or a `TimeSlice` of timestamps. infer_timestep: Whether to infer the timestep from the provided data. This should be set to True (the default) for ACE training. It may be useful to toggle this to False for applications like downscaling, which do not depend on the timestep of the data and therefore lack the additional requirement that the data be ordered and evenly spaced in time. It must be set to True if n_repeats > 1 in order to be able to infer the full time coordinate. """ data_path: str file_pattern: str = "*.nc" n_repeats: int = 1 engine: Optional[Literal["netcdf4", "h5netcdf", "zarr"]] = None spatial_dimensions: Literal["healpix", "latlon"] = "latlon" subset: Union[Slice, TimeSlice] = dataclasses.field(default_factory=Slice) infer_timestep: bool = True def __post_init__(self): if self.n_repeats > 1 and not self.infer_timestep: raise ValueError( "infer_timestep must be True if n_repeats is greater than 1" )
[docs]@dataclasses.dataclass class DataLoaderConfig: """ Attributes: dataset: A sequence of configurations each defining a dataset to be loaded. This sequence of datasets will be concatenated. batch_size: Number of samples per batch. num_data_workers: Number of parallel workers to use for data loading. prefetch_factor: how many batches a single data worker will attempt to hold in host memory at a given time. strict_ensemble: Whether to enforce that the ensemble members have the same dimensions and coordinates. """ dataset: Sequence[XarrayDataConfig] batch_size: int num_data_workers: int prefetch_factor: Optional[int] = None strict_ensemble: bool = True def __post_init__(self): dist = Distributed.get_instance() if self.batch_size % dist.world_size != 0: raise ValueError( "batch_size must be divisible by the number of parallel " f"workers, got {self.batch_size} and {dist.world_size}" )