[docs]@dataclasses.dataclassclassTrainAggregatorConfig:""" Configuration for the train aggregator. Attributes: spherical_power_spectrum: Whether to compute the spherical power spectrum. weighted_rmse: Whether to compute the weighted RMSE. """spherical_power_spectrum:bool=Trueweighted_rmse:bool=True
classAggregator(Protocol):defrecord_batch(self,target_data:TensorMapping,gen_data:TensorMapping):passdefget_logs(self,label:str)->dict[str,torch.Tensor]:passclassTrainAggregator(AggregatorABC[TrainOutput]):""" Aggregates statistics for the first timestep. To use, call `record_batch` on the results of each batch, then call `get_logs` to get a dictionary of statistics when you're done. """def__init__(self,config:TrainAggregatorConfig,operations:GriddedOperations):self._n_loss_batches=0self._loss=torch.tensor(0.0,device=get_device())self._paired_aggregators:dict[str,Aggregator]={}ifconfig.spherical_power_spectrum:try:flood_fill=SmoothFloodFill(num_steps=4)self._paired_aggregators["power_spectrum"]=(PairedSphericalPowerSpectrumAggregator(gridded_operations=operations,report_plot=False,nan_fill_fn=flood_fill,))exceptNotImplementedError:logging.warning("Power spectrum aggregator not implemented ""for this grid type, omitting.")ifconfig.weighted_rmse:self._paired_aggregators["mean"]=MeanAggregator(gridded_operations=operations,include_bias=False,include_grad_mag_percent_diff=False,)@torch.no_grad()defrecord_batch(self,batch:TrainOutput):self._loss+=batch.metrics["loss"]self._n_loss_batches+=1folded_gen_data,n_ensemble=fold_ensemble_dim(batch.gen_data)folded_target_data=fold_sized_ensemble_dim(batch.target_data,n_ensemble)foraggregatorinself._paired_aggregators.values():aggregator.record_batch(target_data=folded_target_data,gen_data=folded_gen_data,)@torch.no_grad()defget_logs(self,label:str)->dict[str,torch.Tensor]:""" Returns logs as can be reported to WandB. Args: label: Label to prepend to all log keys. """logs={}ifself._n_loss_batches>0:forname,aggregatorinself._paired_aggregators.items():logs.update({f"{label}/{k}":vfork,vinaggregator.get_logs(name).items()})dist=Distributed.get_instance()logs[f"{label}/mean/loss"]=float(dist.reduce_mean(self._loss/self._n_loss_batches).cpu().numpy())returnlogs@torch.no_grad()defflush_diagnostics(self,subdir:str|None)->None:pass