Source code for fme.downscaling.models

import dataclasses
from collections.abc import Mapping
from typing import Any

import dacite
import torch

from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.loss import LossConfig
from fme.core.normalizer import NormalizationConfig, StandardNormalizer
from fme.core.optimization import NullOptimization, Optimization
from fme.core.packer import Packer
from fme.core.rand import randn, randn_like
from fme.core.typing_ import TensorDict, TensorMapping
from fme.downscaling.data import (
    BatchData,
    PairedBatchData,
    StaticInputs,
    Topography,
    get_normalized_topography,
)
from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate
from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector
from fme.downscaling.requirements import DataRequirements
from fme.downscaling.samplers import stochastic_sampler as edm_sampler
from fme.downscaling.typing_ import FineResCoarseResPair


@dataclasses.dataclass
class ModelOutputs:
    prediction: TensorDict
    target: TensorDict
    loss: torch.Tensor
    latent_steps: list[torch.Tensor] = dataclasses.field(default_factory=list)


def _rename_normalizer(
    normalizer: StandardNormalizer, rename: dict[str, str] | None
) -> StandardNormalizer:
    if not rename:
        return normalizer
    new_means = {
        rename.get(name, name): value for name, value in normalizer.means.items()
    }
    new_stds = {
        rename.get(name, name): value for name, value in normalizer.stds.items()
    }
    return StandardNormalizer(means=new_means, stds=new_stds)


@dataclasses.dataclass
class PairedNormalizationConfig:
    fine: NormalizationConfig
    coarse: NormalizationConfig

    def build(
        self,
        in_names: list[str],
        out_names: list[str],
        rename: dict[str, str] | None = None,
    ) -> FineResCoarseResPair[StandardNormalizer]:
        coarse = self.coarse.build(list(set(in_names).union(out_names)))
        fine = self.fine.build(out_names)

        return FineResCoarseResPair[StandardNormalizer](
            coarse=_rename_normalizer(coarse, rename),
            fine=_rename_normalizer(fine, rename),
        )

    def load(self):
        """
        Load the normalization configuration from the netCDF files.

        Updates the configuration so it no longer requires external files.
        """
        self.fine.load()
        self.coarse.load()


@dataclasses.dataclass
class DiffusionModelConfig:
    """
    This class implements or wraps the algorithms described in `EDM`_.

    .. _EDM: https://arxiv.org/abs/2206.00364

    Parameters:
        module: The module registry selector for the diffusion model.
        loss:  The loss configuration for the diffusion model.
        in_names: The input variable names for the diffusion model.
        out_names: The output variable names for the diffusion model.
        normalization: The normalization configurations for the diffusion model.
        p_mean: The mean of noise distribution used during training.
        p_std: The std of the noise distribution used during training.
        sigma_min: Min noise level for generation.
        sigma_max: Max noise level for generation.
        churn: The amount of stochasticity during generation.
        num_diffusion_generation_steps: Number of diffusion generation steps
        use_fine_topography: Whether to use fine topography in the model.
    """

    module: DiffusionModuleRegistrySelector
    loss: LossConfig
    in_names: list[str]
    out_names: list[str]
    normalization: PairedNormalizationConfig
    p_mean: float
    p_std: float
    sigma_min: float
    sigma_max: float
    churn: float
    num_diffusion_generation_steps: int
    predict_residual: bool
    use_fine_topography: bool = False

    def __post_init__(self):
        self._interpolate_input = self.module.expects_interpolated_input
        if self.use_fine_topography and not self._interpolate_input:
            raise ValueError(
                "Fine topography can only be used when predicting on interpolated"
                " coarse input"
            )

    def build(
        self,
        coarse_shape: tuple[int, int],
        downscale_factor: int,
        rename: dict[str, str] | None = None,
        static_inputs: StaticInputs | None = None,
    ) -> "DiffusionModel":
        invert_rename = {v: k for k, v in (rename or {}).items()}
        orig_in_names = [invert_rename.get(name, name) for name in self.in_names]
        orig_out_names = [invert_rename.get(name, name) for name in self.out_names]
        normalizer = self.normalization.build(orig_in_names, orig_out_names, rename)
        loss = self.loss.build(reduction="none", gridded_operations=None)
        # We always use standard score normalization, so sigma_data is
        # always 1.0. See below for standard score normalization:
        # https://en.wikipedia.org/wiki/Standard_score
        sigma_data = 1.0

        n_in_channels = len(self.in_names)
        # fine topography is already normalized and at fine scale, so needs
        # some special handling for now
        if self.use_fine_topography:
            n_in_channels += 1

        module = self.module.build(
            n_in_channels=n_in_channels,
            n_out_channels=len(self.out_names),
            coarse_shape=coarse_shape,
            downscale_factor=downscale_factor,
            sigma_data=sigma_data,
        )

        return DiffusionModel(
            config=self,
            module=module,
            normalizer=normalizer,
            loss=loss,
            coarse_shape=coarse_shape,
            downscale_factor=downscale_factor,
            sigma_data=sigma_data,
            static_inputs=static_inputs,
        )

    def get_state(self) -> Mapping[str, Any]:
        # Update normalization configuration so it no longer requires external files.
        self.normalization.load()
        return dataclasses.asdict(self)

    @classmethod
    def from_state(cls, state: Mapping[str, Any]) -> "DiffusionModelConfig":
        return dacite.from_dict(data_class=cls, data=state)

    @property
    def data_requirements(self) -> DataRequirements:
        # Requires output names in coarse for aggregators checking relative measures
        return DataRequirements(
            fine_names=self.out_names,
            coarse_names=list(set(self.in_names).union(self.out_names)),
            n_timesteps=1,
            use_fine_topography=self.use_fine_topography,
        )


def _repeat_batch_by_samples(tensor: torch.Tensor, n_samples: int) -> torch.Tensor:
    """
    Repeat the batch dimension of a tensor n_samples times.  Used to parallelize
    sample generation in the diffusion model, but done so such that the tensor still
    only leads with a single batch dimension.  Added as a separate function for
    round-trip testing.
    """
    return tensor.repeat_interleave(dim=0, repeats=n_samples)


def _separate_interleaved_samples(tensor: torch.Tensor, n_samples: int) -> torch.Tensor:
    """
    Reshape an interleaved tensor to have a leading [batch, sample, ...] dimension.
    """
    if tensor.shape[0] % n_samples != 0:
        raise ValueError(
            "The interleaved batch+sample dimension of the tensor must be divisible "
            "by n_samples."
        )
    n_batch = tensor.shape[0] // n_samples
    return tensor.reshape(n_batch, n_samples, *tensor.shape[1:])


@dataclasses.dataclass
class ConditionedTarget:
    """
    A class to hold the conditioned targets and the loss weighting.

    Attributes:
        latents: The normalized targets with noise added.
        sigma: The noise level.
        weight: The loss weighting.
    """

    latents: torch.Tensor
    sigma: torch.Tensor
    weight: torch.Tensor


def condition_with_noise_for_training(
    targets_norm: torch.Tensor,
    p_std: float,
    p_mean: float,
    sigma_data: float,
) -> ConditionedTarget:
    """
    Condition the targets with noise for training.

    Args:
        targets_norm: The normalized targets.
        p_std: The standard deviation of the noise distribution used during training.
        p_mean: The mean of the noise distribution used during training.
        sigma_data: The standard deviation of the data,
            used to determine loss weighting.

    Returns:
        The conditioned targets and the loss weighting.
    """
    rnd_normal = randn([targets_norm.shape[0], 1, 1, 1], device=targets_norm.device)
    # This is taken from EDM's original implementation in EDMLoss:
    # https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L72-L80  # noqa: E501
    sigma = (rnd_normal * p_std + p_mean).exp()
    weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2
    noise = randn_like(targets_norm) * sigma
    latents = targets_norm + noise
    return ConditionedTarget(latents=latents, sigma=sigma, weight=weight)


class DiffusionModel:
    def __init__(
        self,
        config: DiffusionModelConfig,
        module: torch.nn.Module,
        normalizer: FineResCoarseResPair[StandardNormalizer],
        loss: torch.nn.Module,
        coarse_shape: tuple[int, int],
        downscale_factor: int,
        sigma_data: float,
        static_inputs: StaticInputs | None = None,
    ) -> None:
        """
        Args:
            config: The configuration object for the diffusion model.
            module: The neural network module used for downscaling. Note: this
                should *not* be DistributedDataParallel since it is wrapped by
                it in this init method.
            normalizer: The normalizer object used for data normalization.
            loss: The loss function used for training the model.
            coarse_shape: The height (lat) and width (lon) of the
                coarse-resolution input data.
            downscale_factor: The factor by which the data is downscaled from
                coarse to fine.
            sigma_data: The standard deviation of the data, used for diffusion
                model preconditioning.
            static_inputs: Optional static inputs to the model that may be loaded
                from saved checkpoint. If required by the model but not passed at
                init, they are expected to be provided in the loaded dataset.
        """
        self.coarse_shape = coarse_shape
        self.downscale_factor = downscale_factor
        self.sigma_data = sigma_data
        dist = Distributed.get_instance()
        self.module = dist.wrap_module(module.to(get_device()))
        self.normalizer = normalizer
        self.loss = loss
        self.in_packer = Packer(config.in_names)
        self.out_packer = Packer(config.out_names)
        self.config = config
        self._channel_axis = -3
        self.static_inputs = static_inputs

    @property
    def modules(self) -> torch.nn.ModuleList:
        return torch.nn.ModuleList([self.module])

    @property
    def fine_shape(self) -> tuple[int, int]:
        return self._get_fine_shape(self.coarse_shape)

    def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]:
        """
        Calculate the fine shape based on the coarse shape and downscale factor.
        """
        return (
            coarse_shape[0] * self.downscale_factor,
            coarse_shape[1] * self.downscale_factor,
        )

    def _get_input_from_coarse(
        self, coarse: TensorMapping, topography: Topography | None
    ) -> torch.Tensor:
        inputs = filter_tensor_mapping(coarse, self.in_packer.names)
        normalized = self.in_packer.pack(
            self.normalizer.coarse.normalize(inputs), axis=self._channel_axis
        )
        interpolated = interpolate(normalized, self.downscale_factor)

        if self.config.use_fine_topography:
            if topography is None:
                raise ValueError(
                    "Topography must be provided for each batch when use of fine "
                    "topography is enabled."
                )
            else:
                n_batches = normalized.shape[0]
                # Join the normalized topography to the input (see dataset for details)
                topo = topography.data.unsqueeze(0).repeat(n_batches, 1, 1)
                topo = topo.unsqueeze(self._channel_axis)
                interpolated = torch.concat(
                    [interpolated, topo], axis=self._channel_axis
                )

        if self.config._interpolate_input:
            return interpolated
        return normalized

    def train_on_batch(
        self,
        batch: PairedBatchData,
        topography: Topography | None,
        optimizer: Optimization | NullOptimization,
    ) -> ModelOutputs:
        """Performs a denoising training step on a batch of data."""
        coarse, fine = batch.coarse.data, batch.fine.data
        inputs_norm = self._get_input_from_coarse(coarse, topography)
        targets_norm = self.out_packer.pack(
            self.normalizer.fine.normalize(dict(fine)), axis=self._channel_axis
        )

        if self.config.predict_residual:
            base_prediction = interpolate(
                self.out_packer.pack(
                    self.normalizer.coarse.normalize(
                        {k: coarse[k] for k in self.out_packer.names}
                    ),
                    axis=self._channel_axis,
                ),
                self.downscale_factor,
            )
            targets_norm = targets_norm - base_prediction

        conditioned_target = condition_with_noise_for_training(
            targets_norm, self.config.p_std, self.config.p_mean, self.sigma_data
        )

        denoised_norm = self.module(
            conditioned_target.latents, inputs_norm, conditioned_target.sigma
        )
        loss = torch.mean(
            conditioned_target.weight * self.loss(denoised_norm, targets_norm)
        )
        optimizer.accumulate_loss(loss)
        optimizer.step_weights()

        if self.config.predict_residual:
            denoised_norm = denoised_norm + base_prediction

        target = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names))
        denoised = self.normalizer.fine.denormalize(
            self.out_packer.unpack(denoised_norm, axis=self._channel_axis)
        )
        return ModelOutputs(
            prediction=denoised, target=target, loss=loss, latent_steps=[]
        )

    @torch.no_grad()
    def generate(
        self,
        coarse_data: TensorMapping,
        topography: torch.Tensor | None,
        n_samples: int = 1,
    ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]:
        inputs_ = self._get_input_from_coarse(coarse_data, topography)
        # expand samples and fold to
        # [batch * n_samples, output_channels, height, width]
        inputs_ = _repeat_batch_by_samples(inputs_, n_samples)
        coarse_input_shape = next(iter(coarse_data.values())).shape[-2:]

        outputs_shape = (
            inputs_.shape[0],
            len(self.out_packer.names),
            *self._get_fine_shape(coarse_input_shape),
        )
        latents = torch.randn(outputs_shape).to(device=get_device())

        generated_norm, latent_steps = edm_sampler(
            self.module,
            latents,
            inputs_,
            S_churn=self.config.churn,
            sigma_min=self.config.sigma_min,
            sigma_max=self.config.sigma_max,
            num_steps=self.config.num_diffusion_generation_steps,
        )

        if self.config.predict_residual:
            base_prediction = interpolate(
                self.out_packer.pack(
                    self.normalizer.coarse.normalize(
                        {k: coarse_data[k] for k in self.out_packer.names}
                    ),
                    axis=self._channel_axis,
                ),
                self.downscale_factor,
            )
            generated_norm = generated_norm + _repeat_batch_by_samples(
                base_prediction, n_samples
            )

        generated_norm_reshaped = _separate_interleaved_samples(
            generated_norm, n_samples
        )
        generated = self.normalizer.fine.denormalize(
            self.out_packer.unpack(generated_norm_reshaped, axis=self._channel_axis)
        )
        return generated, generated_norm, latent_steps

    @torch.no_grad()
    def generate_on_batch_no_target(
        self,
        batch: BatchData,
        topography: Topography | None,
        n_samples: int = 1,
    ) -> TensorDict:
        generated, _, _ = self.generate(batch.data, topography, n_samples)
        return generated

    @torch.no_grad()
    def generate_on_batch(
        self,
        batch: PairedBatchData,
        topography: Topography | None,
        n_samples: int = 1,
    ) -> ModelOutputs:
        coarse, fine = batch.coarse.data, batch.fine.data
        generated, generated_norm, latent_steps = self.generate(
            coarse, topography, n_samples
        )

        targets_norm = self.out_packer.pack(
            self.normalizer.fine.normalize(dict(fine)), axis=self._channel_axis
        )
        targets_norm = _repeat_batch_by_samples(targets_norm, n_samples)

        targets = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names))
        targets = {k: v.unsqueeze(1) for k, v in targets.items()}

        loss = self.loss(generated_norm, targets_norm)
        return ModelOutputs(
            prediction=generated, target=targets, loss=loss, latent_steps=latent_steps
        )

    def get_state(self) -> Mapping[str, Any]:
        if self.static_inputs is not None:
            static_inputs_state = self.static_inputs.to_state()
        else:
            static_inputs_state = None

        return {
            "config": self.config.get_state(),
            "module": self.module.state_dict(),
            "coarse_shape": self.coarse_shape,
            "downscale_factor": self.downscale_factor,
            "static_inputs": static_inputs_state,
        }

    @classmethod
    def from_state(
        cls,
        state: Mapping[str, Any],
    ) -> "DiffusionModel":
        config = DiffusionModelConfig.from_state(state["config"])
        model = config.build(
            state["coarse_shape"],
            state["downscale_factor"],
        )
        model.module.load_state_dict(state["module"], strict=True)
        return model


@dataclasses.dataclass
class _CheckpointModelConfigSelector:
    wrapper: DiffusionModelConfig

    @classmethod
    def from_state(cls, state: Mapping[str, Any]) -> DiffusionModelConfig:
        return dacite.from_dict(
            data={"wrapper": state}, data_class=cls, config=dacite.Config(strict=True)
        ).wrapper


[docs]@dataclasses.dataclass class CheckpointModelConfig: """ This class specifies a diffusion model loaded from a checkpoint file. Parameters: checkpoint_path: The path to the checkpoint file. rename: Optional mapping of {old: new} model input/output names to rename. fine_topography_path: Optional path to the fine topography file, if needed. This is useful when no fine res data is used during evaluation but the model still needs fine res static input data. model_updates: Optional mapping of {key: new_value} model config updates to apply when loading the model. This is useful for running evaluation with updated parameters than at training time. Use with caution; not all parameters can or should be updated at evaluation time. """ checkpoint_path: str rename: dict[str, str] | None = None fine_topography_path: str | None = None model_updates: dict[str, Any] | None = None def __post_init__(self) -> None: # For config validation testing, we don't want to load immediately # so we defer until build or properties are accessed. self._checkpoint_is_loaded = False self._rename = self.rename or {} if "module" in (self.model_updates or {}): raise ValueError("'module' cannot be updated in model_updates.") @property def _checkpoint(self) -> Mapping[str, Any]: if not self._checkpoint_is_loaded: checkpoint_data = torch.load(self.checkpoint_path, weights_only=False) checkpoint_data["model"]["config"]["in_names"] = [ self._rename.get(name, name) for name in checkpoint_data["model"]["config"]["in_names"] ] checkpoint_data["model"]["config"]["out_names"] = [ self._rename.get(name, name) for name in checkpoint_data["model"]["config"]["out_names"] ] # backwards compatibility for models before static inputs serialization checkpoint_data["model"].setdefault("static_inputs", None) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True if self.model_updates is not None: for k, v in self.model_updates.items(): checkpoint_data["model"]["config"][k] = v return self._checkpoint_data def build( self, ) -> DiffusionModel: if self._checkpoint["model"]["static_inputs"] is not None: static_inputs = StaticInputs.from_state( self._checkpoint["model"]["static_inputs"] ) else: static_inputs = None model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], rename=self._rename, static_inputs=static_inputs, ) model.module.load_state_dict(self._checkpoint["model"]["module"]) return model @property def data_requirements(self) -> DataRequirements: in_names = self.in_names out_names = self.out_names return DataRequirements( fine_names=out_names, coarse_names=list(set(in_names).union(out_names)), n_timesteps=1, use_fine_topography=self._checkpoint["model"]["config"][ "use_fine_topography" ], ) @property def in_names(self): return self._checkpoint["model"]["config"]["in_names"] @property def out_names(self): return self._checkpoint["model"]["config"]["out_names"] def get_topography(self) -> Topography | None: if self.data_requirements.use_fine_topography: if self.fine_topography_path is None: raise ValueError( "Topography path must be provided for model configured " "to use fine topography." ) return get_normalized_topography(self.fine_topography_path).to_device() else: return None