import dataclasses
import logging
from dataclasses import dataclass, field
import dacite
import numpy as np
import torch
import yaml
from fme.core.cli import prepare_directory
from fme.core.generics.trainer import count_parameters
from fme.core.logging_utils import LoggingConfig
from ..data import DataLoaderConfig
from ..models import CheckpointModelConfig, DiffusionModel
from ..predictors import PatchPredictionConfig, PatchPredictor
from .output import DownscalingOutput, EventConfig, TimeRangeConfig
from .work_items import LoadedSliceWorkItem
class Downscaler:
"""
Orchestrates downscaling generation across multiple outputs.
Each output can have different spatial extents, time ranges, and ensemble sizes.
Generation is performed sequentially across outputs.
"""
def __init__(
self,
model: DiffusionModel,
outputs: list[DownscalingOutput],
output_dir: str = ".",
):
self.model = model
self.outputs = outputs
self.output_dir = output_dir
def run_all(self):
"""Run generation for all outputs."""
logging.info(f"Starting generation for {len(self.outputs)} output(s)")
for output in self.outputs:
# Clear GPU cache before each output
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.run_output_generation(output=output)
logging.info("All outputs completed successfully")
def _get_generation_model(
self,
input_shape: tuple[int, int],
output: DownscalingOutput,
) -> DiffusionModel | PatchPredictor:
"""
Set up the model, wrapping with PatchPredictor if needed. While models are
probably capable of generating any domain size, we haven't tested for domains
smaller than the model patch size, so we raise an error in that case, and prompt
the user to use patching for larger domains because that provides better
generations.
"""
model_patch_shape = self.model.coarse_shape
if model_patch_shape == input_shape:
# short circuit, no patching necessary
return self.model
elif any(
expected > actual
for expected, actual in zip(model_patch_shape, input_shape)
):
# we don't support generating regions smaller than the model patch size
raise ValueError(
f"Model coarse shape {model_patch_shape} is larger than "
f"actual input shape {input_shape} for output {output.name}."
"We do not support generating outputs with a smaller spatial extent"
" than the model's trained patch size. Please adjust the spatial extent"
" to be at least as large as the model's input patch size."
)
elif output.patch.needs_patch_predictor:
# Use a patch predictor
logging.info(f"Using PatchPredictor for output: {output.name}")
return PatchPredictor(
model=self.model,
coarse_horizontal_overlap=output.patch.coarse_horizontal_overlap,
)
else:
# User should enable patching
raise ValueError(
f"Model coarse shape {model_patch_shape} does not match "
f"actual input shape {input_shape} for output {output.name}, "
"and patch prediction is not configured. Generation for larger domains "
"requires patch prediction."
)
def _on_device_generator(self, loader):
for loaded_item in loader:
yield loaded_item.to_device()
def run_output_generation(self, output: DownscalingOutput):
"""Execute the generation loop for this output."""
logging.info(f"Generating downscaled outputs for output: {output.name}")
# initialize writer and model in loop for coord info
model = None
writer = None
total_batches = len(output.data.loader)
loaded_item: LoadedSliceWorkItem
for i, loaded_item in enumerate(output.data.get_generator()):
input_shape = loaded_item.batch.horizontal_shape
if model is None:
model = self._get_generation_model(
input_shape=input_shape, output=output
)
if writer is None:
fine_latlon_coords = model.get_fine_coords_for_batch(loaded_item.batch)
writer = output.get_writer(
latlon_coords=fine_latlon_coords,
output_dir=self.output_dir,
)
writer.initialize_store(np.float32)
logging.info(
f"[{output.name}] Batch {i+1}/{total_batches}, "
f"generating work slice {loaded_item.dim_insert_slices} "
)
output_data = model.generate_on_batch_no_target(
loaded_item.batch,
n_samples=loaded_item.n_ens,
)
output_np = {key: value.cpu().numpy() for key, value in output_data.items()}
insert_slices = loaded_item.dim_insert_slices
if not loaded_item.is_padding:
writer.record_batch(output_np, position_slices=insert_slices)
else:
logging.info("Skipping padding work item. No data will be written.")
logging.info(f"Completed generation for output: {output.name}")
[docs]@dataclass
class InferenceConfig:
"""
Top-level configuration for downscaling generation entrypoint.
Defines the model, base data source, and one or more outputs to generate.
Fine-resolution outputs are generated from coarse-resolution inputs without
requiring fine-resolution target data (unlike training/evaluation).
Each output can specify different spatial regions, time ranges, ensemble
sizes, and output variables. Outputs are processed sequentially, with generation
parallelized across GPUs using distributed data loading.
Parameters:
model: Model specification to load for generation.
data: Base data loader configuration that is shared to each output
generation task. Specifics for each output like the time(range),
spatial extent, saved variables, and max_samples_per_gpu
(effective batch size) are specified in each outputß.
experiment_dir: Directory for saving generated zarr files and logs.
outputs: List of output specifications. Each output generates a
separate zarr file.
logging: Logging configuration.
patch: Default patch prediction configuration.
Exclude following from autoclass documentation:
Example YAML configuration::
experiment_dir: /results
model:
checkpoint_path: /checkpoints/best_histogram_tail.ckpt
data:
topography: /climate-default/X-SHiELD-AMIP-downscaling/3km.zarr
coarse:
- data_path: /climate-default/X-SHiELD-AMIP-downscaling
engine: zarr
file_pattern: 100km.zarr
batch_size: 4 # Value is overidden by each output
num_data_workers: 0
strict_ensemble: False
patch:
divide_generation: true
composite_prediction: true
coarse_horizontal_overlap: 0
outputs:
- name: "WA_AR_20230206"
save_vars: ["PRATEsfc"]
n_ens: 128
max_samples_per_gpu: 8
event_time: "2023-02-06T06:00:00"
lat_extent:
start: 36.0
stop: 52.0
lon_extent:
start: 228.0
stop: 244.0
- name: "CONUS_2023"
save_vars: ["PRATEsfc"]
n_ens: 8
max_samples_per_gpu: 8
time_range:
start_time: "2023-01-01T00:00:00"
end_time: "2023-12-31T18:00:00"
lat_extent:
start: 22.0
stop: 50.0
lon_extent:
start: 230.0
stop: 295.0
logging:
log_to_screen: true
log_to_wandb: false
log_to_file: true
project: downscaling
entity: my_organization
"""
model: CheckpointModelConfig
data: DataLoaderConfig
experiment_dir: str
outputs: list[EventConfig | TimeRangeConfig]
logging: LoggingConfig
patch: PatchPredictionConfig = field(default_factory=PatchPredictionConfig)
def configure_logging(self, log_filename: str):
config = dataclasses.asdict(self)
self.logging.configure_logging(
self.experiment_dir, log_filename, config=config, resumable=True
)
def build(self) -> Downscaler:
model = self.model.build()
outputs = [
output_cfg.build(
loader_config=self.data,
requirements=self.model.data_requirements,
patch=self.patch,
fine_shape=model.fine_shape,
)
for output_cfg in self.outputs
]
return Downscaler(model=model, outputs=outputs, output_dir=self.experiment_dir)
def main(config_path: str):
with open(config_path) as f:
config = yaml.safe_load(f)
generation_config: InferenceConfig = dacite.from_dict(
data_class=InferenceConfig,
data=config,
config=dacite.Config(strict=True),
)
prepare_directory(generation_config.experiment_dir, config)
generation_config.configure_logging(log_filename="out.log")
logging.info("Starting downscaling generation...")
downscaler = generation_config.build()
logging.info(f"Number of parameters: {count_parameters(downscaler.model.modules)}")
downscaler.run_all()