importosimporttorchfrom.typing_importTensorDict,TensorMappingdefusing_gpu()->bool:returnget_device().type=="cuda"defusing_srun()->bool:"""If using srun instead of torchrun, set FME_USE_SRUN=1 in the environment."""ifos.environ.get("FME_USE_SRUN","0")=="1":returnTruereturnFalse
[docs]defget_device()->torch.device:"""If CUDA is available, return a CUDA device. Otherwise, return a CPU device unless FME_USE_MPS is set, in which case return an MPS device if available. """iftorch.cuda.is_available():returntorch.device("cuda",torch.cuda.current_device())else:mps_available=torch.backends.mps.is_available()ifmps_availableandos.environ.get("FME_USE_MPS","0")=="1":returntorch.device("mps",0)else:returntorch.device("cpu")