[docs]@dataclasses.dataclassclassSchedulerConfig:""" Configuration for a scheduler to use during training. Parameters: type: Name of scheduler class from torch.optim.lr_scheduler, no scheduler is used by default. kwargs: Keyword arguments to pass to the scheduler constructor. step_each_iteration: If true, step after each batch. Otherwise, just step at the end of each epoch. Schedulers that step with every iteration won't be passed the validation loss. """type:str|None=Nonekwargs:Mapping[str,Any]=dataclasses.field(default_factory=dict)step_each_iteration:bool=False
[docs]defbuild(self,optimizer,max_epochs)->"LRScheduler":""" Build the scheduler. """ifself.typeisNone:returnLRScheduler()build_kwargs={**self.kwargs}# work-around so we don't need to specify T_max# in the yaml file for this schedulerifself.type=="CosineAnnealingLR"and"T_max"notinself.kwargs:build_kwargs["T_max"]=max_epochsscheduler_class=getattr(torch.optim.lr_scheduler,self.type)returnLRScheduler(scheduler_obj=scheduler_class(optimizer=optimizer,**build_kwargs),step_each_iteration=self.step_each_iteration,)
[docs]@dataclasses.dataclassclassSequentialSchedulerConfig:""" Configuration for using torch.optim.SequentialLR to build a sequence of LR schedulers that run one after the other. Parameters: schedulers: Ordered sequence of SchedulerConfigs to define the schedulers for the SequentialLR. Note that all schedulers in the sequence must have the same value for steps_per_iteration. milestones: Sequence of integers that reflects milestone points, where milestones[i] corresponds to the last epoch or iteration where schedulers[i] is active before switching to schedulers[i+1]. For example, with two schedulers and milestones=[10] the first 10 epochs will use the first scheduler and then switch to the second scheduler for epoch 11. last_epoch: The index of last epoch. Default: -1. """schedulers:Sequence[SchedulerConfig]milestones:Sequence[int]last_epoch:int=-1def__post_init__(self):valid_steps_per_iteration=all([x.step_each_iteration==self.schedulers[0].step_each_iterationforxinself.schedulers])ifnotvalid_steps_per_iteration:raiseValueError("All SchedulerConfigs in the SequentialSchedulerConfig must have ""identical values for steps_per_iteration.")@propertydefstep_each_iteration(self)->bool:returnself.schedulers[0].step_each_iteration
[docs]defbuild(self,optimizer,max_epochs)->"LRScheduler":""" Build the SequentialLR scheduler. """schedulers=[x.build(optimizer,max_epochs).scheduler_objforxinself.schedulers]returnLRScheduler(scheduler_obj=SequentialLR(optimizer=optimizer,schedulers=schedulers,milestones=self.milestones,last_epoch=self.last_epoch,),step_each_iteration=self.step_each_iteration,)
classLRScheduler:"""Thin wrapper around torch.optim.lr_scheduler._LRScheduler."""def__init__(self,scheduler_obj:torch.optim.lr_scheduler._LRScheduler|None=None,step_each_iteration:bool=False,):self._scheduler_obj=scheduler_objself._step_each_iteration=step_each_iteration@propertydefscheduler_obj(self)->torch.optim.lr_scheduler._LRScheduler|None:returnself._scheduler_objdefshould_step(self,is_iteration:bool)->bool:"""Determine whether the scheduler should be stepped based on configuration and context. """ifself._scheduler_objisNone:returnFalsedo_iter_step=self._step_each_iterationandis_iterationdo_epoch_step=notself._step_each_iterationandnotis_iterationreturndo_iter_stepordo_epoch_stepdefstep(self,*args,**kwargs):ifself._scheduler_objisnotNone:self._scheduler_obj.step(*args,**kwargs)defstate_dict(self):ifself._scheduler_objisNone:returnNonereturnself._scheduler_obj.state_dict()defload_state_dict(self,state):ifself._scheduler_objisnotNoneandstateisnotNone:self._scheduler_obj.load_state_dict(state)