[docs]@dataclasses.dataclassclassDataLoaderConfig:r""" Configuration for a data loader for training/validation. Parameters: dataset: Could be a single dataset configuration, or a sequence of datasets to be concatenated using the keyword `concat`, or datasets from different sources to be merged using the keyword `merge`. batch_size: Number of samples per batch. num_data_workers: Number of parallel workers to use for data loading. prefetch_factor: how many batches a single data worker will attempt to hold in host memory at a given time. augmentation: Configuration for data augmentation. sample_with_replacement: If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not). time_buffer: How many more continuous timesteps to load in memory than the required number of timesteps for a single batch. Setting this to greater than 0 should improve data loading performance, however, it also decreases the independence of subsequent batches if shuffled batches are desired. Note: Setting `time_buffer` to a value greater than 0 results in pre-loading samples of length `time_buffer + n_timesteps_required`, where `n_timesteps_required` is the number of timesteps required for training the model (initial condition(s) plus forward step(s)). These pre-loaded samples become a window from which samples of the required length are drawn without replacement. The windows will overlap by an amount such that no samples are skipped, with exception of the last window, which is dropped if incomplete. This is useful for improving data loading throughput and reducing the number of reads. There must be enough pre-loaded samples in the dataset to produce at least one batch at the configured batch size. Independent data will be seen every `time_buffer + 1` batches, i.e., this is the number of samples in each pre-loaded window. """dataset:ConcatDatasetConfig|MergeDatasetConfig|XarrayDataConfigbatch_size:intnum_data_workers:int=0prefetch_factor:int|None=Noneaugmentation:AugmentationConfig=dataclasses.field(default_factory=lambda:AugmentationConfig())sample_with_replacement:int|None=Nonetime_buffer:int=0@propertydefusing_labels(self)->bool:returnself.available_labelsisnotNonedefget_dataset(self,names:Sequence[str],n_timesteps:IntSchedule,)->tuple[DatasetABC,DatasetProperties]:returnself.dataset.build(names,n_timesteps)@propertydefavailable_labels(self)->set[str]|None:""" Return the labels that are available in the dataset. """returnself.dataset.available_labelsdef__post_init__(self):dist=Distributed.get_instance()ifself.batch_size%dist.world_size!=0:raiseValueError("batch_size must be divisible by the number of parallel "f"workers, got {self.batch_size} and {dist.world_size}")self._zarr_engine_used=self.dataset.zarr_engine_usedifself.time_buffer<0:raiseValueError("time_buffer must be greater than or equal to 0. "f"Got {self.time_buffer}")@propertydefzarr_engine_used(self)->bool:""" Whether any of the configured datasets are using the Zarr engine. """returnself._zarr_engine_used