[docs]definitialize_rng(self):"""Set the rng at runtime. This helps guarantee that the distributed seed has already been set. """ifself._rngisNone:self._rng=np.random.RandomState(Distributed.get_instance().get_seed()+684# don't use this number anywhere else)# must be the same across all processes
[docs]defsample(self)->int:""" Update the current number of timesteps to sample based on the probabilities of sampling each number of timesteps. """self.initialize_rng()# jit, if not called externallyassertself._rngisnotNonereturnself._rng.choice(self._n_times,p=self._probabilities)
TimeLength=TimeLengthProbabilities|int
[docs]@dataclasses.dataclassclassTimeLengthMilestone:""" A milestone for a time length schedule. """epoch:intvalue:TimeLength
[docs]@dataclasses.dataclassclassTimeLengthSchedule:""" A schedule for a time length value. """start_value:TimeLengthmilestones:list[TimeLengthMilestone]def__post_init__(self):self._validated_milestones=ValidatedMilestones(start_value=self.start_value,milestones=self.milestones)
[docs]@classmethoddeffrom_constant(cls,value:TimeLength)->"TimeLengthSchedule":""" Create a TimeLengthSchedule that always returns the same value. Parameters: value: The constant value. Returns: A TimeLengthSchedule instance. """returncls(start_value=value,milestones=[])
defget_value(self,epoch:int)->TimeLength:returnself._validated_milestones.get_value(epoch)@propertydefmax_n_forward_steps(self)->IntSchedule:""" Get a schedule of the maximum number of forward steps. """ifisinstance(self.start_value,int):max_start=self.start_valueelse:max_start=self.start_value.max_n_forward_stepsmax_milestones=[]formilestoneinself.milestones:ifisinstance(milestone.value,int):max_value=milestone.valueelse:max_value=milestone.value.max_n_forward_stepsmax_milestones.append(IntMilestone(epoch=milestone.epoch,value=max_value))returnIntSchedule(start_value=max_start,milestones=max_milestones)