Source code for fme.downscaling.data.config

import dataclasses
from collections.abc import Sequence

from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler

from fme.core.coordinates import LatLonCoordinates
from fme.core.dataset.concat import XarrayConcat, get_dataset
from fme.core.dataset.properties import DatasetProperties
from fme.core.dataset.schedule import IntSchedule
from fme.core.dataset.xarray import XarrayDataConfig, get_raw_paths
from fme.core.device import using_gpu
from fme.core.distributed import Distributed
from fme.downscaling.data.datasets import (
    BatchData,
    BatchItemDatasetAdapter,
    ContiguousDistributedSampler,
    FineCoarsePairedDataset,
    GriddedData,
    HorizontalSubsetDataset,
    PairedBatchData,
    PairedGriddedData,
)
from fme.downscaling.data.topography import (
    StaticInputs,
    Topography,
    get_normalized_topography,
    get_topography_downscale_factor,
)
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.requirements import DataRequirements


def enforce_lat_bounds(lat: ClosedInterval):
    if lat.start < -88.0 or lat.stop > 88.0:
        raise ValueError(
            "Latitude bounds must be within +/-88 degrees, "
            f"got {lat.start} to {lat.stop}."
            "This is enforced because the 3 km X-SHiELD dataset "
            "does not have 32 fine grid midpoints between the last two "
            "coarse latitude midpoints of the 100 km dataset, which breaks "
            "the assumption used for subsetting fine grid latitudes."
        )


@dataclasses.dataclass
class XarrayEnsembleDataConfig:
    """
    Configuration for an ensemble dataset.
    This config's expand method returns a sequence of xarray datasets, each
    with the same data_config, where each individual dataset is an ensemble member
    selected from the ensemble dimension.

    Parameters:
        data_config: XarrayDataConfig for the dataset.
        ensemble_dim: Name of the ensemble dimension in the dataset.
        n_ensemble_members: Number of ensemble members to load. They will be taken
            in order from index 0 of the ensemble_dim.
    """

    data_config: XarrayDataConfig
    ensemble_dim: str
    n_ensemble_members: int

    def __post_init__(self):
        if self.n_ensemble_members <= 0:
            raise ValueError(
                f"n_ensemble_members must be > 0, got {self.n_ensemble_members}"
            )
        if self.ensemble_dim in self.data_config.isel:
            raise ValueError(
                f"Ensemble dimension {self.ensemble_dim} cannot be in the "
                "base data_config.isel"
            )

    def expand(self) -> list[XarrayDataConfig]:
        configs = []
        for i in range(self.n_ensemble_members):
            configs.append(
                dataclasses.replace(
                    self.data_config,
                    isel={self.ensemble_dim: i},
                )
            )
        return configs


[docs]@dataclasses.dataclass class DataLoaderConfig: """ Configuration for loading downscaling data for generation. Input coarse dataset will be processed into batches, usually with a horizontal extent to define a portion of the full domain for use in generation. If the model requires topography, the dataset to use should be specified in the `topography` field. Topography data may be at higher resolution than the data, e.g. when fine topography is loaded as an input. Args: coarse: The dataset configuration. batch_size: The batch size to use for the dataloader. num_data_workers: The number of data workers to use for the dataloader. (For multi-GPU runtime, it's the number of workers per GPU.) strict_ensemble: Whether to enforce that the datasets to be concatened have the same dimensions and coordinates. topography: The dataset path for the topography data. This may be at a higher resolution than the coarse data, e.g. when fine topography is loaded as an input for predictions that have no fine-res paired targets. If None, no topography data will be loaded. lat_extent: The latitude extent to use for the dataset specified in degrees, limited to (-88.0, 88.0). The extent is inclusive, so the start and stop values are included in the extent. Defaults to [-66, 70] which covers continental land masses aside from Antarctica. lon_extent: The longitude extent to use for the dataset specified in degrees (0, 360). The extent is inclusive, so the start and stop values are included in the extent. repeat: The number of times to repeat the underlying xarray dataset time dimension. Useful to include longer sequences of small data for testing. drop_last: Use drop_last option in sampler. Defaults to False. If True, drop the last samples required to have even batch sizes across ranks. If false, pad with extra samples to make ranks have the same size batches. """ coarse: Sequence[XarrayDataConfig | XarrayEnsembleDataConfig] batch_size: int num_data_workers: int strict_ensemble: bool topography: str | None = None lat_extent: ClosedInterval = dataclasses.field( default_factory=lambda: ClosedInterval(-66, 70) ) lon_extent: ClosedInterval = dataclasses.field( default_factory=lambda: ClosedInterval(float("-inf"), float("inf")) ) repeat: int = 1 drop_last: bool = False def __post_init__(self): enforce_lat_bounds(self.lat_extent) @property def full_config(self) -> Sequence[XarrayDataConfig]: # Expands any XarrayEnsembleDataConfig so it is converted # to the equivalent sequence of XarrayDataConfig. all_configs = [] for config in self.coarse: if isinstance(config, XarrayEnsembleDataConfig): all_configs += config.expand() else: all_configs.append(config) return all_configs @property def mp_context(self): context = None if self.num_data_workers == 0: return None for config in self.full_config: if config.engine == "zarr": context = "forkserver" return context def _repeat_if_requested(self, dataset: XarrayConcat) -> XarrayConcat: return XarrayConcat([dataset] * self.repeat) def get_xarray_dataset( self, names: list[str], n_timesteps: int, ) -> tuple[XarrayConcat, DatasetProperties]: return get_dataset( self.full_config, names, IntSchedule.from_constant(n_timesteps), strict=self.strict_ensemble, ) def build_topography( self, coarse_coords: LatLonCoordinates, requires_topography: bool, static_inputs_from_checkpoint: StaticInputs | None = None, ) -> Topography | None: if requires_topography is False: return None if static_inputs_from_checkpoint is not None: # TODO: change to use full static inputs list topography = static_inputs_from_checkpoint[0] else: if self.topography is None: raise ValueError( "Topography is required for this model, but no topography " "dataset was specified in the configuration nor provided " "in model checkpoint." ) topography = get_normalized_topography(self.topography) # Fine grid boundaries are adjusted to exactly match the coarse grid fine_lat_interval = adjust_fine_coord_range( self.lat_extent, full_coarse_coord=coarse_coords.lat, full_fine_coord=topography.coords.lat, ) fine_lon_interval = adjust_fine_coord_range( self.lon_extent, full_coarse_coord=coarse_coords.lon, full_fine_coord=topography.coords.lon, ) subset_topography = topography.subset_latlon( lat_interval=fine_lat_interval, lon_interval=fine_lon_interval ) return subset_topography.to_device() def build_batchitem_dataset( self, dataset: XarrayConcat, properties: DatasetProperties, ) -> BatchItemDatasetAdapter: # n_timesteps is hardcoded to 1 for downscaling, so the sample_start_times # are the full time range for the dataset if dataset.sample_n_times != 1: raise ValueError( "Downscaling data loading should always have n_timesteps=1 " "in model data requirements." f" Got {dataset.sample_n_times} instead." ) dataset = self._repeat_if_requested(dataset) dataset_subset = HorizontalSubsetDataset( dataset, properties=properties, lat_interval=self.lat_extent, lon_interval=self.lon_extent, ) return BatchItemDatasetAdapter( dataset_subset, dataset_subset.subset_latlon_coordinates, properties=properties, ) def build( self, requirements: DataRequirements, dist: Distributed | None = None, static_inputs_from_checkpoint: StaticInputs | None = None, ) -> GriddedData: # TODO: static_inputs_from_checkpoint is currently passed from the model # to allow loading fine topography when no fine data is available. # See PR https://github.com/ai2cm/ace/pull/728 # In the future we could disentangle this dependency between the data loader # and model by enabling the built GriddedData objects to take in full static # input fields and subset them to the same coordinate range as data. xr_dataset, properties = self.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 ) if not isinstance(properties.horizontal_coordinates, LatLonCoordinates): raise ValueError( "Downscaling data loader only supports datasets with latlon coords." ) latlon_coords = properties.horizontal_coordinates dataset = self.build_batchitem_dataset( dataset=xr_dataset, properties=properties, ) all_times = xr_dataset.sample_start_times if dist is None: dist = Distributed.get_instance() # Shuffle is not used for generation, it is set to False. sampler = ( ContiguousDistributedSampler(dataset, drop_last=self.drop_last) if dist.is_distributed() else None ) dataloader = DataLoader( dataset, batch_size=dist.local_batch_size(int(self.batch_size)), num_workers=self.num_data_workers, shuffle=False, sampler=sampler, drop_last=True, collate_fn=BatchData.from_sequence, pin_memory=using_gpu(), multiprocessing_context=self.mp_context, persistent_workers=True if self.num_data_workers > 0 else False, ) example = dataset[0] subset_topography = self.build_topography( coarse_coords=latlon_coords, requires_topography=requirements.use_fine_topography, static_inputs_from_checkpoint=static_inputs_from_checkpoint, ) return GriddedData( _loader=dataloader, topography=subset_topography, shape=example.horizontal_shape, dims=example.latlon_coordinates.dims, variable_metadata=dataset.variable_metadata, all_times=all_times, )
@dataclasses.dataclass class PairedDataLoaderConfig: """ Configuration for loading downscaling datasets. The input fine and coarse Xarray datasets will be processed into batches, usually with a horizontal extent to define a portion of the full domain for use in training or validation. Additionally, a user may specify to take random subsets of the initial domain by using the coarse random extent arguments. The build ensures the compatibility of the fine/coarse datasets by checking that the fine coordinates are evenly divisible by the coarse coordinates, and that the scale factors are equal. Args: fine: The fine dataset configuration. coarse: The coarse dataset configuration. XarrayEnsembleDataConfig is supported to load multiple ensemble members. batch_size: The batch size to use for the dataloader. num_data_workers: The number of data workers to use for the dataloader. (For multi-GPU runtime, it's the number of workers per GPU.) strict_ensemble: Whether to enforce that the datasets to be concatened have the same dimensions and coordinates. lat_extent: The latitude extent to use for the dataset specified in degrees [-88, 88]. The extent is inclusive, so the start and stop values are included in the extent. Defaults to [-66, 70] which covers continental land masses aside from Antarctica. lon_extent: The longitude extent to use for the dataset specified in degrees (0, 360). The extent is inclusive, so the start and stop values are included in the extent. repeat: The number of times to repeat the underlying xarray dataset time dimension. Useful to include longer sequences of small data for testing. topography: Optional path to dataset to load for topography. If not provided and model has requires_topography=True, the data loader will default to trying to load the variable from the fine data. sample_with_replacement: If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not). drop_last: Use drop_last option in sampler. Defaults to False. If True, drop the last samples required to have even batch sizes across ranks. If false, pad with extra samples to make ranks have the same size batches. """ fine: Sequence[XarrayDataConfig] coarse: Sequence[XarrayDataConfig | XarrayEnsembleDataConfig] batch_size: int num_data_workers: int strict_ensemble: bool lat_extent: ClosedInterval = dataclasses.field( default_factory=lambda: ClosedInterval(-66.0, 70.0) ) lon_extent: ClosedInterval = dataclasses.field( default_factory=lambda: ClosedInterval(float("-inf"), float("inf")) ) repeat: int = 1 topography: str | None = None sample_with_replacement: int | None = None drop_last: bool = False def __post_init__(self): enforce_lat_bounds(self.lat_extent) def _repeat_if_requested(self, dataset: XarrayConcat) -> XarrayConcat: return XarrayConcat([dataset] * self.repeat) def _mp_context(self): mp_context = None if self.num_data_workers == 0: return None for config in self.fine: if config.engine == "zarr": mp_context = "forkserver" for config in self.coarse_full_config: if config.engine == "zarr": mp_context = "forkserver" return mp_context @property def coarse_full_config(self) -> Sequence[XarrayDataConfig]: # Expands the coarse dataset configs so that any XarrayEnsembleDataConfig # is converted to the equivalent sequence of XarrayDataConfig. coarse_configs = [] for config in self.coarse: if isinstance(config, XarrayEnsembleDataConfig): coarse_configs += config.expand() else: coarse_configs.append(config) return coarse_configs def build( self, train: bool, requirements: DataRequirements, dist: Distributed | None = None, static_inputs_from_checkpoint: StaticInputs | None = None, ) -> PairedGriddedData: # TODO: static_inputs_from_checkpoint is currently passed from the model # to allow loading fine topography when no fine data is available. # See PR https://github.com/ai2cm/ace/pull/728 # In the future we could disentangle this dependency between the data loader # and model by enabling the built GriddedData objects to take in full static # input fields and subset them to the same coordinate range as data. if dist is None: dist = Distributed.get_instance() # Load initial datasets dataset_fine, properties_fine = get_dataset( self.fine, requirements.fine_names, IntSchedule.from_constant(requirements.n_timesteps), strict=self.strict_ensemble, ) dataset_coarse, properties_coarse = get_dataset( self.coarse_full_config, requirements.coarse_names, IntSchedule.from_constant(requirements.n_timesteps), strict=self.strict_ensemble, ) # Ensure that bounds for subselecting on latlon grids return fine grid data # that aligns with the coarse grid. if not isinstance( properties_coarse.horizontal_coordinates, LatLonCoordinates ) or not isinstance(properties_fine.horizontal_coordinates, LatLonCoordinates): raise ValueError( "Downscaling data loader only supports datasets with latlon coords." ) # n_timesteps is hardcoded to 1 for downscaling, so the sample_start_times # are the full time range for the dataset if dataset_fine.sample_n_times != 1: raise ValueError( "Downscaling data loading should always have n_timesteps=1 " "in model data requirements." f" Got {dataset_fine.sample_n_times} instead." ) all_times = dataset_fine.sample_start_times dataset_fine = self._repeat_if_requested(dataset_fine) dataset_coarse = self._repeat_if_requested(dataset_coarse) # Ensure fine data subselection lines up exactly with coarse data fine_lat_extent = adjust_fine_coord_range( self.lat_extent, full_coarse_coord=properties_coarse.horizontal_coordinates.lat, full_fine_coord=properties_fine.horizontal_coordinates.lat, ) fine_lon_extent = adjust_fine_coord_range( self.lon_extent, full_coarse_coord=properties_coarse.horizontal_coordinates.lon, full_fine_coord=properties_fine.horizontal_coordinates.lon, ) if requirements.use_fine_topography: if static_inputs_from_checkpoint is not None: # TODO: change to use full static inputs list fine_topography = static_inputs_from_checkpoint[0] elif self.topography is None: data_path = self.fine[0].data_path file_pattern = self.fine[0].file_pattern raw_paths = get_raw_paths(data_path, file_pattern) if len(raw_paths) == 0: raise ValueError( f"No files found matching '{data_path}/{file_pattern}'." ) fine_topography = get_normalized_topography(raw_paths[0]) else: fine_topography = get_normalized_topography(self.topography) fine_topography = fine_topography.to_device() if ( get_topography_downscale_factor( fine_topography.data.shape, properties_fine.horizontal_coordinates.shape, ) != 1 ): raise ValueError( f"Fine topography shape {fine_topography.shape} does not match " f"fine data shape {properties_fine.horizontal_coordinates.shape}." ) fine_topography = fine_topography.subset_latlon( lat_interval=fine_lat_extent, lon_interval=fine_lon_extent ) else: fine_topography = None # TODO: horizontal subsetting should probably live in the XarrayDatast level # Subset to overall horizontal domain # TODO: Follow up PR will remove topography from batch items dataset_fine_subset = HorizontalSubsetDataset( dataset_fine, properties=properties_fine, lat_interval=fine_lat_extent, lon_interval=fine_lon_extent, ) dataset_coarse_subset = HorizontalSubsetDataset( dataset_coarse, properties=properties_coarse, lat_interval=self.lat_extent, lon_interval=self.lon_extent, ) # Convert datasets to produce BatchItems dataset_fine_subset = BatchItemDatasetAdapter( dataset_fine_subset, dataset_fine_subset.subset_latlon_coordinates, properties=properties_fine, ) dataset_coarse_subset = BatchItemDatasetAdapter( dataset_coarse_subset, dataset_coarse_subset.subset_latlon_coordinates, properties=properties_coarse, ) dataset = FineCoarsePairedDataset( dataset_fine_subset, dataset_coarse_subset, ) sampler = self._get_sampler( dataset=dataset, dist=dist, train=train, drop_last=self.drop_last ) dataloader = DataLoader( dataset, batch_size=dist.local_batch_size(int(self.batch_size)), num_workers=self.num_data_workers, shuffle=(sampler is None) and train, sampler=sampler, drop_last=True, pin_memory=using_gpu(), collate_fn=PairedBatchData.from_sequence, multiprocessing_context=self._mp_context(), persistent_workers=True if self.num_data_workers > 0 else False, ) example = dataset[0] common_metadata_keys = set(dataset_fine_subset.variable_metadata).intersection( dataset_coarse_subset.variable_metadata ) assert all( dataset_fine_subset.variable_metadata[key] == dataset_coarse_subset.variable_metadata[key] for key in common_metadata_keys ), "Metadata for variables common to coarse and fine datasets must match." variable_metadata = { **dataset_fine_subset.variable_metadata, **dataset_coarse_subset.variable_metadata, } return PairedGriddedData( _loader=dataloader, topography=fine_topography, coarse_shape=example.coarse.horizontal_shape, downscale_factor=example.downscale_factor, dims=example.fine.latlon_coordinates.dims, variable_metadata=variable_metadata, all_times=all_times, ) def _get_sampler( self, dataset: Dataset, dist: Distributed, train: bool, drop_last: bool = False ) -> RandomSampler | DistributedSampler | None: # Use RandomSampler with replacement for both distributed and # non-distributed cases if self.sample_with_replacement is not None: local_sample_with_replacement_dataset_size = ( self.sample_with_replacement // dist.world_size ) return RandomSampler( dataset, num_samples=local_sample_with_replacement_dataset_size, replacement=True, ) if dist.is_distributed(): if train: sampler = DistributedSampler( dataset, shuffle=train, drop_last=drop_last ) else: sampler = ContiguousDistributedSampler(dataset, drop_last=drop_last) else: sampler = None return sampler