Skip to content
Snippets Groups Projects
Commit be6c0b9f authored by Alain Riou's avatar Alain Riou
Browse files

initial commit

parents
No related branches found
No related tags found
No related merge requests found
import logging
from pathlib import Path
from typing import Sequence
import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
log = logging.getLogger(__name__)
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
"callbacks", "logger", "trainer", "paths", "extras")``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
queue.append(field) if field in cfg else log.warning(
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)
import logging
import warnings
from importlib.util import find_spec
from typing import Any, Callable, Dict, Optional, Tuple
from omegaconf import DictConfig, OmegaConf
from src.utils import rich_utils
log = logging.getLogger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
rich_utils.enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
if cfg.extras.get("save_config"):
log.info(f"Saving config in {cfg.paths.output_dir}/config.yaml! <cfg.extras.save_config=True>")
OmegaConf.save(cfg, f"{cfg.paths.output_dir}/config.yaml", resolve=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
```
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: If provided, the name of the metric to retrieve.
:return: If a metric name was provided, the value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def remove_omegaconf_dependencies(obj: Any) -> Any:
if OmegaConf.is_config(obj):
return OmegaConf.to_container(obj, resolve=True)
if isinstance(obj, dict):
return {k: remove_omegaconf_dependencies(v) for k, v in obj.items()}
if isinstance(obj, list):
return [remove_omegaconf_dependencies(e) for e in obj]
return obj
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment