from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
from fme.core.dataset.time import RepeatedInterval, TimeSlice
from fme.core.dataset.xarray import XarrayDataConfig
from fme.core.distributed import Distributed
from fme.core.typing_ import Slice
from fme.core.writer import ZarrWriter
from ..data import (
ClosedInterval,
DataLoaderConfig,
LatLonCoordinates,
StaticInputs,
enforce_lat_bounds,
)
from ..data.config import XarrayEnsembleDataConfig
from ..predictors import PatchPredictionConfig
from ..requirements import DataRequirements
from .constants import DIMS
from .work_items import SliceItemDataset, SliceWorkItemGriddedData, get_work_items
from .zarr_utils import determine_zarr_chunks
def _identity_collate(batch):
"""
Collate function that returns the single batch item.
Used with batch_size=1 to extract the single item from the batch list.
Must be a module-level function (not lambda) to be picklable for multiprocessing.
"""
return batch[0]
class DownscalingOutput:
"""
Container for a single downscaling output.
Encapsulates all data and metadata needed to generate downscaled outputs
for a specific region, time range, and ensemble configuration.
Parameters:
name: Identifier for this target (used as output filenames).
save_vars: List of variable names to save to zarr.
n_ens: Total number of ensemble members to generate.
max_samples_per_gpu: Max number of time and/or ensemble samples per GPU batch.
The breakdown of time vs ensemble per batch is determined automatically.
data: GriddedData containing the input coarse data and loader.
patch: Configuration for patching large domains.
chunks: Zarr chunk sizes for each dimension.
shards: Zarr shard sizes for each dimension.
dims: Dimension names including ensemble (time, ensemble, lat, lon).
"""
def __init__(
self,
name: str,
save_vars: list[str] | None,
n_ens: int,
max_samples_per_gpu: int,
data: SliceWorkItemGriddedData,
patch: PatchPredictionConfig,
chunks: dict[str, int],
shards: dict[str, int],
dims: tuple[str, ...] = DIMS,
) -> None:
self.name = name
self.save_vars = save_vars
self.n_ens = n_ens
self.max_samples_per_gpu = max_samples_per_gpu
self.data = data
self.patch = patch
self.chunks = chunks
self.shards = shards
self.dims = dims
def get_writer(
self,
latlon_coords: LatLonCoordinates,
output_dir: str,
) -> ZarrWriter:
"""
Create a ZarrWriter for this target.
Args:
latlon_coords: High-resolution spatial coordinates for outputs
output_dir: Directory to store output zarr file
"""
ensemble = list(range(self.n_ens))
coords = dict(
zip(
self.dims,
[
self.data.all_times.to_numpy(),
np.array(ensemble),
latlon_coords.lat.cpu().numpy(),
latlon_coords.lon.cpu().numpy(),
],
)
)
dims = tuple(coords.keys())
return ZarrWriter(
path=f"{output_dir}/{self.name}.zarr",
dims=dims,
coords=coords,
data_vars=self.save_vars,
chunks=self.chunks,
shards=self.shards,
)
@dataclass
class DownscalingOutputConfig(ABC):
"""
Base class for configuring downscaling output generation targets.
Output targets define what data to generate, where to generate it, and how
to save it.
Parameters:
name: Unique identifier for this target (used in output filename)
n_ens: Number of ensemble members to generate when downscaling
save_vars: List of variable names to save to zarr output. If None,
all variables from the model output will be saved.
zarr_chunks: Optional chunk sizes for zarr dimensions. If None, automatically
calculated to target lat/lon shape <=10MB per chunk. Ensemble and time
dimensions chunks are length 1.
zarr_shards: Optional shard sizes for zarr dimensions. If None, defaults to
maximum output size for a single unit of downscaling work. This ensures
that parallel generation tasks write to separate shards.
max_samples_per_gpu: Number of time and/or ensemble samples to include in a
single GPU generation. Controls memory usage and time to generate.
"""
name: str
n_ens: int
save_vars: list[str] | None = None
zarr_chunks: dict[str, int] | None = None
zarr_shards: dict[str, int] | None = None
max_samples_per_gpu: int = 4
@abstractmethod
def build(
self,
loader_config: DataLoaderConfig,
requirements: DataRequirements,
patch: PatchPredictionConfig,
) -> DownscalingOutput:
"""
Build an OutputTarget from this configuration.
Args:
loader_config: Base data loader configuration to modify
requirements: Model's data requirements (variable names, etc.)
patch: Default patch prediction configuration
"""
pass
@staticmethod
def _single_xarray_config(
coarse: list[XarrayDataConfig]
| Sequence[XarrayDataConfig | XarrayEnsembleDataConfig],
) -> list[XarrayDataConfig]:
"""
Ensures that the data configuration is a single xarray config.
Necessary because we will be using the top-level DataLoaderConfig
to build the data, and we'll be replacing time and spatial extents.
"""
# TODO: Consider only supporting a single xarray config
# for this run type since we use Zarr not netCDF. Just more
# complexity to enforce all possible rather than just supporting
# a single config.
if len(coarse) != 1:
raise NotImplementedError(
"Only a single XarrayDataConfig is supported in OutputTargetConfig "
" coarse specification."
)
data_config = coarse[0]
if not isinstance(data_config, XarrayDataConfig):
raise NotImplementedError(
"Only XarrayDataConfig objects are supported in OutputTargetConfig "
" coarse specification."
)
return [data_config]
def _replace_loader_config(
self,
time,
coarse,
lat_extent,
lon_extent,
loader_config: DataLoaderConfig,
) -> DataLoaderConfig:
new_coarse = [replace(coarse[0], subset=time)]
# TODO: log the replacements for debugging
new_loader_config = replace(
loader_config,
coarse=new_coarse,
lat_extent=lat_extent,
lon_extent=lon_extent,
)
return new_loader_config
def _build_gridded_data(
self,
loader_config: DataLoaderConfig,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs_from_checkpoint: StaticInputs | None = None,
) -> SliceWorkItemGriddedData:
xr_dataset, properties = loader_config.get_xarray_dataset(
names=requirements.coarse_names, n_timesteps=1
)
coords = properties.horizontal_coordinates
if not isinstance(coords, LatLonCoordinates):
raise ValueError(
"Downscaling data loader only supports datasets with latlon coords."
)
dataset = loader_config.build_batchitem_dataset(xr_dataset, properties)
topography = loader_config.build_topography(
coords,
requires_topography=requirements.use_fine_topography,
# TODO: update to support full list of static inputs
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
)
if topography is None:
raise ValueError("Topography is required for downscaling generation.")
work_items = get_work_items(
n_times=len(dataset),
n_ens=self.n_ens,
max_samples_per_gpu=self.max_samples_per_gpu,
)
# defer topography device placement until after batch generation
slice_dataset = SliceItemDataset(
slice_items=work_items,
dataset=dataset,
spatial_shape=topography.shape,
)
# each SliceItemDataset work item loads its own full batch, so batch_size=1
dist = dist or Distributed.get_instance()
loader = DataLoader(
slice_dataset,
batch_size=1,
shuffle=False,
num_workers=loader_config.num_data_workers,
collate_fn=_identity_collate,
drop_last=False,
multiprocessing_context=loader_config.mp_context,
persistent_workers=True if loader_config.num_data_workers > 0 else False,
sampler=(
DistributedSampler(slice_dataset, shuffle=False)
if dist.is_distributed()
else None
),
)
return SliceWorkItemGriddedData(
loader,
variable_metadata=dataset.variable_metadata,
all_times=xr_dataset.sample_start_times,
dtype=slice_dataset.dtype,
max_output_shape=slice_dataset.max_output_shape,
topography=topography,
)
def _build(
self,
time: TimeSlice | RepeatedInterval | Slice,
lat_extent: ClosedInterval,
lon_extent: ClosedInterval,
loader_config: DataLoaderConfig,
requirements: DataRequirements,
patch: PatchPredictionConfig,
coarse: list[XarrayDataConfig],
static_inputs_from_checkpoint: StaticInputs | None = None,
) -> DownscalingOutput:
updated_loader_config = self._replace_loader_config(
time,
coarse,
lat_extent,
lon_extent,
loader_config,
)
gridded_data = self._build_gridded_data(
updated_loader_config,
requirements,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
)
if self.zarr_chunks is None:
# Get element size from dtype by creating a dummy tensor
element_size = torch.tensor([], dtype=gridded_data.dtype).element_size()
chunks = determine_zarr_chunks(
dims=DIMS,
data_shape=gridded_data.max_output_shape,
bytes_per_element=element_size,
)
else:
chunks = self.zarr_chunks
if self.zarr_shards is None:
shards = dict(zip(DIMS, gridded_data.max_output_shape))
else:
shards = self.zarr_shards
return DownscalingOutput(
name=self.name,
save_vars=self.save_vars,
n_ens=self.n_ens,
max_samples_per_gpu=self.max_samples_per_gpu,
data=gridded_data,
patch=patch,
chunks=chunks,
shards=shards,
dims=DIMS,
)
[docs]@dataclass
class EventConfig(DownscalingOutputConfig):
"""
Configuration for generating a single time snapshot over a spatial region.
Useful for capturing specific events like hurricane landfall, extreme weather
events, or any single-timestep high-resolution snapshot of a region.
If n_ens > max_samples_per_gpu, this event can be run in a distributed manner
where each GPU generates a subset of the ensemble members for the event.
Parameters:
name: Unique identifier for this target (used in output filename)
n_ens: Number of ensemble members to generate when downscaling
save_vars: List of variable names to save to zarr output. If None,
all variables from the model output will be saved.
zarr_chunks: Optional chunk sizes for zarr dimensions. If None, automatically
calculated to target lat/lon shape <=10MB per chunk. Ensemble and time
dimensions chunks are length 1.
zarr_shards: Optional shard sizes for zarr dimensions. If None, defaults to
maximum output size for a single unit of downscaling work. This ensures
that parallel generation tasks write to separate shards.
max_samples_per_gpu: Number of time and/or ensemble samples to include in a
single GPU generation. Controls memory usage and time to generate.
event_time: Timestamp or integer index of the event. If string, must match
time_format. Required field.
time_format: strptime format for parsing event_time string.
Default: "%Y-%m-%dT%H:%M:%S" (ISO 8601)
lat_extent: Latitude bounds in degrees limited to [-88, 88].
Defaults to (-66, 70) which covers continental land masses aside
from Antarctica.
lon_extent: Longitude bounds in degrees [-180, 360]. Default: full extent
of the underlying data.
"""
# event_time required, but must specify as optional kwarg to allow subclassing
event_time: str | int = ""
time_format: str = "%Y-%m-%dT%H:%M:%S"
lat_extent: ClosedInterval = field(
default_factory=lambda: ClosedInterval(-66.0, 70)
)
lon_extent: ClosedInterval = field(
default_factory=lambda: ClosedInterval(float("-inf"), float("inf"))
)
def __post_init__(self):
if not self.event_time:
raise ValueError("event_time must be specified for EventConfig.")
enforce_lat_bounds(self.lat_extent)
def build(
self,
loader_config: DataLoaderConfig,
requirements: DataRequirements,
patch: PatchPredictionConfig,
static_inputs_from_checkpoint: StaticInputs | None = None,
) -> DownscalingOutput:
# Convert single time to TimeSlice
time: Slice | TimeSlice
if isinstance(self.event_time, str):
stop_time = (
datetime.strptime(self.event_time, self.time_format)
+ timedelta(hours=3) # half timestep to not include next time
).strftime(self.time_format)
time = TimeSlice(self.event_time, stop_time)
else:
time = Slice(self.event_time, self.event_time + 1)
coarse = self._single_xarray_config(loader_config.coarse)
return self._build(
time=time,
lat_extent=self.lat_extent,
lon_extent=self.lon_extent,
loader_config=loader_config,
requirements=requirements,
patch=patch,
coarse=coarse,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
)
[docs]@dataclass
class TimeRangeConfig(DownscalingOutputConfig):
"""
Configuration for generating a time segment over a spatial region.
This is the most common and flexible configuration, suitable for generating
downscaled data over regions like CONUS, continental areas, or custom domains
over extended time periods.
Parameters:
name: Unique identifier for this target (used in output filename)
n_ens: Number of ensemble members to generate when downscaling
save_vars: List of variable names to save to zarr output. If None,
all variables from the model output will be saved.
zarr_chunks: Optional chunk sizes for zarr dimensions. If None, automatically
calculated to target lat/lon shape <=10MB per chunk. Ensemble and time
dimensions chunks are length 1.
zarr_shards: Optional shard sizes for zarr dimensions. If None, defaults to
maximum output size for a single unit of downscaling work. This ensures
that parallel generation tasks write to separate shards.
max_samples_per_gpu: Number of time and/or ensemble samples to include in a
single GPU generation. Controls memory usage and time to generate.
time_range: Time selection specification. Can be:
- TimeSlice: Start/stop timestamps (e.g.,
TimeSlice(start_time="2021-01-01", stop_time="2021-12-31"))
- Slice: Integer indices (e.g., Slice(0, 365))
- RepeatedInterval: Repeating time pattern
lat_extent: Latitude bounds in degrees limited to [-88, 88].
Defaults to (-66, 70) which covers continental land masses aside
from Antarctica.
lon_extent: Longitude bounds in degrees [-180, 360]. Default: full extent
of the underlying data.
"""
time_range: TimeSlice | RepeatedInterval | Slice = field(
default_factory=lambda: Slice(-1, 1)
)
lat_extent: ClosedInterval = field(
default_factory=lambda: ClosedInterval(-66.0, 70.0)
)
lon_extent: ClosedInterval = field(
default_factory=lambda: ClosedInterval(float("-inf"), float("inf"))
)
def __post_init__(self):
if self.time_range == Slice(-1, 1):
raise ValueError("time_range must be specified for RegionConfig.")
enforce_lat_bounds(self.lat_extent)
def build(
self,
loader_config: DataLoaderConfig,
requirements: DataRequirements,
patch: PatchPredictionConfig,
static_inputs_from_checkpoint: StaticInputs | None = None,
) -> DownscalingOutput:
coarse = self._single_xarray_config(loader_config.coarse)
return self._build(
time=self.time_range,
lat_extent=self.lat_extent,
lon_extent=self.lon_extent,
loader_config=loader_config,
requirements=requirements,
patch=patch,
coarse=coarse,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
)