importdataclassesfromcollections.abcimportSequencefromtypingimportSelfimporttorchimportxarrayasxrfromfme.core.dataset.configimportDatasetConfigABCfromfme.core.dataset.propertiesimportDatasetPropertiesfromfme.core.dataset.xarrayimport(XarrayDataConfig,XarrayDataset,XarraySubset,get_xarray_datasets,)fromfme.core.typing_importTensorDictclassXarrayConcat(torch.utils.data.Dataset):def__init__(self,datasets:Sequence[XarrayDataset|XarraySubset|Self]):self._dataset=torch.utils.data.ConcatDataset(datasets)sample_start_times=datasets[0].sample_start_timesfordatasetindatasets[1:]:sample_start_times=sample_start_times.append(dataset.sample_start_times)assertdataset.sample_n_times==datasets[0].sample_n_timesifnotdataset.sample_n_times==datasets[0].sample_n_times:raiseValueError("All concatenated datasets \ must have the same number of steps per sample item.")self._sample_start_times=sample_start_timesassertlen(self._dataset)==len(sample_start_times)self._sample_n_times=datasets[0].sample_n_timesfordatasetindatasets[1:]:ifdataset.dims!=datasets[0].dims:raiseValueError("Datasets being concatenated do not have the same dimensions: "f"{dataset.dims} != {datasets[0].dims}")self.dims:list[str]=datasets[0].dimsdef__len__(self):returnlen(self._dataset)def__getitem__(self,idx:int)->tuple[TensorDict,xr.DataArray,set[str]]:returnself._dataset[idx]@propertydefsample_start_times(self):returnself._sample_start_times@propertydefsample_n_times(self)->int:"""The length of the time dimension of each sample."""returnself._sample_n_timesdefget_dataset(dataset_configs:Sequence[XarrayDataConfig],names:Sequence[str],n_timesteps:int,strict:bool=True,)->tuple[XarrayConcat,DatasetProperties]:datasets,properties=get_xarray_datasets(dataset_configs,names,n_timesteps,strict=strict)ensemble=XarrayConcat(datasets)returnensemble,properties
[docs]@dataclasses.dataclassclassConcatDatasetConfig(DatasetConfigABC):""" Configuration for concatenating multiple datasets across time. Parameters: concat: List of XarrayDataConfig objects to concatenate. strict: Whether to enforce that the datasets to be concatenated have the same dimensions and spatial coordinates. """concat:Sequence[XarrayDataConfig]strict:bool=Truedef__post_init__(self):self.zarr_engine_used=any(ds.engine=="zarr"fordsinself.concat)defbuild(self,names:Sequence[str],n_timesteps:int,)->tuple[torch.utils.data.Dataset,DatasetProperties]:returnget_dataset(self.concat,names,n_timesteps,strict=self.strict,)