import logging
from collections.abc import Sequence
import torch.utils.data
from fme.ace.data_loading.batch_data import BatchData
from fme.ace.data_loading.dataloader import get_data_loader
from fme.ace.requirements import DataRequirements, PrognosticStateDataRequirements
from fme.core.dataset.dataset import DatasetItem
from fme.core.dataset.merged import MergeNoConcatDatasetConfig
from fme.core.dataset.subset import SubsetDataset
from fme.core.dataset.xarray import XarrayDataConfig, XarrayDataset
from fme.core.device import using_gpu
from fme.core.distributed import Distributed
from fme.core.labels import LabelEncoding
from .batch_data import PrognosticState
from .config import DataLoaderConfig
from .gridded_data import GriddedData, InferenceGriddedData
from .inference import (
ExplicitIndices,
ForcingDataLoaderConfig,
InferenceDataLoaderConfig,
InferenceDataset,
)
logger = logging.getLogger(__name__)
class CollateFn:
def __init__(
self, horizontal_dims: list[str], label_encoding: LabelEncoding | None = None
):
self.horizontal_dims = horizontal_dims
self.label_encoding = label_encoding
def __call__(self, samples: Sequence[DatasetItem]) -> BatchData:
return BatchData.from_sample_tuples(
samples,
horizontal_dims=self.horizontal_dims,
label_encoding=self.label_encoding,
)
def _get_sampler(
dataset: torch.utils.data.Dataset,
sample_with_replacement_dataset_size: int | None,
train: bool,
) -> torch.utils.data.Sampler:
dist = Distributed.get_instance()
if sample_with_replacement_dataset_size is not None:
dist.require_no_spatial_parallelism(
"sample_with_replacement is not supported with spatial "
"parallelism. Spatial co-ranks would draw different samples, "
"producing corrupted data after scatter_spatial reassembly."
)
local_sample_with_replacement_dataset_size = (
sample_with_replacement_dataset_size // dist.total_data_parallel_ranks
)
sampler = torch.utils.data.RandomSampler(
dataset,
num_samples=local_sample_with_replacement_dataset_size,
replacement=True,
)
else:
sampler = dist.get_sampler(dataset, shuffle=train)
return sampler
def get_gridded_data(
config: DataLoaderConfig,
train: bool,
requirements: DataRequirements,
_force_forkserver: bool = False,
) -> GriddedData:
"""
Args:
config: Parameters for the data loader.
train: Whether loader is intended for training or validation data; if True,
then data will be shuffled.
requirements: Data requirements for the model.
_force_forkserver: Whether to force using forkserver multiprocessing context.
This is useful for debugging or testing in cases where forkserver is not
the default, but should generally be unused in production code.
"""
n_timesteps_preloaded = requirements.n_timesteps_schedule.add(config.time_buffer)
dataset, properties = config.get_dataset(requirements.names, n_timesteps_preloaded)
if config.time_buffer > 0:
# include requirements.n_timesteps - 1 steps of overlap so that no samples are
# skipped at the boundaries of the preloaded timesteps
start_every_n = config.time_buffer + 1
indices = list(range(len(dataset))[::start_every_n])
dataset = SubsetDataset(dataset, indices)
dist = Distributed.get_instance()
sampler = _get_sampler(dataset, config.sample_with_replacement, train)
if _force_forkserver or (config.zarr_engine_used and config.num_data_workers > 0):
# GCSFS and S3FS are not fork-safe, so we need to use forkserver
# reading zarr with async from weka also requires forkserver
mp_context = "forkserver"
persistent_workers = True
else:
mp_context = None
persistent_workers = False
dist = Distributed.get_instance()
batch_size = dist.local_batch_size(int(config.batch_size))
if config.available_labels is not None:
label_encoding = LabelEncoding(sorted(list(config.available_labels)))
else:
label_encoding = None
dataloader = get_data_loader(
dataset=dataset,
batch_size=batch_size,
n_window_timesteps=requirements.n_timesteps_schedule,
time_buffer=config.time_buffer,
num_workers=config.num_data_workers,
sampler=sampler,
shuffled=train,
drop_last=True,
pin_memory=using_gpu(),
collate_fn=CollateFn(
list(properties.horizontal_coordinates.dims),
label_encoding,
),
multiprocessing_context=mp_context,
persistent_workers=persistent_workers,
prefetch_factor=config.prefetch_factor,
)
return GriddedData(
loader=dataloader,
properties=properties,
modifier=config.augmentation.build_modifier(),
)
_WORKER_DIST_CX = None # needed so it doesn't get garbage collected and finalized
def _forkserver_worker_init_fn(worker_id: int) -> None:
global _WORKER_DIST_CX
_WORKER_DIST_CX = Distributed.context()
_WORKER_DIST_CX.__enter__()
# don't need to exit the context on workers as they are not
# initialized/managed by torchrun
def get_inference_data(
config: InferenceDataLoaderConfig,
total_forward_steps: int,
window_requirements: DataRequirements,
initial_condition: PrognosticState | PrognosticStateDataRequirements,
label_override: list[str] | None = None,
surface_temperature_name: str | None = None,
ocean_fraction_name: str | None = None,
_force_forkserver: bool = False,
) -> InferenceGriddedData:
"""
Args:
config: Parameters for the data loader.
total_forward_steps: Total number of forward steps to take over the course of
inference.
window_requirements: Data requirements for the model.
initial_condition: Initial condition for the inference, or a requirements object
specifying how to extract the initial condition from the first batch of
data
label_override: Labels for the forcing data to be provided on each sample
instead of the labels in the dataset.
surface_temperature_name: Name of the surface temperature variable. Can be
set to None if no ocean temperature prescribing is being used.
ocean_fraction_name: Name of the ocean fraction variable. Can be set to None
if no ocean temperature prescribing is being used.
_force_forkserver: Whether to force using forkserver multiprocessing context.
This is useful for debugging or testing in cases where forkserver is not
the default, but should generally be unused in production code.
Returns:
A data loader for inference with coordinates and metadata.
"""
dataset = InferenceDataset(
config=config,
total_forward_steps=total_forward_steps,
requirements=window_requirements,
surface_temperature_name=surface_temperature_name,
ocean_fraction_name=ocean_fraction_name,
label_override=label_override,
)
properties = dataset.properties
if config.zarr_engine_used or _force_forkserver:
# GCSFS and S3FS are not fork-safe, so we need to use forkserver
# persist workers since startup is slow
mp_context = "forkserver"
persistent_workers = True
worker_init_fn = _forkserver_worker_init_fn
else:
mp_context = None
persistent_workers = False
worker_init_fn = None
logging.info(f"Multiprocessing inference context: {mp_context or 'fork'}")
# we roll our own batching in InferenceDataset, which is why batch_size=None below
loader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
num_workers=config.num_data_workers,
shuffle=False,
pin_memory=using_gpu(),
multiprocessing_context=mp_context,
persistent_workers=persistent_workers,
worker_init_fn=worker_init_fn,
)
gridded_data = InferenceGriddedData(
loader=loader,
initial_condition=initial_condition,
properties=properties,
)
return gridded_data
[docs]def get_forcing_data(
config: ForcingDataLoaderConfig,
total_forward_steps: int,
window_requirements: DataRequirements,
initial_condition: PrognosticState,
surface_temperature_name: str | None = None,
ocean_fraction_name: str | None = None,
label_override: list[str] | None = None,
) -> InferenceGriddedData:
"""Return a GriddedData loader for forcing data based on the initial condition.
This function determines the start indices for the forcing data based on the initial
time in the provided initial condition.
Args:
config: Parameters for the forcing data loader.
total_forward_steps: Total number of forward steps to take over the course of
inference.
window_requirements: Data requirements for the forcing data.
initial_condition: Initial condition for the inference.
surface_temperature_name: Name of the surface temperature variable. Can be
set to None if no ocean temperature prescribing is being used.
ocean_fraction_name: Name of the ocean fraction variable. Can be set to None
if no ocean temperature prescribing is being used.
label_override: Labels for the forcing data. If provided, these labels will be
provided on each sample instead of the labels in the dataset.
Returns:
A data loader for forcing data with coordinates and metadata.
"""
initial_time = initial_condition.as_batch_data().time
if initial_time.shape[1] != 1:
raise NotImplementedError("code assumes initial time only has 1 timestep")
if isinstance(config.dataset, XarrayDataConfig):
available_times = XarrayDataset(
config.dataset,
window_requirements.names,
window_requirements.n_timesteps_schedule,
).all_times
elif isinstance(config.dataset, MergeNoConcatDatasetConfig):
# Some forcing variables may not be in the first dataset,
# use an empty data requirements to get all times
if isinstance(config.dataset.merge[0], XarrayDataConfig):
available_times = XarrayDataset(
config.dataset.merge[0],
names=[],
n_timesteps=window_requirements.n_timesteps_schedule,
).all_times
else:
raise ValueError("Forcing data cannot be concatenated.")
start_time_indices = []
for time in initial_time.values[:, 0]:
start_time_indices.append(available_times.get_loc(time))
inference_config = config.build_inference_config(
start_indices=ExplicitIndices(start_time_indices)
)
return get_inference_data(
config=inference_config,
total_forward_steps=total_forward_steps,
window_requirements=window_requirements,
initial_condition=initial_condition,
surface_temperature_name=surface_temperature_name,
ocean_fraction_name=ocean_fraction_name,
label_override=label_override,
)