Downscaling Inference¶
Overview¶
The downscaling inference entrypoint generates high-resolution downscaled outputs from trained diffusion models using coarse-resolution input data. Unlike training or evaluation, this entrypoint does not require fine-resolution target data, making it suitable for generating downscaled predictions for any region and time period where coarse data is available. Multiple outputs can be specified in a single configuration file, each with different spatial regions, time ranges, ensemble sizes, and output variables. Outputs are processed sequentially, with generation parallelized across GPUs using distributed data loading.
A separate zarr file is generated for each output at {experiment_dir}/{output_name}.zarr with dimensions (time, ensemble, latitude, longitude).
Launching Inference¶
To run inference on GPUs:
torchrun --nproc_per_node=<num_gpus> -m fme.downscaling.inference config.yaml
Replace <num_gpus> with the number of GPUs you want to use.
Example YAML Configuration¶
The following example shows a configuration which generates two outputs: one using EventConfig for a single time snapshot, and one using TimeRangeConfig for a time range.
experiment_dir: /output_directory
patch:
divide_generation: true
composite_prediction: true
coarse_horizontal_overlap: 1
model:
checkpoint_path: /HiRO.ckpt
rename:
eastward_wind_at_ten_meters: UGRD10m
northward_wind_at_ten_meters: VGRD10m
data:
coarse:
- data_path: /output_directory
engine: zarr
file_pattern: output_6hourly_predictions_ic0000.zarr
batch_size: 4
num_data_workers: 2
strict_ensemble: False
outputs:
- name: "PNW_ensemble0000"
save_vars: ["PRATEsfc"]
n_ens: 1
max_samples_per_gpu: 32
time_range:
start_time: "2014-01-01T00:00:00"
stop_time: "2023-12-31T18:00:00"
lat_extent:
start: 32.9
stop: 50.0
lon_extent:
start: 233.0
stop: 250.0
- name: "WA_AR_20230206" # Example of a single event downscaling (may not exist in stochastically generated ACE outputs)
save_vars: ["PRATEsfc"]
n_ens: 16
max_samples_per_gpu: 8
event_time: "2023-02-06T06:00:00"
lat_extent:
start: 36.0
stop: 52.0
lon_extent:
start: 228.0
stop: 244.0
logging:
log_to_screen: true
log_to_wandb: false
log_to_file: true
Configuration Structure¶
We use the Builder pattern to load this configuration into a multi-level dataclass structure. The top-level configuration is the fme.downscaling.inference.inference.InferenceConfig class.
- class fme.downscaling.inference.inference.InferenceConfig(model, data, experiment_dir, outputs, logging, patch=<factory>)[source]
Bases:
objectTop-level configuration for downscaling generation entrypoint.
Defines the model, base data source, and one or more outputs to generate. Fine-resolution outputs are generated from coarse-resolution inputs without requiring fine-resolution target data (unlike training/evaluation).
Each output can specify different spatial regions, time ranges, ensemble sizes, and output variables. Outputs are processed sequentially, with generation parallelized across GPUs using distributed data loading.
- Parameters:
model (
CheckpointModelConfig) – Model specification to load for generation.data (
DataLoaderConfig) – Base data loader configuration that is shared to each output generation task. Specifics for each output like the time(range), spatial extent, saved variables, and max_samples_per_gpu (effective batch size) are specified in each outputß.experiment_dir (
str) – Directory for saving generated zarr files and logs.outputs (
list[EventConfig|TimeRangeConfig]) – List of output specifications. Each output generates a separate zarr file.logging (
LoggingConfig) – Logging configuration.patch (
PatchPredictionConfig, default:<factory>) – Default patch prediction configuration.
Output Configuration Types¶
The outputs list can contain two types of configurations: EventConfig for single time snapshots and TimeRangeConfig for time ranges.
EventConfig¶
fme.downscaling.inference.output.EventConfig is used for generating a single time snapshot over a spatial region. This is useful for capturing specific events like hurricane landfall, extreme weather events, or any single-timestep high-resolution snapshot of a region.
- class fme.downscaling.inference.output.EventConfig(name, n_ens, save_vars=None, zarr_chunks=None, zarr_shards=None, max_samples_per_gpu=4, event_time='', time_format='%Y-%m-%dT%H:%M:%S', lat_extent=<factory>, lon_extent=<factory>)[source]
Bases:
DownscalingOutputConfigConfiguration for generating a single time snapshot over a spatial region.
Useful for capturing specific events like hurricane landfall, extreme weather events, or any single-timestep high-resolution snapshot of a region.
If n_ens > max_samples_per_gpu, this event can be run in a distributed manner where each GPU generates a subset of the ensemble members for the event.
- Parameters:
name (
str) – Unique identifier for this target (used in output filename)n_ens (
int) – Number of ensemble members to generate when downscalingsave_vars (
Optional[list[str]], default:None) – List of variable names to save to zarr output. If None, all variables from the model output will be saved.zarr_chunks (
Optional[dict[str,int]], default:None) – Optional chunk sizes for zarr dimensions. If None, automatically calculated to target lat/lon shape <=10MB per chunk. Ensemble and time dimensions chunks are length 1.zarr_shards (
Optional[dict[str,int]], default:None) – Optional shard sizes for zarr dimensions. If None, defaults to maximum output size for a single unit of downscaling work. This ensures that parallel generation tasks write to separate shards.max_samples_per_gpu (
int, default:4) – Number of time and/or ensemble samples to include in a single GPU generation. Controls memory usage and time to generate.event_time (
str|int, default:'') – Timestamp or integer index of the event. If string, must match time_format. Required field.time_format (
str, default:'%Y-%m-%dT%H:%M:%S') – strptime format for parsing event_time string. Default: “%Y-%m-%dT%H:%M:%S” (ISO 8601)lat_extent (
ClosedInterval, default:<factory>) – Latitude bounds in degrees limited to [-88, 88].to (Defaults) – from Antarctica.
lon_extent (
ClosedInterval, default:<factory>) – Longitude bounds in degrees [-180, 360]. Default: full extent of the underlying data.
Example EventConfig:
name: "hurricane_landfall_2023"
save_vars: ["PRATEsfc"]
n_ens: 64
max_samples_per_gpu: 8
event_time: "2023-09-15T12:00:00"
time_format: "%Y-%m-%dT%H:%M:%S"
lat_extent:
start: 25.0
stop: 35.0
lon_extent:
start: 260.0
stop: 275.0
You can also use integer indices for event_time:
name: "event_at_index_100"
save_vars: ["PRATEsfc"]
n_ens: 32
max_samples_per_gpu: 4
event_time: 100
lat_extent:
start: 30.0
stop: 40.0
TimeRangeConfig¶
fme.downscaling.inference.output.TimeRangeConfig is used for generating a time segment over a spatial region. This is the most common and flexible configuration, suitable for generating downscaled data over regions like CONUS, continental areas, or custom domains over extended time periods.
- class fme.downscaling.inference.output.TimeRangeConfig(name, n_ens, save_vars=None, zarr_chunks=None, zarr_shards=None, max_samples_per_gpu=4, time_range=<factory>, lat_extent=<factory>, lon_extent=<factory>)[source]
Bases:
DownscalingOutputConfigConfiguration for generating a time segment over a spatial region.
This is the most common and flexible configuration, suitable for generating downscaled data over regions like CONUS, continental areas, or custom domains over extended time periods.
- Parameters:
name (
str) – Unique identifier for this target (used in output filename)n_ens (
int) – Number of ensemble members to generate when downscalingsave_vars (
Optional[list[str]], default:None) – List of variable names to save to zarr output. If None, all variables from the model output will be saved.zarr_chunks (
Optional[dict[str,int]], default:None) – Optional chunk sizes for zarr dimensions. If None, automatically calculated to target lat/lon shape <=10MB per chunk. Ensemble and time dimensions chunks are length 1.zarr_shards (
Optional[dict[str,int]], default:None) – Optional shard sizes for zarr dimensions. If None, defaults to maximum output size for a single unit of downscaling work. This ensures that parallel generation tasks write to separate shards.max_samples_per_gpu (
int, default:4) – Number of time and/or ensemble samples to include in a single GPU generation. Controls memory usage and time to generate.time_range (
TimeSlice|RepeatedInterval|Slice, default:<factory>) –Time selection specification. Can be:
TimeSlice: Start/stop timestamps (e.g., TimeSlice(start_time=”2021-01-01”, stop_time=”2021-12-31”))
Slice: Integer indices (e.g., Slice(0, 365))
RepeatedInterval: Repeating time pattern
lat_extent (
ClosedInterval, default:<factory>) – Latitude bounds in degrees limited to [-88, 88]. Defaults to (-66, 70) which covers continental land masses aside from Antarctica.lon_extent (
ClosedInterval, default:<factory>) – Longitude bounds in degrees [-180, 360]. Default: full extent of the underlying data.
Example TimeRangeConfig with TimeSlice:
name: "CONUS_full_year"
n_ens: 4
max_samples_per_gpu: 4
time_range:
start_time: "2023-01-01T00:00:00"
stop_time: "2023-12-31T18:00:00"
Example TimeRangeConfig with Slice:
name: "first_year_indices"
n_ens: 4
max_samples_per_gpu: 4
time_range:
start: 0
stop: 36
Example TimeRangeConfig with RepeatedInterval:
name: "weekly_snapshots"
n_ens: 4
max_samples_per_gpu: 4
time_range:
interval_length: "1d"
block_length: "7d"
start: "0d"
Common Configuration Patterns¶
Renaming model variables¶
You can rename input/output variables for the model loaded from the checkpoint. This is useful if the model input variables’ names are not the same as the variable names in the coarse input dataset, or if the model output variables are not the same as the variable names you want to save.
For example, ACE outputs coarse grid 10m winds as UGRD10m and VGRD10m, while the downscaling checkpoint was created using data with variable names eastward_wind_at_ten_meters and northward_wind_at_ten_meters. Thus, the model configuration in the example requires the following rename fields
model:
checkpoint_path: /HiRO.ckpt
rename:
eastward_wind_at_ten_meters: UGRD10m
northward_wind_at_ten_meters: VGRD10m
Multiple Outputs¶
You can mix EventConfig and TimeRangeConfig outputs in a single configuration file. Outputs are processed sequentially:
outputs:
- name: "event_1"
event_time: "2023-06-15T00:00:00"
save_vars: ["PRATEsfc"]
n_ens: 128
max_samples_per_gpu: 8
- name: "time_range_1"
time_range:
start_time: "2023-01-01T00:00:00"
stop_time: "2023-03-31T18:00:00"
n_ens: 8
max_samples_per_gpu: 8
Spatial Extent Configuration¶
Both EventConfig and TimeRangeConfig support spatial extent configuration via lat_extent and lon_extent. These define the latitude and longitude bounds for the output region:
lat_extent:
start: 25.0
stop: 50.0
lon_extent:
start: 230.0
stop: 295.0
Latitude bounds must be within (-88, 88) degrees. Longitude can be in the range (-180, 360) degrees. If not specified, the generated dataset region will default to the latitude range used in training (-66, 70) degrees. Note- this will generate a very large output dataset!
Ensemble Size Configuration¶
The n_ens field specifies the total number of ensemble members to generate. The max_samples_per_gpu field controls how many time and/or ensemble samples are included in a single GPU batch, which affects memory usage and generation time. If not provided, the default value for max_samples_per_gpu is 4.
n_ens: 128
max_samples_per_gpu: 4
Variable Selection¶
Use the save_vars field to specify which variables to save to the output zarr file. If save_vars is null or not specified, all variables from the model output will be saved:
save_vars: ["PRATEsfc"]
Patch Prediction for Large Domains¶
For domains larger than the model’s patch size, subdivision of the full domain into patches for prediction must be configured in the top-level patch section.
Generation for region sizes smaller than the size the model was trained on is not supported.
patch:
divide_generation: true
composite_prediction: true
coarse_horizontal_overlap: 0
- class fme.downscaling.predictors.composite.PatchPredictionConfig(divide_generation=False, composite_prediction=False, coarse_horizontal_overlap=1)[source]
Bases:
objectConfiguration to enable predictions on multiple patches for evaluation.
- Parameters:
divide_generation (
bool, default:False) – enables the patched prediction of the full input data extent for generation.composite_prediction (
bool, default:False) – if True, recombines the smaller prediction regions into the original full region as a single sample.coarse_horizontal_overlap (
int, default:1) – number of pixels to overlap in the coarse data.