Source code for fme.core.dataset.xarray

import dataclasses
import datetime
import functools
import json
import logging
import multiprocessing
import os
import re
import warnings
from collections import namedtuple
from collections.abc import Mapping, Sequence
from functools import lru_cache
from typing import Literal
from urllib.parse import urlparse

import fsspec
import numpy as np
import torch
import xarray as xr
from xarray.coding.times import CFDatetimeCoder

from fme.core.coordinates import (
    DepthCoordinate,
    HorizontalCoordinates,
    HybridSigmaPressureCoordinate,
    NullVerticalCoordinate,
    VerticalCoordinate,
)
from fme.core.dataset.config import DatasetConfigABC
from fme.core.dataset.properties import DatasetProperties
from fme.core.dataset.schedule import IntSchedule
from fme.core.dataset.time import RepeatedInterval, TimeSlice
from fme.core.dataset.utils import FillNaNsConfig
from fme.core.mask_provider import MaskProvider
from fme.core.stacker import Stacker
from fme.core.typing_ import Slice, TensorDict

from .data_typing import VariableMetadata
from .dataset import DatasetABC, DatasetItem
from .utils import (
    as_broadcasted_tensor,
    get_horizontal_coordinates,
    get_nonspacetime_dimensions,
    load_series_data,
    load_series_data_zarr_async,
)

SLICE_NONE = slice(None)
GET_RAW_TIMES_NUM_FILES_PARALLELIZATION_THRESHOLD = 12
logger = logging.getLogger(__name__)

VariableNames = namedtuple(
    "VariableNames",
    (
        "time_dependent_names",
        "time_invariant_names",
        "static_derived_names",
    ),
)


def _get_vertical_coordinate(
    ds: xr.Dataset, dtype: torch.dtype | None
) -> VerticalCoordinate:
    """
    Get vertical coordinate from a dataset.

    If the dataset contains variables named `ak_N` and `bk_N` where
    `N` is the level number, then a hybrid sigma-pressure coordinate
    will be returned. If the dataset contains variables named
    `idepth_N` then a depth coordinate will be returned. If neither thing
    is true, a hybrid sigma-pressure coordinate of lenght 0 is returned.

    Args:
        ds: Dataset to get vertical coordinates from.
        dtype: Data type of the returned tensors. If None, the dtype is not
            changed from the original in ds.
    """
    ak_mapping = {
        int(v[3:]): torch.as_tensor(ds[v].values)
        for v in ds.variables
        if v.startswith("ak_")
    }
    bk_mapping = {
        int(v[3:]): torch.as_tensor(ds[v].values)
        for v in ds.variables
        if v.startswith("bk_")
    }
    ak_list = [ak_mapping[k] for k in sorted(ak_mapping.keys())]
    bk_list = [bk_mapping[k] for k in sorted(bk_mapping.keys())]

    idepth_mapping = {
        int(v[7:]): torch.as_tensor(ds[v].values)
        for v in ds.variables
        if v.startswith("idepth_")
    }
    idepth_list = [idepth_mapping[k] for k in sorted(idepth_mapping.keys())]

    if len(ak_list) > 0 and len(bk_list) > 0 and len(idepth_list) > 0:
        raise ValueError(
            "Dataset contains both hybrid sigma-pressure and depth coordinates. "
            "Can only provide one, or else the vertical coordinate is ambiguous."
        )

    coordinate: VerticalCoordinate
    deptho = None
    if len(idepth_list) > 0:
        if "mask_0" in ds.data_vars:
            mask_layers = {
                name: torch.as_tensor(ds[name].values, dtype=dtype)
                for name in ds.data_vars
                if re.match(r"mask_(\d+)$", name)
            }
            for name in mask_layers:
                if "time" in ds[name].dims:
                    raise ValueError("The ocean mask must by time-independent.")
            stacker = Stacker({"mask": ["mask_"]})
            mask = stacker("mask", mask_layers)
        else:
            logger.warning(
                "Dataset does not contain a mask. Providing a DepthCoordinate with "
                "mask set to 1 at all layers."
            )
            mask = torch.ones(len(idepth_list) - 1, dtype=dtype)
        if "deptho" in ds.data_vars:
            if "time" in ds["deptho"].dims:
                raise ValueError("'deptho' must be time-independent.")
            deptho = torch.as_tensor(ds["deptho"].values, dtype=dtype)
        else:
            logger.warning(
                "Dataset does not have a variable named 'deptho' (sea floor depth). "
                "The ocean depth integral will not account for partial bottom cells."
            )
        coordinate = DepthCoordinate(
            torch.as_tensor(idepth_list, dtype=dtype), mask, deptho
        )
    elif len(ak_list) > 0 and len(bk_list) > 0:
        coordinate = HybridSigmaPressureCoordinate(
            ak=torch.as_tensor(ak_list, dtype=dtype),
            bk=torch.as_tensor(bk_list, dtype=dtype),
        )
    else:
        logger.warning("Dataset does not contain a vertical coordinate.")
        coordinate = NullVerticalCoordinate()

    return coordinate


def _get_raw_times_single_file(path: str, engine: str | None = None) -> np.array:
    with _open_xr_dataset(path, engine=engine) as ds:
        return ds.time.values


def _get_raw_times(paths: list[str], engine: str) -> list[np.ndarray]:
    function = functools.partial(_get_raw_times_single_file, engine=engine)

    # Only parallelize if we are loading from a reasonable number of files; this
    # helps speed up data loading tests, which otherwise would be slowed by the
    # overhead of setting up a pool.
    if len(paths) > GET_RAW_TIMES_NUM_FILES_PARALLELIZATION_THRESHOLD:
        processes = min(multiprocessing.cpu_count(), len(paths))
        with multiprocessing.Pool(processes) as pool:
            return pool.map(function, paths)
    else:
        return list(map(function, paths))


def _repeat_and_increment_time(
    raw_times: list[np.ndarray], n_repeats: int, timestep: datetime.timedelta
) -> list[np.ndarray]:
    """Repeats and increments a collection of arrays of evenly spaced times."""
    n_timesteps = sum(len(times) for times in raw_times)
    timespan = timestep * n_timesteps

    repeated_and_incremented_time = []
    for repeats in range(n_repeats):
        increment = repeats * timespan
        for time in raw_times:
            incremented_time = time + increment
            repeated_and_incremented_time.append(incremented_time)
    return repeated_and_incremented_time


def _get_cumulative_timesteps(time: list[np.ndarray]) -> np.ndarray:
    """Returns a list of cumulative timesteps for each item in a time coordinate."""
    num_timesteps_per_file = [0]
    for time_coord in time:
        num_timesteps_per_file.append(len(time_coord))
    return np.array(num_timesteps_per_file).cumsum()


def _get_file_local_index(index: int, start_indices: np.ndarray) -> tuple[int, int]:
    """
    Return a tuple of the index of the file containing the time point at `index`
    and the index of the time point within that file.
    """
    file_index = np.searchsorted(start_indices, index, side="right") - 1
    time_index = index - start_indices[file_index]
    return int(file_index), time_index


class StaticDerivedData:
    names = ("x", "y", "z")
    metadata = {
        "x": VariableMetadata(units="", long_name="Euclidean x-coordinate"),
        "y": VariableMetadata(units="", long_name="Euclidean y-coordinate"),
        "z": VariableMetadata(units="", long_name="Euclidean z-coordinate"),
    }

    def __init__(self, coordinates: HorizontalCoordinates):
        self._coords = coordinates
        self._x: torch.Tensor | None = None
        self._y: torch.Tensor | None = None
        self._z: torch.Tensor | None = None

    def _get_xyz(self) -> TensorDict:
        if self._x is None or self._y is None or self._z is None:
            coords = self._coords
            x, y, z = coords.xyz

            self._x = torch.as_tensor(x)
            self._y = torch.as_tensor(y)
            self._z = torch.as_tensor(z)

        return {"x": self._x, "y": self._y, "z": self._z}

    def __getitem__(self, name: str) -> torch.Tensor:
        return self._get_xyz()[name]


def _get_protocol(path):
    return urlparse(str(path)).scheme


def _get_fs(path):
    protocol = _get_protocol(path)
    if not protocol:
        protocol = "file"
    proto_kw = _get_fs_protocol_kwargs(path)
    fs = fsspec.filesystem(protocol, **proto_kw)

    return fs


def _preserve_protocol(original_path, glob_paths):
    protocol = _get_protocol(str(original_path))
    if protocol:
        glob_paths = [f"{protocol}://{path}" for path in glob_paths]
    return glob_paths


def _get_fs_protocol_kwargs(path):
    protocol = _get_protocol(path)
    kwargs = {}
    if protocol == "gs":
        # https://gcsfs.readthedocs.io/en/latest/api.html#gcsfs.core.GCSFileSystem
        key_json = os.environ.get("FSSPEC_GS_KEY_JSON", None)
        key_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", None)

        if key_json is not None:
            token = json.loads(key_json)
        elif key_file is not None:
            token = key_file
        else:
            logger.warning(
                "GCS currently expects user credentials authenticated using"
                " `gcloud auth application-default login`. This is not recommended for "
                "production use."
            )
            token = "google_default"
        kwargs["token"] = token
    elif protocol == "s3":
        # https://s3fs.readthedocs.io/en/latest/#s3-compatible-storage
        env_vars = [
            "FSSPEC_S3_KEY",
            "FSSPEC_S3_SECRET",
            "FSSPEC_S3_ENDPOINT_URL",
        ]
        for v in env_vars:
            if v not in os.environ:
                warnings.warn(
                    f"An S3 path was specified but environment variable {v} "
                    "was not found. This may cause authentication issues if not "
                    "set and no other defaults are present. See "
                    "https://s3fs.readthedocs.io/en/latest/#s3-compatible-storage"
                    " for details."
                )

    return kwargs


def _open_xr_dataset(path: str, *args, **kwargs):
    # need the path to get protocol specific arguments for the backend
    protocol_kw = _get_fs_protocol_kwargs(path)
    if protocol_kw:
        kwargs.update({"storage_options": protocol_kw})

    return xr.open_dataset(
        path,
        *args,
        decode_times=CFDatetimeCoder(use_cftime=True),
        decode_timedelta=False,
        mask_and_scale=False,
        cache=False,
        chunks=None,
        **kwargs,
    )


_open_xr_dataset_lru = lru_cache()(_open_xr_dataset)


def _open_file_fh_cached(path, **kwargs):
    protocol = _get_protocol(path)
    if protocol:
        # add an LRU cache for remote zarrs
        return _open_xr_dataset_lru(
            path,
            **kwargs,
        )
    # netcdf4 and h5engine have a filehandle LRU cache in xarray
    # https://github.com/pydata/xarray/blob/cd3ab8d5580eeb3639d38e1e884d2d9838ef6aa1/xarray/backends/file_manager.py#L54 # noqa: E501
    return _open_xr_dataset(
        path,
        **kwargs,
    )


def get_raw_paths(path, file_pattern):
    fs = _get_fs(path)
    glob_paths = sorted(fs.glob(os.path.join(path, file_pattern)))
    raw_paths = _preserve_protocol(path, glob_paths)
    return raw_paths


def _get_mask_provider(ds: xr.Dataset, dtype: torch.dtype | None) -> MaskProvider:
    """Get mask provider from a dataset.

    If the dataset contains time-invariant variables that start with the string
    "mask_" then these variables will be used to instantiate a MaskProvider
    object. Otherwise, an empty MaskProvider is returned.

    Args:
        ds: Dataset to get vertical coordinates from.
        dtype: Data type of the returned tensors. If None, the dtype is not
            changed from the original in ds.

    """
    masks: dict[str, torch.Tensor] = {
        name: torch.as_tensor(ds[name].values, dtype=dtype)
        for name in ds.data_vars
        if name.startswith("mask_")
    }
    for name in masks:
        if "time" in ds[name].dims:
            raise ValueError("Masks must be time-independent.")
    mask_provider = MaskProvider(masks)
    logging.info(f"Initialized {mask_provider}.")
    return mask_provider


[docs]@dataclasses.dataclass class OverwriteConfig: """Configuration to overwrite field values in XarrayDataset. Parameters: constant: Fill field with constant value. multiply_scalar: Multiply field by scalar value. """ constant: Mapping[str, float] = dataclasses.field(default_factory=dict) multiply_scalar: Mapping[str, float] = dataclasses.field(default_factory=dict) def __post_init__(self): key_overlap = set(self.constant.keys()) & set(self.multiply_scalar.keys()) if key_overlap: raise ValueError( "OverwriteConfig cannot have the same variable in both constant " f"and multiply_scalar: {key_overlap}" ) def apply(self, tensors: TensorDict) -> TensorDict: for var, fill_value in self.constant.items(): data = tensors[var] tensors[var] = torch.ones_like(data) * torch.tensor( fill_value, dtype=data.dtype, device=data.device ) for var, multiplier in self.multiply_scalar.items(): data = tensors[var] tensors[var] = data * torch.tensor( multiplier, dtype=data.dtype, device=data.device ) return tensors @property def variables(self): return set(self.constant.keys()) | set(self.multiply_scalar.keys())
[docs]@dataclasses.dataclass class XarrayDataConfig(DatasetConfigABC): """ Parameters: data_path: Path to the data. file_pattern: Glob pattern to match files in the data_path. n_repeats: Number of times to repeat the dataset (in time). It is up to the user to ensure that the input dataset to repeat results in data that is reasonably continuous across repetitions. engine: Backend used in xarray.open_dataset call. spatial_dimensions: Specifies the spatial dimensions for the grid, default is lat/lon. If 'latlon', it is assumed that the last two dimensions are latitude and longitude, respectively. If 'healpix', it is assumed that the last three dimensions are face, height, and width, respectively. subset: Slice defining a subset of the XarrayDataset to load. This can either be a `Slice` of integer indices or a `TimeSlice` of timestamps. This feature is applied directly to the dataset samples. For example, if the file(s) have the time coordinate (t0, t1, t2, t3) and requirements.n_timesteps=2, then subset=Slice(stop=2) will provide two samples: (t0, t1), (t1, t2). infer_timestep: Whether to infer the timestep from the provided data. This should be set to True (the default) for ACE training. It may be useful to toggle this to False for applications like downscaling, which do not depend on the timestep of the data and therefore lack the additional requirement that the data be ordered and evenly spaced in time. It must be set to True if n_repeats > 1 in order to be able to infer the full time coordinate. dtype: Data type to cast the data to. If None, no casting is done. It is required that 'torch.{dtype}' is a valid dtype. overwrite: Optional OverwriteConfig to overwrite loaded field values. fill_nans: Optional FillNaNsConfig to fill NaNs with a constant value. isel: Optional xarray isel arguments to be passed to the dataset. Will raise ValueError if time is included here, since the subset argument is used specifically for selecting times. Horizontal dimensions are also not currently supported. labels: Optional list of labels to be returned with the data. Examples: If data is stored in a directory with multiple netCDF files which can be concatenated along the time dimension, use: >>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") # doctest: +IGNORE_OUTPUT If data is stored in a single zarr store at ``/some/directory/dataset.zarr``, use: >>> fme.ace.XarrayDataConfig( ... data_path="/some/directory", ... file_pattern="dataset.zarr", ... engine="zarr" ... ) # doctest: +IGNORE_OUTPUT """ # noqa: E501 data_path: str file_pattern: str = "*.nc" n_repeats: int = 1 engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4" spatial_dimensions: Literal["healpix", "latlon"] = "latlon" subset: Slice | TimeSlice | RepeatedInterval = dataclasses.field( default_factory=Slice ) infer_timestep: bool = True dtype: str | None = "float32" overwrite: OverwriteConfig = dataclasses.field(default_factory=OverwriteConfig) fill_nans: FillNaNsConfig | None = None isel: Mapping[str, Slice | int] = dataclasses.field(default_factory=dict) labels: list[str] | None = None def _default_file_pattern_check(self): if self.engine == "zarr" and self.file_pattern == "*.nc": raise ValueError( "The file pattern is set to the default NetCDF file pattern *.nc " "but the engine is specified as 'zarr'. Please set " "`XarrayDataConfig.file_pattern` to match the zarr filename." ) @property def available_labels(self) -> set[str] | None: """ Return the labels that are available in the dataset. """ if self.labels is None: return None return set(self.labels) @property def torch_dtype(self) -> torch.dtype | None: if self.dtype is None: return None else: try: torch_dtype = getattr(torch, self.dtype) except AttributeError: raise ValueError(f"Invalid dtype '{self.dtype}'") if not isinstance(torch_dtype, torch.dtype): raise ValueError(f"Invalid dtype '{self.dtype}'") return torch_dtype def __post_init__(self): if self.n_repeats > 1 and not self.infer_timestep: raise ValueError( "infer_timestep must be True if n_repeats is greater than 1" ) if self.spatial_dimensions not in ["latlon", "healpix"]: raise ValueError( f"unexpected spatial_dimensions {self.spatial_dimensions}," " should be one of 'latlon' or 'healpix'" ) self.torch_dtype # check it can be retrieved self._default_file_pattern_check() @property def zarr_engine_used(self) -> bool: return self.engine == "zarr" def update_subset(self, subset: Slice | TimeSlice | RepeatedInterval): self.subset = subset def build( self, names: Sequence[str], n_timesteps: IntSchedule, ) -> tuple["XarraySubset", DatasetProperties]: return get_xarray_dataset( self, list(names), n_timesteps, )
class XarrayDataset(DatasetABC): """Load data from a directory of files matching a pattern using xarray. The number of contiguous timesteps to load for each sample is specified by the n_timesteps argument. For example, if the file(s) have the time coordinate (t0, t1, t2, t3, t4) and n_timesteps=3, then this dataset will provide three samples: (t0, t1, t2), (t1, t2, t3), and (t2, t3, t4). """ def __init__( self, config: XarrayDataConfig, names: Sequence[str], n_timesteps: IntSchedule ): self._horizontal_coordinates: HorizontalCoordinates self._names = names self.path = config.data_path self.file_pattern = config.file_pattern self.engine = config.engine self.dtype = config.torch_dtype self.spatial_dimensions = config.spatial_dimensions self.fill_nans = config.fill_nans self.subset_config = config.subset self._raw_paths = get_raw_paths(self.path, self.file_pattern) if len(self._raw_paths) == 0: raise ValueError( f"No files found matching '{self.path}/{self.file_pattern}'." ) self.full_paths = self._raw_paths * config.n_repeats self._n_timesteps_schedule = n_timesteps self._get_files_stats( config.n_repeats, config.infer_timestep, max_sample_n_times=n_timesteps.max_value, ) first_dataset = xr.open_dataset( self.full_paths[0], decode_times=False, decode_timedelta=False, engine=self.engine, chunks=None, ) self._mask_provider = _get_mask_provider(first_dataset, self.dtype) ( self._horizontal_coordinates, self._static_derived_data, _loaded_horizontal_dims, ) = self.configure_horizontal_coordinates(first_dataset) ( self._time_dependent_names, self._time_invariant_names, self._static_derived_names, ) = self._group_variable_names_by_time_type() self._vertical_coordinate = _get_vertical_coordinate(first_dataset, self.dtype) self.overwrite = config.overwrite self._nonspacetime_dims = get_nonspacetime_dimensions( first_dataset, _loaded_horizontal_dims ) self._shape_excluding_time = [ first_dataset.sizes[dim] for dim in (self._nonspacetime_dims + _loaded_horizontal_dims) ] self._loaded_dims = ["time"] + self._nonspacetime_dims + _loaded_horizontal_dims self.isel = { dim: v if isinstance(v, int) else v.slice for dim, v in config.isel.items() } self._isel_tuple = tuple( [self.isel.get(dim, SLICE_NONE) for dim in self._loaded_dims[1:]] ) self._check_isel_dimensions(first_dataset.sizes) self._apply_sample_n_times(self._n_timesteps_schedule.get_value(0)) self._labels = set(config.labels) if config.labels is not None else None self._infer_timestep = config.infer_timestep self._local_epoch: int = -1 self._global_epoch = torch.tensor( -1 ).share_memory_() # required for multi-worker parallelism def _ensure_epoch_synchronized(self): """Ensure that the local epoch is synchronized with the global epoch. This is required for multi-worker data loading, where each worker process has its own copy of the dataset object. """ if self._local_epoch != self._global_epoch.item(): self._local_epoch = self._global_epoch.item() sample_n_times = self._n_timesteps_schedule.get_value(self._local_epoch) self._apply_sample_n_times(sample_n_times) @property def _epoch(self) -> int | None: self._ensure_epoch_synchronized() if self._local_epoch == -1: return None return self._local_epoch def _apply_sample_n_times(self, sample_n_times: int): self._sample_n_times = sample_n_times logging.info( f"Dataset now has {self._n_initial_conditions} samples of " f"length {sample_n_times}." ) def _check_isel_dimensions(self, data_dim_sizes): # Horizontal dimensions are not currently supported, as the current isel code # does not adjust HorizonalCoordinates to match selection. if "time" in self.isel: raise ValueError("isel cannot be used to select time. Use subset instead.") for dim, selection in self.isel.items(): if dim not in self._nonspacetime_dims: raise ValueError( f"isel dimension {dim} must be a non-spacetime dimension " f"of the dataset ({self._nonspacetime_dims})." ) max_isel_index = ( (selection.start or 0) if isinstance(selection, slice) else selection ) if max_isel_index >= data_dim_sizes[dim]: raise ValueError( f"isel index {max_isel_index} is out of bounds for dimension " f"{dim} with size {data_dim_sizes[dim]}." ) @property def _shape_excluding_time_after_selection(self): final_shape = [] for orig_size, sel in zip(self._shape_excluding_time, self._isel_tuple): # if selecting a single index, dimension is squeezed # so it is not included in the final shape if isinstance(sel, slice): if sel.start is None and sel.stop is None and sel.step is None: final_shape.append(orig_size) else: final_shape.append(len(range(*sel.indices(orig_size)))) return final_shape @property def dims(self) -> list[str]: # Final dimensions of returned data after dims that are selected # with a single index are dropped final_dims = ["time"] for dim, sel in zip(self._loaded_dims[1:], self._isel_tuple): if isinstance(sel, slice): final_dims.append(dim) return final_dims @property def properties(self) -> DatasetProperties: return DatasetProperties( self._variable_metadata, self._vertical_coordinate, self._horizontal_coordinates, self._mask_provider, self.timestep, self._is_remote, self._labels, ) @property def _is_remote(self) -> bool: protocol = _get_protocol(str(self.path)) if not protocol or protocol == "file": return False return True def _get_variable_metadata(self, ds): result = {} for name in self._names: if name in StaticDerivedData.names: result[name] = StaticDerivedData.metadata[name] elif hasattr(ds[name], "units") and hasattr(ds[name], "long_name"): result[name] = VariableMetadata( units=ds[name].units, long_name=ds[name].long_name, ) self._variable_metadata = result def _get_files_stats( self, n_repeats: int, infer_timestep: bool, max_sample_n_times: int ): logging.info(f"Opening data at {os.path.join(self.path, self.file_pattern)}") raw_times = _get_raw_times(self._raw_paths, engine=self.engine) self._timestep: datetime.timedelta | None if infer_timestep: inferred_timestep = _get_timestep(np.concatenate(raw_times)) time_coord = _repeat_and_increment_time( raw_times, n_repeats, inferred_timestep ) self._timestep = inferred_timestep else: self._timestep = None time_coord = raw_times cum_num_timesteps = _get_cumulative_timesteps(time_coord) self.start_indices = cum_num_timesteps[:-1] self._total_timesteps = cum_num_timesteps[-1] self._n_initial_conditions = self._total_timesteps - max_sample_n_times + 1 self._sample_start_times = xr.CFTimeIndex( np.concatenate(time_coord)[: self._n_initial_conditions] ) self._all_times = xr.CFTimeIndex(np.concatenate(time_coord)) del cum_num_timesteps ds = self._open_file(0) self._get_variable_metadata(ds) def _group_variable_names_by_time_type(self) -> VariableNames: """Returns lists of time-dependent variable names, time-independent variable names, and variables which are only present as an initial condition. """ ( time_dependent_names, time_invariant_names, static_derived_names, ) = ([], [], []) # Don't use open_mfdataset here, because it will give time-invariant # fields a time dimension. We assume that all fields are present in the # netcdf file corresponding to the first chunk of time. with _open_xr_dataset(self.full_paths[0], engine=self.engine) as ds: for name in self._names: if name in StaticDerivedData.names: static_derived_names.append(name) else: try: da = ds[name] except KeyError: raise ValueError( f"Required variable not found in dataset: {name}." ) else: dims = da.dims if "time" in dims: time_dependent_names.append(name) else: time_invariant_names.append(name) logging.info( f"The required variables have been found in the dataset: {self._names}." ) return VariableNames( time_dependent_names, time_invariant_names, static_derived_names, ) def configure_horizontal_coordinates( self, first_dataset ) -> tuple[HorizontalCoordinates, StaticDerivedData, list[str]]: horizontal_coordinates: HorizontalCoordinates static_derived_data: StaticDerivedData horizontal_coordinates, dim_names = get_horizontal_coordinates( first_dataset, self.spatial_dimensions, self.dtype ) static_derived_data = StaticDerivedData(horizontal_coordinates) coords_sizes = { coord_name: len(coord) for coord_name, coord in horizontal_coordinates.coords.items() } logging.info(f"Horizontal coordinate sizes are {coords_sizes}.") return horizontal_coordinates, static_derived_data, dim_names @property def timestep(self) -> datetime.timedelta | None: if self._timestep is None: if self._infer_timestep is False: warnings.warn( "XarrayDataConfig.infer_timestep set to False. " "Timestep was not inferred in the data loader." ) return self._timestep else: raise ValueError( "Timestep was not inferred in the data loader. Note " "XarrayDataConfig.infer_timestep must be set to True for this " "to occur." ) else: return self._timestep def _open_file(self, idx): logger.debug(f"Opening file {self.full_paths[idx]}") return _open_file_fh_cached(self.full_paths[idx], engine=self.engine) @property def sample_start_times(self) -> xr.CFTimeIndex: """Return cftime index corresponding to start time of each sample.""" self._ensure_epoch_synchronized() return self._sample_start_times @property def all_times(self) -> xr.CFTimeIndex: """ Like sample_start_times, but includes all times in the dataset, including final times which are not valid as a start index. This is relevant for inference, where we may use get_sample_by_time_slice to retrieve time windows directly. If this dataset does not support inference, this will raise a NotImplementedError. """ return self._all_times @property def sample_n_times(self) -> int: """Number of timesteps in each sample.""" self._ensure_epoch_synchronized() return self._sample_n_times def __getitem__(self, idx: int) -> DatasetItem: """Return a sample of data spanning the timesteps [idx, idx + self.sample_n_times). Args: idx: Index of the sample to retrieve. Returns: Tuple of a sample's data (i.e. a mapping from names to torch.Tensors) and its corresponding time coordinate. """ self._ensure_epoch_synchronized() time_slice = slice(idx, idx + self.sample_n_times) return self.get_sample_by_time_slice(time_slice) def validate_inference_length(self, max_start_index: int, max_window_len: int): self._ensure_epoch_synchronized() if max_window_len + max_start_index > self._total_timesteps: raise ValueError( f"The maximum start index {max_start_index} plus window length " f"{max_window_len} must be less than or " f"equal to the number of steps in the dataset {self._total_timesteps}." ) def get_sample_by_time_slice(self, time_slice: slice) -> DatasetItem: self._ensure_epoch_synchronized() input_file_idx, input_local_idx = _get_file_local_index( time_slice.start, self.start_indices ) output_file_idx, output_local_idx = _get_file_local_index( time_slice.stop - 1, self.start_indices ) # get the sequence of observations arrays: dict[str, list[torch.Tensor]] = {} idxs = range(input_file_idx, output_file_idx + 1) total_steps = 0 for i, file_idx in enumerate(idxs): start = input_local_idx if i == 0 else 0 if i == len(idxs) - 1: stop = output_local_idx else: stop = ( self.start_indices[file_idx + 1] - self.start_indices[file_idx] - 1 ) n_steps = stop - start + 1 shape = [n_steps] + self._shape_excluding_time_after_selection total_steps += n_steps if self.engine == "zarr": tensor_dict = load_series_data_zarr_async( idx=start, n_steps=n_steps, path=self.full_paths[file_idx], names=self._time_dependent_names, final_dims=self.dims, final_shape=shape, fill_nans=self.fill_nans, nontime_selection=self._isel_tuple, ) else: ds = self._open_file(file_idx) ds = ds.isel(**self.isel) tensor_dict = load_series_data( idx=start, n_steps=n_steps, ds=ds, names=self._time_dependent_names, final_dims=self.dims, final_shape=shape, fill_nans=self.fill_nans, ) ds.close() del ds for n in self._time_dependent_names: arrays.setdefault(n, []).append(tensor_dict[n]) tensors: TensorDict = {} for n, tensor_list in arrays.items(): tensors[n] = torch.cat(tensor_list) del arrays # load time-invariant variables from first dataset if len(self._time_invariant_names) > 0: ds = self._open_file(idxs[0]) ds = ds.isel(**self.isel) shape = [total_steps] + self._shape_excluding_time_after_selection for name in self._time_invariant_names: variable = ds[name].variable if self.fill_nans is not None: variable = variable.fillna(self.fill_nans.value) tensors[name] = as_broadcasted_tensor(variable, self.dims, shape) ds.close() del ds # load static derived variables for name in self._static_derived_names: tensor = self._static_derived_data[name] horizontal_dims = [1] * tensor.ndim tensors[name] = tensor.repeat((total_steps, *horizontal_dims)) # cast to desired dtype tensors = {k: v.to(dtype=self.dtype) for k, v in tensors.items()} # Apply field overwrites tensors = self.overwrite.apply(tensors) # Create a DataArray of times to return corresponding to the slice that # is valid even when n_repeats > 1. time = xr.DataArray(self.all_times[time_slice].values, dims=["time"]) return tensors, time, self._labels, self._epoch def set_epoch(self, epoch: int): """ Set the epoch for the dataset. This will update the number of initial conditions and the sample start times if the number of timesteps is a schedule. """ self._global_epoch.fill_(epoch) # values get set lazily based on this def _get_timestep(time: np.ndarray) -> datetime.timedelta: """Computes the timestep of an array of a time coordinate array. Raises an error if the times are not separated by a positive constant interval, or if the array has one or fewer times. """ assert len(time.shape) == 1, "times must be a 1D array" if len(time) > 1: timesteps = np.diff(time) timestep = timesteps[0] if not (timestep > datetime.timedelta(days=0)): raise ValueError("Timestep of data must be greater than zero.") if not np.all(timesteps == timestep): raise ValueError("Time coordinate does not have a uniform timestep.") return timestep else: raise ValueError( "Time coordinate does not have enough times to infer a timestep." ) def _as_index_selection( subset: Slice | TimeSlice | RepeatedInterval, dataset: XarrayDataset ) -> slice | np.ndarray: """Converts a subset defined either as a Slice or TimeSlice into an index slice based on time coordinate in provided dataset. """ if isinstance(subset, Slice): index_selection = subset.slice elif isinstance(subset, TimeSlice): index_selection = subset.slice(dataset.sample_start_times) elif isinstance(subset, RepeatedInterval): try: index_selection = subset.get_boolean_mask(len(dataset), dataset.timestep) except ValueError as e: raise ValueError(f"Error when applying RepeatedInterval to dataset: {e}") else: raise TypeError(f"subset must be Slice or TimeSlice, got {type(subset)}") return index_selection class XarraySubset(DatasetABC): def __init__(self, dataset: XarrayDataset, subset: slice | np.ndarray): indices = np.arange(len(dataset))[subset] logging.info(f"Subsetting dataset samples according to {subset}.") self._wrapped_dataset = dataset self._dataset = torch.utils.data.Subset(dataset, indices) self._sample_start_times = dataset.sample_start_times[indices] self._sample_n_times = dataset.sample_n_times self._max_timestep_index: int | None = None if len(indices) > 0 and np.all(indices[:-1] <= indices[1:]): self._max_timestep_index = indices[-1] + dataset.sample_n_times - 1 self.dims = dataset.dims def __getitem__(self, idx: int) -> DatasetItem: return self._dataset[idx] @property def sample_start_times(self): return self._sample_start_times @property def all_times(self) -> xr.CFTimeIndex: """ Like sample_start_times, but includes all times in the dataset, including final times which are not valid as a start index. This is relevant for inference, where we may use get_sample_by_time_slice to retrieve time windows directly. If this dataset does not support inference, this will raise a NotImplementedError. """ raise NotImplementedError("XarraySubset does not support inference.") @property def sample_n_times(self) -> int: """The length of the time dimension of each sample.""" return self._sample_n_times def get_sample_by_time_slice(self, time_slice: slice) -> DatasetItem: raise NotImplementedError( "XarraySubset does not support getting samples by time slice, " "is this a bug?." ) def validate_inference_length(self, max_start_index: int, max_window_len: int): if self._max_timestep_index is None: raise ValueError( "XarraySubset that does not preserve time ordering of the data " "cannot be used for inference." ) if max_start_index + max_window_len - 1 > self._max_timestep_index: raise ValueError( f"The maximum start index {max_start_index} plus forward steps " f"{max_window_len - 1} must be less than or equal to the " f"max timestep index in the dataset {self._max_timestep_index}." ) @property def properties(self) -> DatasetProperties: return self._wrapped_dataset.properties def set_epoch(self, epoch: int): self._wrapped_dataset.set_epoch(epoch) def get_xarray_dataset( config: XarrayDataConfig, names: Sequence[str], n_timesteps: IntSchedule ) -> tuple["XarraySubset", DatasetProperties]: dataset = XarrayDataset(config, names, n_timesteps) properties = dataset.properties index_slice = _as_index_selection(config.subset, dataset) return XarraySubset(dataset, index_slice), properties def get_xarray_datasets( dataset_configs: Sequence[XarrayDataConfig], names: Sequence[str], n_timesteps: IntSchedule, strict: bool = True, ) -> tuple[list[XarraySubset], DatasetProperties]: datasets = [] properties: DatasetProperties | None = None for config in dataset_configs: dataset, new_properties = get_xarray_dataset(config, names, n_timesteps) datasets.append(dataset) if properties is None: properties = new_properties else: properties.update(new_properties, strict=strict) if properties is None: raise ValueError("At least one dataset must be provided.") return datasets, properties