Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pesto-full
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Paul Best
pesto-full
Commits
be6c0b9f
Commit
be6c0b9f
authored
1 year ago
by
Alain Riou
Browse files
Options
Downloads
Patches
Plain Diff
initial commit
parents
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
62
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/utils/rich_utils.py
+98
-0
98 additions, 0 deletions
src/utils/rich_utils.py
src/utils/utils.py
+135
-0
135 additions, 0 deletions
src/utils/utils.py
with
233 additions
and
0 deletions
src/utils/rich_utils.py
0 → 100644
+
98
−
0
View file @
be6c0b9f
import
logging
from
pathlib
import
Path
from
typing
import
Sequence
import
rich
import
rich.syntax
import
rich.tree
from
hydra.core.hydra_config
import
HydraConfig
from
lightning_utilities.core.rank_zero
import
rank_zero_only
from
omegaconf
import
DictConfig
,
OmegaConf
,
open_dict
from
rich.prompt
import
Prompt
log
=
logging
.
getLogger
(
__name__
)
@rank_zero_only
def
print_config_tree
(
cfg
:
DictConfig
,
print_order
:
Sequence
[
str
]
=
(
"
data
"
,
"
model
"
,
"
callbacks
"
,
"
logger
"
,
"
trainer
"
,
"
paths
"
,
"
extras
"
,
),
resolve
:
bool
=
False
,
save_to_file
:
bool
=
False
,
)
->
None
:
"""
Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``(
"
data
"
,
"
model
"
,
"
callbacks
"
,
"
logger
"
,
"
trainer
"
,
"
paths
"
,
"
extras
"
)``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style
=
"
dim
"
tree
=
rich
.
tree
.
Tree
(
"
CONFIG
"
,
style
=
style
,
guide_style
=
style
)
queue
=
[]
# add fields from `print_order` to queue
for
field
in
print_order
:
queue
.
append
(
field
)
if
field
in
cfg
else
log
.
warning
(
f
"
Field
'
{
field
}
'
not found in config. Skipping
'
{
field
}
'
config printing...
"
)
# add all the other fields to queue (not specified in `print_order`)
for
field
in
cfg
:
if
field
not
in
queue
:
queue
.
append
(
field
)
# generate config tree from queue
for
field
in
queue
:
branch
=
tree
.
add
(
field
,
style
=
style
,
guide_style
=
style
)
config_group
=
cfg
[
field
]
if
isinstance
(
config_group
,
DictConfig
):
branch_content
=
OmegaConf
.
to_yaml
(
config_group
,
resolve
=
resolve
)
else
:
branch_content
=
str
(
config_group
)
branch
.
add
(
rich
.
syntax
.
Syntax
(
branch_content
,
"
yaml
"
))
# print config tree
rich
.
print
(
tree
)
# save config tree to file
if
save_to_file
:
with
open
(
Path
(
cfg
.
paths
.
output_dir
,
"
config_tree.log
"
),
"
w
"
)
as
file
:
rich
.
print
(
tree
,
file
=
file
)
@rank_zero_only
def
enforce_tags
(
cfg
:
DictConfig
,
save_to_file
:
bool
=
False
)
->
None
:
"""
Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if
not
cfg
.
get
(
"
tags
"
):
if
"
id
"
in
HydraConfig
().
cfg
.
hydra
.
job
:
raise
ValueError
(
"
Specify tags before launching a multirun!
"
)
log
.
warning
(
"
No tags provided in config. Prompting user to input tags...
"
)
tags
=
Prompt
.
ask
(
"
Enter a list of comma separated tags
"
,
default
=
"
dev
"
)
tags
=
[
t
.
strip
()
for
t
in
tags
.
split
(
"
,
"
)
if
t
!=
""
]
with
open_dict
(
cfg
):
cfg
.
tags
=
tags
log
.
info
(
f
"
Tags:
{
cfg
.
tags
}
"
)
if
save_to_file
:
with
open
(
Path
(
cfg
.
paths
.
output_dir
,
"
tags.log
"
),
"
w
"
)
as
file
:
rich
.
print
(
cfg
.
tags
,
file
=
file
)
This diff is collapsed.
Click to expand it.
src/utils/utils.py
0 → 100644
+
135
−
0
View file @
be6c0b9f
import
logging
import
warnings
from
importlib.util
import
find_spec
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
from
omegaconf
import
DictConfig
,
OmegaConf
from
src.utils
import
rich_utils
log
=
logging
.
getLogger
(
__name__
)
def
extras
(
cfg
:
DictConfig
)
->
None
:
"""
Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if
not
cfg
.
get
(
"
extras
"
):
log
.
warning
(
"
Extras config not found! <cfg.extras=null>
"
)
return
# disable python warnings
if
cfg
.
extras
.
get
(
"
ignore_warnings
"
):
log
.
info
(
"
Disabling python warnings! <cfg.extras.ignore_warnings=True>
"
)
warnings
.
filterwarnings
(
"
ignore
"
)
# prompt user to input tags from command line if none are provided in the config
if
cfg
.
extras
.
get
(
"
enforce_tags
"
):
log
.
info
(
"
Enforcing tags! <cfg.extras.enforce_tags=True>
"
)
rich_utils
.
enforce_tags
(
cfg
,
save_to_file
=
True
)
# pretty print config tree using Rich library
if
cfg
.
extras
.
get
(
"
print_config
"
):
log
.
info
(
"
Printing config tree with Rich! <cfg.extras.print_config=True>
"
)
rich_utils
.
print_config_tree
(
cfg
,
resolve
=
True
,
save_to_file
=
True
)
if
cfg
.
extras
.
get
(
"
save_config
"
):
log
.
info
(
f
"
Saving config in
{
cfg
.
paths
.
output_dir
}
/config.yaml! <cfg.extras.save_config=True>
"
)
OmegaConf
.
save
(
cfg
,
f
"
{
cfg
.
paths
.
output_dir
}
/config.yaml
"
,
resolve
=
True
)
def
task_wrapper
(
task_func
:
Callable
)
->
Callable
:
"""
Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return
metric_dict
,
object_dict
``
`
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def
wrap
(
cfg
:
DictConfig
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
# execute the task
try
:
metric_dict
,
object_dict
=
task_func
(
cfg
=
cfg
)
# things to do if exception occurs
except
Exception
as
ex
:
# save exception to `.log` file
log
.
exception
(
""
)
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise
ex
# things to always do after either success or exception
finally
:
# display output dir path in terminal
log
.
info
(
f
"
Output dir:
{
cfg
.
paths
.
output_dir
}
"
)
# always close wandb run (even if exception occurs so multirun won't fail)
if
find_spec
(
"
wandb
"
):
# check if wandb is installed
import
wandb
if
wandb
.
run
:
log
.
info
(
"
Closing wandb!
"
)
wandb
.
finish
()
return
metric_dict
,
object_dict
return
wrap
def
get_metric_value
(
metric_dict
:
Dict
[
str
,
Any
],
metric_name
:
Optional
[
str
])
->
Optional
[
float
]:
"""
Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: If provided, the name of the metric to retrieve.
:return: If a metric name was provided, the value of the metric.
"""
if
not
metric_name
:
log
.
info
(
"
Metric name is None! Skipping metric value retrieval...
"
)
return
None
if
metric_name
not
in
metric_dict
:
raise
Exception
(
f
"
Metric value not found! <metric_name=
{
metric_name
}
>
\n
"
"
Make sure metric name logged in LightningModule is correct!
\n
"
"
Make sure `optimized_metric` name in `hparams_search` config is correct!
"
)
metric_value
=
metric_dict
[
metric_name
].
item
()
log
.
info
(
f
"
Retrieved metric value! <
{
metric_name
}
=
{
metric_value
}
>
"
)
return
metric_value
def
remove_omegaconf_dependencies
(
obj
:
Any
)
->
Any
:
if
OmegaConf
.
is_config
(
obj
):
return
OmegaConf
.
to_container
(
obj
,
resolve
=
True
)
if
isinstance
(
obj
,
dict
):
return
{
k
:
remove_omegaconf_dependencies
(
v
)
for
k
,
v
in
obj
.
items
()}
if
isinstance
(
obj
,
list
):
return
[
remove_omegaconf_dependencies
(
e
)
for
e
in
obj
]
return
obj
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment