[docs]@dataclasses.dataclassclassLRTuningConfig:""" Configuration for periodic learning rate tuning trials. At the start of epochs contained in ``epochs``, the trainer forks the current model into a baseline and a candidate copy. Both are trained for ``num_batches`` on the first batches of the epoch; the candidate uses a learning rate of ``current_lr * lr_factor``. Both are then validated. If the candidate's validation loss is less than the baseline's by at least ``improvement_threshold`` times the baseline's validation loss, the trainer adopts the candidate's learning rate. Parameters: epochs: A Slice selecting which epochs to run trials on. For example, ``Slice(start=1, step=2)`` runs at epochs 1, 3, 5, … (skipping epoch 0). lr_factor: Multiply the current LR by this to get the candidate LR. num_batches: Number of training batches for each fork in the trial. improvement_threshold: The candidate must beat the baseline's validation loss by at least this fraction of the baseline's validation loss (e.g. 0.01 means the candidate must be lower by at least 1% of the baseline loss). """lr_factor:float=0.5num_batches:int=200epochs:Slice=dataclasses.field(default_factory=Slice)improvement_threshold:float=0.001
defrun_lr_tuning_trial(train_data:GriddedDataABC,valid_data:GriddedDataABC,optimization:OptimizationABC,copy_stepper:Callable[[],TrainStepperABC],build_optimization:Callable[[torch.nn.ModuleList],OptimizationABC],copy_ema:Callable[[torch.nn.ModuleList],EMATracker],config:LRTuningConfig,current_lr:float,get_validation_aggregator:Callable[[],AggregatorABC],validate_using_ema:bool,)->float|None:""" Run an isolated LR tuning trial comparing the current LR against a candidate. Creates two stepper forks, trains both, validates both, and compares validation loss improvements. Does not mutate the original stepper or optimization. Does not log to wandb. Args: train_data: Training data; ``subset_loader`` is used for the first N batches. The caller must have already called ``set_epoch``. valid_data: Validation data. optimization: The current optimization (used to copy momentum state into the forks). copy_stepper: Factory that returns a new stepper initialized from the current stepper's state. Called twice (baseline and candidate). The caller is responsible for ensuring proper deep copy semantics (e.g. using get_state/load_state rather than copy.deepcopy). build_optimization: Factory to build a fresh optimization for a given ModuleList. copy_ema: Factory that returns a new EMA tracker initialized from the current EMA state but tracking the given modules. Called twice. config: The LR tuning configuration. current_lr: The current learning rate. get_validation_aggregator: Factory for validation aggregators. validate_using_ema: Whether to use EMA parameters during validation. Returns: The candidate learning rate if the candidate wins, otherwise None. """candidate_lr=current_lr*config.lr_factoroptimization_state=copy.deepcopy(optimization.get_state())baseline_stepper=copy_stepper()candidate_stepper=copy_stepper()baseline_opt=build_optimization(baseline_stepper.modules)baseline_opt.load_state(copy.deepcopy(optimization_state))baseline_opt.set_learning_rate(current_lr)candidate_opt=build_optimization(candidate_stepper.modules)candidate_opt.load_state(copy.deepcopy(optimization_state))candidate_opt.set_learning_rate(candidate_lr)baseline_ema=copy_ema(baseline_stepper.modules)candidate_ema=copy_ema(candidate_stepper.modules)# Train both forksbaseline_stepper.set_train()candidate_stepper.set_train()forbatchintrain_data.subset_loader(stop_batch=config.num_batches):baseline_stepper.train_on_batch(batch,baseline_opt)baseline_ema(baseline_stepper.modules)candidate_stepper.train_on_batch(batch,candidate_opt)candidate_ema(candidate_stepper.modules)# Validate both forksbaseline_agg=get_validation_aggregator()run_validation_loop(stepper=baseline_stepper,valid_data=valid_data,aggregator=baseline_agg,ema=baseline_ema,validate_using_ema=validate_using_ema,)baseline_val_logs=baseline_agg.get_logs(label="val")candidate_agg=get_validation_aggregator()run_validation_loop(stepper=candidate_stepper,valid_data=valid_data,aggregator=candidate_agg,ema=candidate_ema,validate_using_ema=validate_using_ema,)candidate_val_logs=candidate_agg.get_logs(label="val")baseline_val_loss=baseline_val_logs["val/mean/loss"]candidate_val_loss=candidate_val_logs["val/mean/loss"]threshold=baseline_val_loss-config.improvement_threshold*baseline_val_losslogging.info(f"LR tuning trial: baseline LR={current_lr}, candidate LR={candidate_lr}, "f"baseline val loss={baseline_val_loss:.6f}, "f"candidate val loss={candidate_val_loss:.6f}, "f"threshold={threshold:.6f}")ifcandidate_val_loss<threshold:logging.info(f"LR tuning trial: candidate wins "f"(candidate loss {candidate_val_loss:.6f} < "f"threshold {threshold:.6f})")returncandidate_lrlogging.info("LR tuning trial: baseline wins, keeping current LR")returnNone