Source code for fme.downscaling.predictors.composite

import dataclasses

import torch

from fme.core.typing_ import TensorDict
from fme.downscaling.data import BatchData, PairedBatchData, Topography, scale_tuple
from fme.downscaling.data.patching import Patch, get_patches
from fme.downscaling.data.utils import null_generator
from fme.downscaling.models import DiffusionModel, ModelOutputs
from fme.downscaling.predictors import CascadePredictor


[docs]@dataclasses.dataclass class PatchPredictionConfig: """ Configuration to enable predictions on multiple patches for evaluation. Args: divide_generation: enables the patched prediction of the full input data extent for generation. composite_prediction: if True, recombines the smaller prediction regions into the original full region as a single sample. coarse_horizontal_overlap: number of pixels to overlap in the coarse data. """ divide_generation: bool = False composite_prediction: bool = False coarse_horizontal_overlap: int = 1 @property def needs_patch_data_generator(self): # If final predictions are not composited together, the BatchData is divided # into patches before being passed to the top level no-target generation call. # If final predictions are composited together, the BatchData is divided # into patches within the PatchPredictor's generation method. if self.divide_generation and not self.composite_prediction: return True return False @property def needs_patch_predictor(self): if self.divide_generation and self.composite_prediction: return True return False
class PatchPredictor: """ Model prediction wrapper for generating a full-extent prediction by dividing the input into a grid of patches. """ def __init__( self, model: DiffusionModel | CascadePredictor, coarse_yx_patch_extent: tuple[int, int] | None = None, coarse_horizontal_overlap: int = 1, ): """ Args: model: the model to use for generating predictions. coarse_yx_patch_extent: The shape of the coarse region passed to the downscaling model for prediction. If None, will be inferred from model.coarse_shape. coarse_horizontal_overlap: the number of pixels to overlap between patches in the coarse data. """ self.model = model self.modules = self.model.modules if coarse_yx_patch_extent is None: coarse_yx_patch_extent = self.model.coarse_shape self.coarse_yx_patch_extent = coarse_yx_patch_extent self.downscale_factor = self.model.downscale_factor self.coarse_horizontal_overlap = coarse_horizontal_overlap @property def coarse_shape(self): return self.coarse_yx_patch_extent def _get_patches( self, coarse_yx_extent, fine_yx_extent ) -> tuple[list[Patch], list[Patch]]: coarse_patches = get_patches( yx_extent=coarse_yx_extent, yx_patch_extent=self.coarse_yx_patch_extent, overlap=self.coarse_horizontal_overlap, drop_partial_patches=False, ) fine_yx_patch_extent = scale_tuple( self.coarse_yx_patch_extent, self.downscale_factor ) fine_patches = get_patches( yx_extent=fine_yx_extent, yx_patch_extent=fine_yx_patch_extent, overlap=self.coarse_horizontal_overlap * self.downscale_factor, drop_partial_patches=False, ) return coarse_patches, fine_patches @torch.no_grad() def generate_on_batch( self, batch: PairedBatchData, topography: Topography | None, n_samples: int = 1, ) -> ModelOutputs: predictions = [] loss = 0.0 coarse_patches, fine_patches = self._get_patches( coarse_yx_extent=batch.coarse.horizontal_shape, fine_yx_extent=batch.fine.horizontal_shape, ) batch_generator = batch.generate_from_patches( coarse_patches=coarse_patches, fine_patches=fine_patches ) if topography is not None: topography_generator = topography.generate_from_patches(fine_patches) else: topography_generator = null_generator(len(fine_patches)) for data_patch, topography_patch in zip(batch_generator, topography_generator): model_output = self.model.generate_on_batch( data_patch, topography_patch, n_samples ) predictions.append(model_output.prediction) loss = loss + model_output.loss prediction = composite_patch_predictions(predictions, fine_patches) # add ensemble dim to the target since not provided by the model target = {k: v.unsqueeze(1) for k, v in batch.fine.data.items()} outputs = ModelOutputs( prediction=prediction, target=target, loss=loss, ) return outputs @torch.no_grad() def generate_on_batch_no_target( self, batch: BatchData, topography: Topography | None, n_samples: int = 1, ) -> TensorDict: coarse_yx_extent = batch.horizontal_shape fine_yx_extent = scale_tuple(coarse_yx_extent, self.downscale_factor) coarse_patches, fine_patches = self._get_patches( coarse_yx_extent=coarse_yx_extent, fine_yx_extent=fine_yx_extent ) predictions = [] batch_generator = batch.generate_from_patches(coarse_patches) if topography is not None: topography_generator = topography.generate_from_patches(fine_patches) else: topography_generator = null_generator(len(fine_patches)) for data_patch, topo_patch in zip(batch_generator, topography_generator): predictions.append( self.model.generate_on_batch_no_target( batch=data_patch, topography=topo_patch, n_samples=n_samples, ) ) prediction = composite_patch_predictions(predictions, fine_patches) return prediction def _get_full_extent_from_patches(patches: list[Patch]) -> tuple[int, int]: # input patches should have int start/stop values y_max = max(patch.input_slice.y.stop for patch in patches) x_max = max(patch.input_slice.x.stop for patch in patches) return y_max, x_max def _composite_patch_tensors( predictions: list[torch.Tensor], patches: list[Patch] ) -> TensorDict: if len(predictions) != len(patches): raise ValueError("The number of predictions must match the number of patches.") y_size, x_size = _get_full_extent_from_patches(patches) # list elements have shape (n_batch, n_generated_sample, patch_lat, patch_lon) n_batch, n_gen_sample = predictions[0].shape[:2] output_sum = torch.zeros( n_batch, n_gen_sample, y_size, x_size, device=predictions[0].device ) output_count = torch.zeros( n_batch, n_gen_sample, y_size, x_size, device=predictions[0].device ) for i, pred in enumerate(predictions): in_slice = patches[i].input_slice out_slice = patches[i].output_slice # Adjust the input slice start if the output slice is trimmed adjusted_in_slice_y = slice( in_slice.y.start + (out_slice.y.start or 0), in_slice.y.stop ) adjusted_in_slice_x = slice( in_slice.x.start + (out_slice.x.start or 0), in_slice.x.stop ) trimmed_pred = pred[..., out_slice.y, out_slice.x] output_sum[..., adjusted_in_slice_y, adjusted_in_slice_x] += trimmed_pred output_count[..., adjusted_in_slice_y, adjusted_in_slice_x] += 1 return output_sum / output_count def composite_patch_predictions( predictions: list[TensorDict], patches: list[Patch] ) -> TensorDict: """ Take the predictions from patches and combine them into a single tensor with the full extent of the patches. The predictions are averaged in overlapping patch regions. """ combined_data = {} predicted_vars = list(predictions[0].keys()) for var in predicted_vars: var_predictions = [pred[var] for pred in predictions] combined_data[var] = _composite_patch_tensors(var_predictions, patches) return combined_data