importcontextlibimportosfromcollections.abcimportGeneratorimporttorchfrom.typing_importTensorDict,TensorMapping_FORCE_CPU:bool=os.environ.get("FME_FORCE_CPU","0")=="1"@contextlib.contextmanagerdefforce_cpu(force:bool=True)->Generator[None,None,None]:"""Force the use of CPU even if a GPU is available. This is useful for testing and debugging. Args: force: If True, force the use of CPU. If False, allow the use of GPU if available. """global_FORCE_CPUprevious=_FORCE_CPUtry:_FORCE_CPU=forceyieldfinally:_FORCE_CPU=previousdefusing_gpu()->bool:returnget_device().type=="cuda"defusing_srun()->bool:"""If using srun instead of torchrun, set FME_USE_SRUN=1 in the environment."""ifos.environ.get("FME_USE_SRUN","0")=="1":returnTruereturnFalse
[docs]defget_device()->torch.device:"""If CUDA is available, return a CUDA device. Otherwise, return a CPU device unless FME_USE_MPS is set, in which case return an MPS device if available. """if_FORCE_CPU:returntorch.device("cpu")iftorch.cuda.is_available():returntorch.device("cuda",torch.cuda.current_device())else:mps_available=torch.backends.mps.is_available()ifmps_availableandos.environ.get("FME_USE_MPS","0")=="1":returntorch.device("mps",0)else:returntorch.device("cpu")