importdataclassesimportwarningsfromcollections.abcimportSequenceimporttorchimportxarrayasxrfromfme.core.dataset.configimportDatasetConfigABCfromfme.core.dataset.datasetimportDatasetABC,DatasetItemfromfme.core.dataset.propertiesimportDatasetPropertiesfromfme.core.dataset.scheduleimportIntSchedulefromfme.core.dataset.utilsimportaccumulate_labelsfromfme.core.dataset.xarrayimportXarrayDataConfig,get_xarray_datasetsclassXarrayConcat(DatasetABC):def__init__(self,datasets:Sequence[DatasetABC],strict:bool=True):self._dataset=torch.utils.data.ConcatDataset(datasets)self._wrapped_datasets=datasetssample_start_times=datasets[0].sample_start_timesfordatasetindatasets[1:]:sample_start_times=sample_start_times.append(dataset.sample_start_times)assertdataset.sample_n_times==datasets[0].sample_n_timesifnotdataset.sample_n_times==datasets[0].sample_n_times:raiseValueError("All concatenated datasets \ must have the same number of steps per sample item.")self._sample_start_times=sample_start_timesassertlen(self._dataset)==len(sample_start_times)self._sample_n_times=datasets[0].sample_n_timesself._properties=datasets[0].properties.copy()fordatasetindatasets[1:]:ifstrict:self._properties.update(dataset.properties)else:try:self._properties.update(dataset.properties)exceptValueErrorase:warnings.warn(f"Metadata for each ensemble member are not the same: {e}")def__getitem__(self,idx:int)->DatasetItem:returnself._dataset[idx]@propertydefsample_start_times(self):returnself._sample_start_times@propertydefall_times(self)->xr.CFTimeIndex:""" Like sample_start_times, but includes all times in the dataset, including final times which are not valid as a start index. This is relevant for inference, where we may use get_sample_by_time_slice to retrieve time windows directly. If this dataset does not support inference, this will raise a NotImplementedError. """raiseNotImplementedError("Concat datasets do not support inference.")@propertydefsample_n_times(self)->int:"""The length of the time dimension of each sample."""returnself._sample_n_timesdefget_sample_by_time_slice(self,time_slice:slice)->DatasetItem:raiseNotImplementedError("Concat datasets do not support getting samples by time slice, ""and should not be configurable for inference. Is there a bug?.")@propertydefproperties(self)->DatasetProperties:returnself._propertiesdefvalidate_inference_length(self,max_start_index:int,max_window_len:int):raiseValueError("Concat datasets do not support inference.")defset_epoch(self,epoch:int):fordatasetinself._wrapped_datasets:dataset.set_epoch(epoch)defget_dataset(dataset_configs:Sequence[XarrayDataConfig],names:Sequence[str],n_timesteps:IntSchedule,strict:bool=True,)->tuple[XarrayConcat,DatasetProperties]:datasets,properties=get_xarray_datasets(dataset_configs,names,n_timesteps,strict=strict)ensemble=XarrayConcat(datasets,strict=strict)returnensemble,properties
[docs]@dataclasses.dataclassclassConcatDatasetConfig(DatasetConfigABC):""" Configuration for concatenating multiple datasets across time. Parameters: concat: List of XarrayDataConfig objects to concatenate. strict: Whether to enforce that the datasets to be concatenated have the same dimensions and spatial coordinates. """concat:Sequence[XarrayDataConfig]strict:bool=Truedef__post_init__(self):self.zarr_engine_used=any(ds.engine=="zarr"fordsinself.concat)defbuild(self,names:Sequence[str],n_timesteps:IntSchedule,)->tuple[DatasetABC,DatasetProperties]:returnget_dataset(self.concat,names,n_timesteps,strict=self.strict,)@propertydefavailable_labels(self)->set[str]|None:""" Return the labels that are available in the dataset. """returnaccumulate_labels([ds.available_labelsfordsinself.concat])