Source code for fme.core.scheduler

import dataclasses
from typing import Any, Mapping, Optional

import torch.optim.lr_scheduler


[docs]@dataclasses.dataclass class SchedulerConfig: """ Configuration for a scheduler to use during training. Attributes: type: Name of scheduler class from torch.optim.lr_scheduler, no scheduler is used by default. kwargs: Keyword arguments to pass to the scheduler constructor. """ type: Optional[str] = None kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
[docs] def build( self, optimizer, max_epochs ) -> Optional[torch.optim.lr_scheduler._LRScheduler]: """ Build the scheduler. """ if self.type is None: return None build_kwargs = {**self.kwargs} # work-around so we don't need to specify T_max # in the yaml file for this scheduler if self.type == "CosineAnnealingLR" and "T_max" not in self.kwargs: build_kwargs["T_max"] = max_epochs scheduler_class = getattr(torch.optim.lr_scheduler, self.type) return scheduler_class(optimizer=optimizer, **build_kwargs)