Source code for fme.core.dataset.merged

import dataclasses
from collections.abc import Sequence

import xarray as xr

from fme.core.dataset.concat import ConcatDatasetConfig
from fme.core.dataset.config import DatasetConfigABC
from fme.core.dataset.dataset import DatasetABC, DatasetItem
from fme.core.dataset.properties import DatasetProperties
from fme.core.dataset.schedule import IntSchedule
from fme.core.dataset.time import RepeatedInterval, TimeSlice
from fme.core.dataset.utils import accumulate_labels
from fme.core.dataset.xarray import XarrayDataConfig, get_raw_paths
from fme.core.typing_ import Slice, TensorDict


class MergedXarrayDataset(DatasetABC):
    def __init__(self, datasets: Sequence[DatasetABC]):
        self.datasets = datasets

        combined_names = [
            item for dataset in self.datasets for item in dataset[0][0].keys()
        ]
        if len(combined_names) != len(set(combined_names)):
            duplicates = list(
                {item for item in combined_names if combined_names.count(item) > 1}
            )
            raise ValueError(
                f"Variable names must be unique across merged datasets. \
                    \nDuplicates found: {duplicates}"
            )
        self._sample_start_times = self.datasets[0].sample_start_times
        self._sample_n_times = self.datasets[0].sample_n_times
        for dataset in self.datasets:
            if not dataset.sample_start_times.equals(self._sample_start_times):
                raise ValueError(
                    "All datasets in a merged dataset must have the same sample "
                    "start times."
                )
            if not dataset.sample_n_times == self._sample_n_times:
                raise ValueError(
                    "All datasets in the merged datasets \
                         must have the same number of steps per sample item."
                )

    def __getitem__(self, idx: int) -> DatasetItem:
        tensors: TensorDict = {}
        labels = None
        epochs = []
        for dataset in self.datasets:
            ds_tensors, time, ds_labels, ds_epoch = dataset[idx]
            if labels is None:
                labels = ds_labels
            else:
                if ds_labels is not None:
                    labels = labels.union(ds_labels)
            tensors.update(ds_tensors)
            epochs.append(ds_epoch)
        if not all(epoch == epochs[0] for epoch in epochs):
            raise ValueError(
                "All datasets in a merged dataset must have the same epoch."
            )
        return tensors, time, labels, epochs[0]

    def get_sample_by_time_slice(self, time_slice: slice) -> DatasetItem:
        tensors: TensorDict = {}
        for dataset in self.datasets:
            ds_tensors, time, labels, epoch = dataset.get_sample_by_time_slice(
                time_slice
            )
            tensors.update(ds_tensors)
        return tensors, time, labels, epoch

    @property
    def all_times(self) -> xr.CFTimeIndex:
        """
        Like sample_start_times, but includes all times in the dataset, including
        final times which are not valid as a start index.

        This is relevant for inference, where we may use get_sample_by_time_slice to
        retrieve time windows directly.

        If this dataset does not support inference,
        this will raise a NotImplementedError.
        """
        return self.datasets[0].all_times

    @property
    def sample_start_times(self):
        return self._sample_start_times

    @property
    def sample_n_times(self) -> int:
        return self._sample_n_times

    def validate_inference_length(self, max_start_index: int, max_window_len: int):
        for dataset in self.datasets:
            dataset.validate_inference_length(max_start_index, max_window_len)

    @property
    def properties(self) -> DatasetProperties:
        data_properties = None
        for dataset in self.datasets:
            if data_properties is None:
                data_properties = dataset.properties
            else:
                data_properties.update_merged_dataset(dataset.properties)
        if data_properties is None:
            raise ValueError("No dataset available to determine properties")
        return data_properties

    def set_epoch(self, epoch: int):
        for dataset in self.datasets:
            dataset.set_epoch(epoch)


[docs]@dataclasses.dataclass class MergeDatasetConfig(DatasetConfigABC): """ Configuration for merging multiple datasets. Merging means combining variables from multiple datasets, each of which must have the same time coordinate. If multiple datasets contain the same data variable, the version from the first source is loaded and other sources are ignored. Parameters: merge: List of dataset configurations to merge. """ merge: Sequence[ConcatDatasetConfig | XarrayDataConfig] @property def zarr_engine_used(self) -> bool: for ds in self.merge: if ds.zarr_engine_used: return True return False def build( self, names: Sequence[str], n_timesteps: IntSchedule, ): return get_merged_datasets( self, names, n_timesteps, ) @property def available_labels(self) -> set[str] | None: """ Return the labels that are available in the dataset. """ return accumulate_labels([ds.available_labels for ds in self.merge])
[docs]@dataclasses.dataclass class MergeNoConcatDatasetConfig(DatasetConfigABC): """ Configuration for merging multiple datasets. Merging means combining variables from multiple datasets, each of which must have the same time coordinate. If multiple datasets contain the same data variable, the version from the first source is loaded and other sources are ignored. For `MergeNoConcatDatasetConfig`, the datasets being merged may not be concatenated datasets. Parameters: merge: List of dataset configurations to merge. """ merge: Sequence[XarrayDataConfig] def update_subset(self, subset: Slice | TimeSlice | RepeatedInterval): for ds in self.merge: ds.update_subset(subset) @property def subset(self) -> Slice | TimeSlice | RepeatedInterval: return self.merge[0].subset def build( self, names: Sequence[str], n_timesteps: IntSchedule, ) -> tuple[MergedXarrayDataset, DatasetProperties]: return get_merged_datasets( MergeDatasetConfig(merge=self.merge), names, n_timesteps, ) @property def available_labels(self) -> set[str] | None: """ Return the labels that are available in the dataset. """ return accumulate_labels([ds.available_labels for ds in self.merge]) @property def zarr_engine_used(self) -> bool: for ds in self.merge: if ds.engine == "zarr": return True return False
def get_merged_datasets( merged_config: MergeDatasetConfig | MergeNoConcatDatasetConfig, names: Sequence[str], n_timesteps: IntSchedule, ) -> tuple[MergedXarrayDataset, DatasetProperties]: merged_xarray_datasets = [] merged_properties: DatasetProperties | None = None per_dataset_names = get_per_dataset_names(merged_config, names) config_counter = 0 for config in merged_config.merge: ( current_source_xarray_dataset, current_source_properties, ) = config.build( per_dataset_names[config_counter], n_timesteps, ) merged_xarray_datasets.append(current_source_xarray_dataset) if merged_properties is None: merged_properties = current_source_properties else: merged_properties.update_merged_dataset(current_source_properties) config_counter += 1 if merged_properties is None: raise ValueError("At least one dataset must be provided.") merged_datasets = MergedXarrayDataset(datasets=merged_xarray_datasets) return merged_datasets, merged_properties def _infer_available_variables(config: XarrayDataConfig): """ Infer the available variables from a XarrayDataset. """ raw_paths = get_raw_paths(config.data_path, config.file_pattern) if len(raw_paths) == 0: raise ValueError( f"No files found matching '{config.data_path}/{config.file_pattern}'." ) dataset = xr.open_dataset( raw_paths[0], decode_times=False, decode_timedelta=False, engine=config.engine, chunks=None, ) return dataset.data_vars def get_per_dataset_names( merged_config: MergeDatasetConfig | MergeNoConcatDatasetConfig, names: Sequence[str], ) -> list[list[str]]: merged_required_names = list(names) per_dataset_names = [] for config in merged_config.merge: if isinstance(config, XarrayDataConfig): current_source_variables = _infer_available_variables(config) elif isinstance(config, ConcatDatasetConfig): current_source_variables = _infer_available_variables(config.concat[0]) current_source_names = [ name for name in merged_required_names if name in current_source_variables ] per_dataset_names.append(current_source_names) for name in current_source_names: merged_required_names.remove(name) return per_dataset_names