Source code for fme.core.dataset.xarray

import dataclasses
import datetime
import functools
import json
import logging
import multiprocessing
import os
import re
import warnings
from collections import namedtuple
from collections.abc import Mapping, Sequence
from functools import lru_cache
from typing import Literal
from urllib.parse import urlparse

import fsspec
import numpy as np
import torch
import xarray as xr
from xarray.coding.times import CFDatetimeCoder

from fme.core.coordinates import (
    DepthCoordinate,
    HorizontalCoordinates,
    HybridSigmaPressureCoordinate,
    NullVerticalCoordinate,
    VerticalCoordinate,
)
from fme.core.dataset.config import DatasetConfigABC
from fme.core.dataset.properties import DatasetProperties
from fme.core.dataset.time import RepeatedInterval, TimeSlice
from fme.core.dataset.utils import FillNaNsConfig
from fme.core.mask_provider import MaskProvider
from fme.core.stacker import Stacker
from fme.core.typing_ import Slice, TensorDict

from .data_typing import VariableMetadata
from .utils import (
    as_broadcasted_tensor,
    get_horizontal_coordinates,
    get_nonspacetime_dimensions,
    load_series_data,
    load_series_data_zarr_async,
)

SLICE_NONE = slice(None)
GET_RAW_TIMES_NUM_FILES_PARALLELIZATION_THRESHOLD = 12
logger = logging.getLogger(__name__)

VariableNames = namedtuple(
    "VariableNames",
    (
        "time_dependent_names",
        "time_invariant_names",
        "static_derived_names",
    ),
)


def _get_vertical_coordinate(
    ds: xr.Dataset, dtype: torch.dtype | None
) -> VerticalCoordinate:
    """
    Get vertical coordinate from a dataset.

    If the dataset contains variables named `ak_N` and `bk_N` where
    `N` is the level number, then a hybrid sigma-pressure coordinate
    will be returned. If the dataset contains variables named
    `idepth_N` then a depth coordinate will be returned. If neither thing
    is true, a hybrid sigma-pressure coordinate of lenght 0 is returned.

    Args:
        ds: Dataset to get vertical coordinates from.
        dtype: Data type of the returned tensors. If None, the dtype is not
            changed from the original in ds.
    """
    ak_mapping = {
        int(v[3:]): torch.as_tensor(ds[v].values)
        for v in ds.variables
        if v.startswith("ak_")
    }
    bk_mapping = {
        int(v[3:]): torch.as_tensor(ds[v].values)
        for v in ds.variables
        if v.startswith("bk_")
    }
    ak_list = [ak_mapping[k] for k in sorted(ak_mapping.keys())]
    bk_list = [bk_mapping[k] for k in sorted(bk_mapping.keys())]

    idepth_mapping = {
        int(v[7:]): torch.as_tensor(ds[v].values)
        for v in ds.variables
        if v.startswith("idepth_")
    }
    idepth_list = [idepth_mapping[k] for k in sorted(idepth_mapping.keys())]

    if len(ak_list) > 0 and len(bk_list) > 0 and len(idepth_list) > 0:
        raise ValueError(
            "Dataset contains both hybrid sigma-pressure and depth coordinates. "
            "Can only provide one, or else the vertical coordinate is ambiguous."
        )

    coordinate: VerticalCoordinate
    surface_mask = None
    if len(idepth_list) > 0:
        if "mask_0" in ds.data_vars:
            mask_layers = {
                name: torch.as_tensor(ds[name].values, dtype=dtype)
                for name in ds.data_vars
                if re.match(r"mask_(\d+)$", name)
            }
            for name in mask_layers:
                if "time" in ds[name].dims:
                    raise ValueError("The ocean mask must by time-independent.")
            stacker = Stacker({"mask": ["mask_"]})
            mask = stacker("mask", mask_layers)
            if "surface_mask" in ds.data_vars:
                if "time" in ds["surface_mask"].dims:
                    raise ValueError("The surface mask must be time-independent.")
                surface_mask = torch.as_tensor(ds["surface_mask"].values, dtype=dtype)
        else:
            logger.warning(
                "Dataset does not contain a mask. Providing a DepthCoordinate with "
                "mask set to 1 at all layers."
            )
            mask = torch.ones(len(idepth_list) - 1, dtype=dtype)
        coordinate = DepthCoordinate(
            torch.as_tensor(idepth_list, dtype=dtype), mask, surface_mask
        )
    elif len(ak_list) > 0 and len(bk_list) > 0:
        coordinate = HybridSigmaPressureCoordinate(
            ak=torch.as_tensor(ak_list, dtype=dtype),
            bk=torch.as_tensor(bk_list, dtype=dtype),
        )
    else:
        logger.warning("Dataset does not contain a vertical coordinate.")
        coordinate = NullVerticalCoordinate()

    return coordinate


def _get_raw_times_single_file(path: str, engine: str | None = None) -> np.array:
    with _open_xr_dataset(path, engine=engine) as ds:
        return ds.time.values


def _get_raw_times(paths: list[str], engine: str) -> list[np.ndarray]:
    function = functools.partial(_get_raw_times_single_file, engine=engine)

    # Only parallelize if we are loading from a reasonable number of files; this
    # helps speed up data loading tests, which otherwise would be slowed by the
    # overhead of setting up a pool.
    if len(paths) > GET_RAW_TIMES_NUM_FILES_PARALLELIZATION_THRESHOLD:
        processes = min(multiprocessing.cpu_count(), len(paths))
        with multiprocessing.Pool(processes) as pool:
            return pool.map(function, paths)
    else:
        return list(map(function, paths))


def _repeat_and_increment_time(
    raw_times: list[np.ndarray], n_repeats: int, timestep: datetime.timedelta
) -> list[np.ndarray]:
    """Repeats and increments a collection of arrays of evenly spaced times."""
    n_timesteps = sum(len(times) for times in raw_times)
    timespan = timestep * n_timesteps

    repeated_and_incremented_time = []
    for repeats in range(n_repeats):
        increment = repeats * timespan
        for time in raw_times:
            incremented_time = time + increment
            repeated_and_incremented_time.append(incremented_time)
    return repeated_and_incremented_time


def _get_cumulative_timesteps(time: list[np.ndarray]) -> np.ndarray:
    """Returns a list of cumulative timesteps for each item in a time coordinate."""
    num_timesteps_per_file = [0]
    for time_coord in time:
        num_timesteps_per_file.append(len(time_coord))
    return np.array(num_timesteps_per_file).cumsum()


def _get_file_local_index(index: int, start_indices: np.ndarray) -> tuple[int, int]:
    """
    Return a tuple of the index of the file containing the time point at `index`
    and the index of the time point within that file.
    """
    file_index = np.searchsorted(start_indices, index, side="right") - 1
    time_index = index - start_indices[file_index]
    return int(file_index), time_index


class StaticDerivedData:
    names = ("x", "y", "z")
    metadata = {
        "x": VariableMetadata(units="", long_name="Euclidean x-coordinate"),
        "y": VariableMetadata(units="", long_name="Euclidean y-coordinate"),
        "z": VariableMetadata(units="", long_name="Euclidean z-coordinate"),
    }

    def __init__(self, coordinates: HorizontalCoordinates):
        self._coords = coordinates
        self._x: torch.Tensor | None = None
        self._y: torch.Tensor | None = None
        self._z: torch.Tensor | None = None

    def _get_xyz(self) -> TensorDict:
        if self._x is None or self._y is None or self._z is None:
            coords = self._coords
            x, y, z = coords.xyz

            self._x = torch.as_tensor(x)
            self._y = torch.as_tensor(y)
            self._z = torch.as_tensor(z)

        return {"x": self._x, "y": self._y, "z": self._z}

    def __getitem__(self, name: str) -> torch.Tensor:
        return self._get_xyz()[name]


def _get_protocol(path):
    return urlparse(str(path)).scheme


def _get_fs(path):
    protocol = _get_protocol(path)
    if not protocol:
        protocol = "file"
    proto_kw = _get_fs_protocol_kwargs(path)
    fs = fsspec.filesystem(protocol, **proto_kw)

    return fs


def _preserve_protocol(original_path, glob_paths):
    protocol = _get_protocol(str(original_path))
    if protocol:
        glob_paths = [f"{protocol}://{path}" for path in glob_paths]
    return glob_paths


def _get_fs_protocol_kwargs(path):
    protocol = _get_protocol(path)
    kwargs = {}
    if protocol == "gs":
        # https://gcsfs.readthedocs.io/en/latest/api.html#gcsfs.core.GCSFileSystem
        key_json = os.environ.get("FSSPEC_GS_KEY_JSON", None)
        key_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", None)

        if key_json is not None:
            token = json.loads(key_json)
        elif key_file is not None:
            token = key_file
        else:
            logger.warning(
                "GCS currently expects user credentials authenticated using"
                " `gcloud auth application-default login`. This is not recommended for "
                "production use."
            )
            token = "google_default"
        kwargs["token"] = token
    elif protocol == "s3":
        # https://s3fs.readthedocs.io/en/latest/#s3-compatible-storage
        env_vars = [
            "FSSPEC_S3_KEY",
            "FSSPEC_S3_SECRET",
            "FSSPEC_S3_ENDPOINT_URL",
        ]
        for v in env_vars:
            if v not in os.environ:
                warnings.warn(
                    f"An S3 path was specified but environment variable {v} "
                    "was not found. This may cause authentication issues if not "
                    "set and no other defaults are present. See "
                    "https://s3fs.readthedocs.io/en/latest/#s3-compatible-storage"
                    " for details."
                )

    return kwargs


def _open_xr_dataset(path: str, *args, **kwargs):
    # need the path to get protocol specific arguments for the backend
    protocol_kw = _get_fs_protocol_kwargs(path)
    if protocol_kw:
        kwargs.update({"storage_options": protocol_kw})

    return xr.open_dataset(
        path,
        *args,
        decode_times=CFDatetimeCoder(use_cftime=True),
        decode_timedelta=False,
        mask_and_scale=False,
        cache=False,
        chunks=None,
        **kwargs,
    )


_open_xr_dataset_lru = lru_cache()(_open_xr_dataset)


def _open_file_fh_cached(path, **kwargs):
    protocol = _get_protocol(path)
    if protocol:
        # add an LRU cache for remote zarrs
        return _open_xr_dataset_lru(
            path,
            **kwargs,
        )
    # netcdf4 and h5engine have a filehandle LRU cache in xarray
    # https://github.com/pydata/xarray/blob/cd3ab8d5580eeb3639d38e1e884d2d9838ef6aa1/xarray/backends/file_manager.py#L54 # noqa: E501
    return _open_xr_dataset(
        path,
        **kwargs,
    )


def get_raw_paths(path, file_pattern):
    fs = _get_fs(path)
    glob_paths = sorted(fs.glob(os.path.join(path, file_pattern)))
    raw_paths = _preserve_protocol(path, glob_paths)
    return raw_paths


def _get_mask_provider(ds: xr.Dataset, dtype: torch.dtype | None) -> MaskProvider:
    """
    Get mask provider from a dataset.

    If the dataset contains static variables that start with the string "mask_" or a
    variable named "surface_mask", then these variables will be used to instantiate
    a MaskProvider object. Otherwise, an empty MaskProvider is returned.

    Args:
        ds: Dataset to get vertical coordinates from.
        dtype: Data type of the returned tensors. If None, the dtype is not
            changed from the original in ds.
    """
    masks: dict[str, torch.Tensor] = {
        name: torch.as_tensor(ds[name].values, dtype=dtype)
        for name in ds.data_vars
        if "mask_" in name
    }
    for name in masks:
        if "time" in ds[name].dims:
            raise ValueError("Masks must be time-independent.")
    mask_provider = MaskProvider(masks)
    logging.info(f"Initialized {mask_provider}.")
    return mask_provider


[docs]@dataclasses.dataclass class OverwriteConfig: """Configuration to overwrite field values in XarrayDataset. Parameters: constant: Fill field with constant value. multiply_scalar: Multiply field by scalar value. """ constant: Mapping[str, float] = dataclasses.field(default_factory=dict) multiply_scalar: Mapping[str, float] = dataclasses.field(default_factory=dict) def __post_init__(self): key_overlap = set(self.constant.keys()) & set(self.multiply_scalar.keys()) if key_overlap: raise ValueError( "OverwriteConfig cannot have the same variable in both constant " f"and multiply_scalar: {key_overlap}" ) def apply(self, tensors: TensorDict) -> TensorDict: for var, fill_value in self.constant.items(): data = tensors[var] tensors[var] = torch.ones_like(data) * torch.tensor( fill_value, dtype=data.dtype, device=data.device ) for var, multiplier in self.multiply_scalar.items(): data = tensors[var] tensors[var] = data * torch.tensor( multiplier, dtype=data.dtype, device=data.device ) return tensors @property def variables(self): return set(self.constant.keys()) | set(self.multiply_scalar.keys())
[docs]@dataclasses.dataclass class XarrayDataConfig(DatasetConfigABC): """ Parameters: data_path: Path to the data. file_pattern: Glob pattern to match files in the data_path. n_repeats: Number of times to repeat the dataset (in time). It is up to the user to ensure that the input dataset to repeat results in data that is reasonably continuous across repetitions. engine: Backend used in xarray.open_dataset call. spatial_dimensions: Specifies the spatial dimensions for the grid, default is lat/lon. If 'latlon', it is assumed that the last two dimensions are latitude and longitude, respectively. If 'healpix', it is assumed that the last three dimensions are face, height, and width, respectively. subset: Slice defining a subset of the XarrayDataset to load. This can either be a `Slice` of integer indices or a `TimeSlice` of timestamps. This feature is applied directly to the dataset samples. For example, if the file(s) have the time coordinate (t0, t1, t2, t3) and requirements.n_timesteps=2, then subset=Slice(stop=2) will provide two samples: (t0, t1), (t1, t2). infer_timestep: Whether to infer the timestep from the provided data. This should be set to True (the default) for ACE training. It may be useful to toggle this to False for applications like downscaling, which do not depend on the timestep of the data and therefore lack the additional requirement that the data be ordered and evenly spaced in time. It must be set to True if n_repeats > 1 in order to be able to infer the full time coordinate. dtype: Data type to cast the data to. If None, no casting is done. It is required that 'torch.{dtype}' is a valid dtype. overwrite: Optional OverwriteConfig to overwrite loaded field values. fill_nans: Optional FillNaNsConfig to fill NaNs with a constant value. isel: Optional xarray isel arguments to be passed to the dataset. Will raise ValueError if time is included here, since the subset argument is used specifically for selecting times. Horizontal dimensions are also not currently supported. labels: Optional list of labels to be returned with the data. Examples: If data is stored in a directory with multiple netCDF files which can be concatenated along the time dimension, use: >>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") # doctest: +IGNORE_OUTPUT If data is stored in a single zarr store at ``/some/directory/dataset.zarr``, use: >>> fme.ace.XarrayDataConfig( ... data_path="/some/directory", ... file_pattern="dataset.zarr", ... engine="zarr" ... ) # doctest: +IGNORE_OUTPUT """ # noqa: E501 data_path: str file_pattern: str = "*.nc" n_repeats: int = 1 engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4" spatial_dimensions: Literal["healpix", "latlon"] = "latlon" subset: Slice | TimeSlice | RepeatedInterval = dataclasses.field( default_factory=Slice ) infer_timestep: bool = True dtype: str | None = "float32" overwrite: OverwriteConfig = dataclasses.field(default_factory=OverwriteConfig) fill_nans: FillNaNsConfig | None = None isel: Mapping[str, Slice | int] = dataclasses.field(default_factory=dict) labels: list[str] = dataclasses.field(default_factory=list) def _default_file_pattern_check(self): if self.engine == "zarr" and self.file_pattern == "*.nc": raise ValueError( "The file pattern is set to the default NetCDF file pattern *.nc " "but the engine is specified as 'zarr'. Please set " "`XarrayDataConfig.file_pattern` to match the zarr filename." ) @property def torch_dtype(self) -> torch.dtype | None: if self.dtype is None: return None else: try: torch_dtype = getattr(torch, self.dtype) except AttributeError: raise ValueError(f"Invalid dtype '{self.dtype}'") if not isinstance(torch_dtype, torch.dtype): raise ValueError(f"Invalid dtype '{self.dtype}'") return torch_dtype def __post_init__(self): if self.n_repeats > 1 and not self.infer_timestep: raise ValueError( "infer_timestep must be True if n_repeats is greater than 1" ) if self.spatial_dimensions not in ["latlon", "healpix"]: raise ValueError( f"unexpected spatial_dimensions {self.spatial_dimensions}," " should be one of 'latlon' or 'healpix'" ) self.torch_dtype # check it can be retrieved self._default_file_pattern_check() self.zarr_engine_used = False if self.engine == "zarr": self.zarr_engine_used = True def build( self, names: Sequence[str], n_timesteps: int, ) -> tuple[torch.utils.data.Dataset, DatasetProperties]: return get_xarray_dataset( self, list(names), n_timesteps, )
class XarrayDataset(torch.utils.data.Dataset): """Load data from a directory of files matching a pattern using xarray. The number of contiguous timesteps to load for each sample is specified by the n_timesteps argument. For example, if the file(s) have the time coordinate (t0, t1, t2, t3, t4) and n_timesteps=3, then this dataset will provide three samples: (t0, t1, t2), (t1, t2, t3), and (t2, t3, t4). """ def __init__( self, config: XarrayDataConfig, names: Sequence[str], n_timesteps: int ): self._horizontal_coordinates: HorizontalCoordinates self._names = names self.path = config.data_path self.file_pattern = config.file_pattern self.engine = config.engine self.dtype = config.torch_dtype self.spatial_dimensions = config.spatial_dimensions self.fill_nans = config.fill_nans self.subset_config = config.subset self._raw_paths = get_raw_paths(self.path, self.file_pattern) if len(self._raw_paths) == 0: raise ValueError( f"No files found matching '{self.path}/{self.file_pattern}'." ) self.full_paths = self._raw_paths * config.n_repeats self.sample_n_times = n_timesteps self._get_files_stats(config.n_repeats, config.infer_timestep) first_dataset = xr.open_dataset( self.full_paths[0], decode_times=False, decode_timedelta=False, engine=self.engine, chunks=None, ) self._mask_provider = _get_mask_provider(first_dataset, self.dtype) ( self._horizontal_coordinates, self._static_derived_data, _loaded_horizontal_dims, ) = self.configure_horizontal_coordinates(first_dataset) ( self._time_dependent_names, self._time_invariant_names, self._static_derived_names, ) = self._group_variable_names_by_time_type() self._vertical_coordinate = _get_vertical_coordinate(first_dataset, self.dtype) self.overwrite = config.overwrite self._nonspacetime_dims = get_nonspacetime_dimensions( first_dataset, _loaded_horizontal_dims ) self._shape_excluding_time = [ first_dataset.sizes[dim] for dim in (self._nonspacetime_dims + _loaded_horizontal_dims) ] self._loaded_dims = ["time"] + self._nonspacetime_dims + _loaded_horizontal_dims self.isel = { dim: v if isinstance(v, int) else v.slice for dim, v in config.isel.items() } self._isel_tuple = tuple( [self.isel.get(dim, SLICE_NONE) for dim in self._loaded_dims[1:]] ) self._check_isel_dimensions(first_dataset.sizes) self._labels = set(config.labels) self._infer_timestep = config.infer_timestep def _check_isel_dimensions(self, data_dim_sizes): # Horizontal dimensions are not currently supported, as the current isel code # does not adjust HorizonalCoordinates to match selection. if "time" in self.isel: raise ValueError("isel cannot be used to select time. Use subset instead.") for dim, selection in self.isel.items(): if dim not in self._nonspacetime_dims: raise ValueError( f"isel dimension {dim} must be a non-spacetime dimension " f"of the dataset ({self._nonspacetime_dims})." ) max_isel_index = ( (selection.start or 0) if isinstance(selection, slice) else selection ) if max_isel_index >= data_dim_sizes[dim]: raise ValueError( f"isel index {max_isel_index} is out of bounds for dimension " f"{dim} with size {data_dim_sizes[dim]}." ) @property def _shape_excluding_time_after_selection(self): final_shape = [] for orig_size, sel in zip(self._shape_excluding_time, self._isel_tuple): # if selecting a single index, dimension is squeezed # so it is not included in the final shape if isinstance(sel, slice): if sel.start is None and sel.stop is None and sel.step is None: final_shape.append(orig_size) else: final_shape.append(len(range(*sel.indices(orig_size)))) return final_shape @property def dims(self) -> list[str]: # Final dimensions of returned data after dims that are selected # with a single index are dropped final_dims = ["time"] for dim, sel in zip(self._loaded_dims[1:], self._isel_tuple): if isinstance(sel, slice): final_dims.append(dim) return final_dims @property def properties(self) -> DatasetProperties: return DatasetProperties( self._variable_metadata, self._vertical_coordinate, self._horizontal_coordinates, self._mask_provider, self.timestep, self._is_remote, self._labels, ) @property def _is_remote(self) -> bool: protocol = _get_protocol(str(self.path)) if not protocol or protocol == "file": return False return True @property def all_times(self) -> xr.CFTimeIndex: """Time index of all available times in the data.""" return self._all_times def _get_variable_metadata(self, ds): result = {} for name in self._names: if name in StaticDerivedData.names: result[name] = StaticDerivedData.metadata[name] elif hasattr(ds[name], "units") and hasattr(ds[name], "long_name"): result[name] = VariableMetadata( units=ds[name].units, long_name=ds[name].long_name, ) self._variable_metadata = result def _get_files_stats(self, n_repeats: int, infer_timestep: bool): logging.info(f"Opening data at {os.path.join(self.path, self.file_pattern)}") raw_times = _get_raw_times(self._raw_paths, engine=self.engine) self._timestep: datetime.timedelta | None if infer_timestep: inferred_timestep = _get_timestep(np.concatenate(raw_times)) time_coord = _repeat_and_increment_time( raw_times, n_repeats, inferred_timestep ) self._timestep = inferred_timestep else: self._timestep = None time_coord = raw_times cum_num_timesteps = _get_cumulative_timesteps(time_coord) self.start_indices = cum_num_timesteps[:-1] self.total_timesteps = cum_num_timesteps[-1] self._n_initial_conditions = self.total_timesteps - self.sample_n_times + 1 self._sample_start_times = xr.CFTimeIndex( np.concatenate(time_coord)[: self._n_initial_conditions] ) self._all_times = xr.CFTimeIndex(np.concatenate(time_coord)) del cum_num_timesteps, time_coord ds = self._open_file(0) self._get_variable_metadata(ds) logging.info(f"Found {self._n_initial_conditions} samples.") def _group_variable_names_by_time_type(self) -> VariableNames: """Returns lists of time-dependent variable names, time-independent variable names, and variables which are only present as an initial condition. """ ( time_dependent_names, time_invariant_names, static_derived_names, ) = ([], [], []) # Don't use open_mfdataset here, because it will give time-invariant # fields a time dimension. We assume that all fields are present in the # netcdf file corresponding to the first chunk of time. with _open_xr_dataset(self.full_paths[0], engine=self.engine) as ds: for name in self._names: if name in StaticDerivedData.names: static_derived_names.append(name) else: try: da = ds[name] except KeyError: raise ValueError( f"Required variable not found in dataset: {name}." ) else: dims = da.dims if "time" in dims: time_dependent_names.append(name) else: time_invariant_names.append(name) logging.info( f"The required variables have been found in the dataset: {self._names}." ) return VariableNames( time_dependent_names, time_invariant_names, static_derived_names, ) def configure_horizontal_coordinates( self, first_dataset ) -> tuple[HorizontalCoordinates, StaticDerivedData, list[str]]: horizontal_coordinates: HorizontalCoordinates static_derived_data: StaticDerivedData horizontal_coordinates, dim_names = get_horizontal_coordinates( first_dataset, self.spatial_dimensions, self.dtype ) static_derived_data = StaticDerivedData(horizontal_coordinates) coords_sizes = { coord_name: len(coord) for coord_name, coord in horizontal_coordinates.coords.items() } logging.info(f"Horizontal coordinate sizes are {coords_sizes}.") return horizontal_coordinates, static_derived_data, dim_names @property def timestep(self) -> datetime.timedelta | None: if self._timestep is None: if self._infer_timestep is False: warnings.warn( "XarrayDataConfig.infer_timestep set to False. " "Timestep was not inferred in the data loader." ) return self._timestep else: raise ValueError( "Timestep was not inferred in the data loader. Note " "XarrayDataConfig.infer_timestep must be set to True for this " "to occur." ) else: return self._timestep def __len__(self): return self._n_initial_conditions def _open_file(self, idx): logger.debug(f"Opening file {self.full_paths[idx]}") return _open_file_fh_cached(self.full_paths[idx], engine=self.engine) @property def sample_start_times(self) -> xr.CFTimeIndex: """Return cftime index corresponding to start time of each sample.""" return self._sample_start_times def __getitem__(self, idx: int) -> tuple[TensorDict, xr.DataArray, set[str]]: """Return a sample of data spanning the timesteps [idx, idx + self.sample_n_times). Args: idx: Index of the sample to retrieve. Returns: Tuple of a sample's data (i.e. a mapping from names to torch.Tensors) and its corresponding time coordinate. """ time_slice = slice(idx, idx + self.sample_n_times) return self.get_sample_by_time_slice(time_slice) def get_sample_by_time_slice( self, time_slice: slice ) -> tuple[TensorDict, xr.DataArray, set[str]]: input_file_idx, input_local_idx = _get_file_local_index( time_slice.start, self.start_indices ) output_file_idx, output_local_idx = _get_file_local_index( time_slice.stop - 1, self.start_indices ) # get the sequence of observations arrays: dict[str, list[torch.Tensor]] = {} idxs = range(input_file_idx, output_file_idx + 1) total_steps = 0 for i, file_idx in enumerate(idxs): start = input_local_idx if i == 0 else 0 if i == len(idxs) - 1: stop = output_local_idx else: stop = ( self.start_indices[file_idx + 1] - self.start_indices[file_idx] - 1 ) n_steps = stop - start + 1 shape = [n_steps] + self._shape_excluding_time_after_selection total_steps += n_steps if self.engine == "zarr": tensor_dict = load_series_data_zarr_async( idx=start, n_steps=n_steps, path=self.full_paths[file_idx], names=self._time_dependent_names, final_dims=self.dims, final_shape=shape, fill_nans=self.fill_nans, nontime_selection=self._isel_tuple, ) else: ds = self._open_file(file_idx) ds = ds.isel(**self.isel) tensor_dict = load_series_data( idx=start, n_steps=n_steps, ds=ds, names=self._time_dependent_names, final_dims=self.dims, final_shape=shape, fill_nans=self.fill_nans, ) ds.close() del ds for n in self._time_dependent_names: arrays.setdefault(n, []).append(tensor_dict[n]) tensors: TensorDict = {} for n, tensor_list in arrays.items(): tensors[n] = torch.cat(tensor_list) del arrays # load time-invariant variables from first dataset if len(self._time_invariant_names) > 0: ds = self._open_file(idxs[0]) ds = ds.isel(**self.isel) shape = [total_steps] + self._shape_excluding_time_after_selection for name in self._time_invariant_names: variable = ds[name].variable if self.fill_nans is not None: variable = variable.fillna(self.fill_nans.value) tensors[name] = as_broadcasted_tensor(variable, self.dims, shape) ds.close() del ds # load static derived variables for name in self._static_derived_names: tensor = self._static_derived_data[name] horizontal_dims = [1] * tensor.ndim tensors[name] = tensor.repeat((total_steps, *horizontal_dims)) # cast to desired dtype tensors = {k: v.to(dtype=self.dtype) for k, v in tensors.items()} # Apply field overwrites tensors = self.overwrite.apply(tensors) # Create a DataArray of times to return corresponding to the slice that # is valid even when n_repeats > 1. time = xr.DataArray(self.all_times[time_slice].values, dims=["time"]) return tensors, time, self._labels def _get_timestep(time: np.ndarray) -> datetime.timedelta: """Computes the timestep of an array of a time coordinate array. Raises an error if the times are not separated by a positive constant interval, or if the array has one or fewer times. """ assert len(time.shape) == 1, "times must be a 1D array" if len(time) > 1: timesteps = np.diff(time) timestep = timesteps[0] if not (timestep > datetime.timedelta(days=0)): raise ValueError("Timestep of data must be greater than zero.") if not np.all(timesteps == timestep): raise ValueError("Time coordinate does not have a uniform timestep.") return timestep else: raise ValueError( "Time coordinate does not have enough times to infer a timestep." ) def _as_index_selection( subset: Slice | TimeSlice | RepeatedInterval, dataset: XarrayDataset ) -> slice | np.ndarray: """Converts a subset defined either as a Slice or TimeSlice into an index slice based on time coordinate in provided dataset. """ if isinstance(subset, Slice): index_selection = subset.slice elif isinstance(subset, TimeSlice): index_selection = subset.slice(dataset.sample_start_times) elif isinstance(subset, RepeatedInterval): try: index_selection = subset.get_boolean_mask(len(dataset), dataset.timestep) except ValueError as e: raise ValueError(f"Error when applying RepeatedInterval to dataset: {e}") else: raise TypeError(f"subset must be Slice or TimeSlice, got {type(subset)}") return index_selection class XarraySubset(torch.utils.data.Dataset): def __init__(self, dataset: XarrayDataset, subset: slice | np.ndarray): indices = np.arange(len(dataset))[subset] logging.info(f"Subsetting dataset samples according to {subset}.") self._dataset = torch.utils.data.Subset(dataset, indices) self._sample_start_times = dataset.sample_start_times[indices] self._sample_n_times = dataset.sample_n_times self.dims = dataset.dims def __len__(self): return len(self._dataset) def __getitem__(self, idx: int) -> tuple[TensorDict, xr.DataArray, set[str]]: return self._dataset[idx] @property def sample_start_times(self): return self._sample_start_times @property def sample_n_times(self) -> int: """The length of the time dimension of each sample.""" return self._sample_n_times def get_xarray_dataset( config: XarrayDataConfig, names: Sequence[str], n_timesteps: int ) -> tuple["XarraySubset", DatasetProperties]: dataset = XarrayDataset(config, names, n_timesteps) properties = dataset.properties index_slice = _as_index_selection(config.subset, dataset) dataset = XarraySubset(dataset, index_slice) return dataset, properties def get_xarray_datasets( dataset_configs: Sequence[XarrayDataConfig], names: Sequence[str], n_timesteps: int, strict: bool = True, ) -> tuple[list[XarraySubset], DatasetProperties]: datasets = [] properties: DatasetProperties | None = None for config in dataset_configs: dataset, new_properties = get_xarray_dataset(config, names, n_timesteps) datasets.append(dataset) if properties is None: properties = new_properties elif not strict: try: properties.update(new_properties) except ValueError as e: warnings.warn( f"Metadata for each ensemble member are not the same: {e}" ) else: properties.update(new_properties) if properties is None: raise ValueError("At least one dataset must be provided.") return datasets, properties