import contextlib
import dataclasses
import logging
import os
from typing import Any, Dict, Mapping, Optional, Union
from fme.core.distributed import Distributed
from fme.core.wandb import WandB
ENV_VAR_NAMES = (
"BEAKER_EXPERIMENT_ID",
"SLURM_JOB_ID",
"SLURM_JOB_USER",
"FME_TRAIN_DIR",
"FME_VALID_DIR",
"FME_STATS_DIR",
"FME_CHECKPOINT_DIR",
"FME_OUTPUT_DIR",
"FME_IMAGE",
)
[docs]@dataclasses.dataclass
class LoggingConfig:
"""
Configuration for logging.
Attributes:
project: name of the project in Weights & Biases
entity: name of the entity in Weights & Biases
log_to_screen: whether to log to the screen
log_to_file: whether to log to a file
log_to_wandb: whether to log to Weights & Biases
log_format: format of the log messages
"""
project: str = "ace"
entity: str = "ai2cm"
log_to_screen: bool = True
log_to_file: bool = True
log_to_wandb: bool = True
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
level: Union[str, int] = logging.INFO
def __post_init__(self):
self._dist = Distributed.get_instance()
def configure_wandb(
self,
config: Mapping[str, Any],
env_vars: Optional[Mapping[str, Any]] = None,
**kwargs,
):
config_copy = {**config}
if "environment" in config_copy:
logging.warning(
"Not recording environmental variables since 'environment' key is "
"already present in config."
)
elif env_vars is not None:
config_copy["environment"] = env_vars
# must ensure wandb.configure is called before wandb.init
wandb = WandB.get_instance()
wandb.configure(log_to_wandb=self.log_to_wandb)
wandb.init(
config=config_copy,
project=self.project,
entity=self.entity,
dir=config["experiment_dir"],
**kwargs,
)
def clean_wandb(self, experiment_dir: str):
wandb = WandB.get_instance()
wandb.clean_wandb_dir(experiment_dir=experiment_dir)
def log_versions():
import torch
logging.info("--------------- Versions ---------------")
logging.info("Torch: " + str(torch.__version__))
logging.info("----------------------------------------")
def retrieve_env_vars(names=ENV_VAR_NAMES) -> Dict[str, str]:
"""Return a dictionary of specific environmental variables."""
output = {}
for name in names:
try:
value = os.environ[name]
except KeyError:
logging.warning(f"Environmental variable {name} not found.")
else:
output[name] = value
logging.info(f"Environmental variable {name}={value}.")
return output
def log_beaker_url(beaker_id=None):
"""Log the Beaker ID and URL for the current experiment.
beaker_id: The Beaker ID of the experiment. If None, uses the env variable
`BEAKER_EXPERIMENT_ID`.
Returns the Beaker URL.
"""
if beaker_id is None:
try:
beaker_id = os.environ["BEAKER_EXPERIMENT_ID"]
except KeyError:
logging.warning("Beaker Experiment ID not found.")
return None
beaker_url = f"https://beaker.org/ex/{beaker_id}"
logging.info(f"Beaker ID: {beaker_id}")
logging.info(f"Beaker URL: {beaker_url}")
return beaker_url
@contextlib.contextmanager
def log_level(level):
"""Temporarily set the log level of the global logger."""
logger = logging.getLogger() # presently, data loading uses the root logger
old_level = logger.getEffectiveLevel()
try:
logger.setLevel(level)
yield
finally:
logger.setLevel(old_level)