# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
This module provides support for Hydra, in particular the `main` wrapper between
the end user `main` function and Hydra.
"""
import copy
from collections import namedtuple, OrderedDict
from importlib.util import find_spec
import json
import logging
from pathlib import Path
import sys
import typing as tp
from unittest import mock
import hydra
from hydra.core.global_hydra import GlobalHydra
try:
from hydra import compose, initialize_config_dir # type: ignore
except ImportError:
from hydra.experimental import compose, initialize_config_dir # type: ignore
old_hydra = True
else:
old_hydra = False
from omegaconf.dictconfig import DictConfig
from .conf import DoraConfig, SlurmConfig, update_from_hydra
from .main import DecoratedMain, MainFun
from .xp import XP, get_xp, is_xp
logger = logging.getLogger(__name__)
def _no_copy(self: tp.Any, memo: tp.Any):
# Dirty trick to speed up Hydra, will remove when Hydra 1.1
# is released, which solves the issues.
return self
_Difference = namedtuple("_Difference", "path key ref other ref_value other_value")
class _NotThere:
pass
NotThere = _NotThere()
def _compare_config(ref, other, path=[]):
"""
Given two configs, gives an iterator over all the differences. For each difference,
this will give a _Difference namedtuple.
"""
keys = sorted(ref.keys())
remaining = sorted(set(other.keys()) - set(ref.keys()))
delta = []
path.append(None)
for key in keys:
path[-1] = key
ref_value = ref[key]
assert key in other, f"XP config shouldn't be missing any key. Missing key {key}"
other_value = other[key]
if isinstance(ref_value, DictConfig):
assert isinstance(other_value, DictConfig), \
"Structure of config should be identical between XPs. "\
f"Wrong type for {key}, expected DictConfig, got {type(other_value)}."
yield from _compare_config(ref_value, other_value, path)
elif other_value != ref_value:
yield _Difference(list(path), key, ref, other, ref_value, other_value)
for key in remaining:
path[-1] = key
other_value = other[key]
yield _Difference(list(path), key, ref, other, NotThere, other_value)
path.pop(-1)
return delta
def _simplify_argv(argv: tp.Sequence[str]) -> tp.List[str]:
simplified = []
seen = set()
for arg in list(argv)[::-1]:
assert '=' in arg, f'Argument {arg} does not contain ='
key, value = arg.split('=', 1)
key = key.strip()
if key in seen:
continue
else:
seen.add(key)
simplified.append(arg)
return simplified[::-1]
def _dump_key(key):
if key is None:
return "null"
elif isinstance(key, (bool, int, float)):
return str(key)
elif isinstance(key, str):
assert ":" not in key
return key
else:
raise TypeError(f"Unsupported dict key type {type(key)} for key {key}")
def _hydra_value_as_override(value):
# hydra doesn't support parsing dict with the json format, so for now
# we have to use a custom function to dump a value.
if value is None:
return "null"
elif isinstance(value, (bool, int, float, str)):
return json.dumps(value)
elif isinstance(value, dict):
return "{" + ", ".join(
f"{_dump_key(key)}: {_hydra_value_as_override(val)}"
for key, val in value.items()
) + "}"
elif isinstance(value, (list, tuple)):
return "[" + ", ".join(_hydra_value_as_override(val) for val in value) + "]"
else:
raise TypeError(f"Unsupported value type {type(value)} for value {value}")
class HydraMain(DecoratedMain):
_slow = True
def __init__(self, main: MainFun, config_name: str, config_path: str, **kwargs):
self.config_name = config_name
self.config_path = config_path
self.hydra_kwargs = kwargs
module = main.__module__
if module == "__main__":
spec = sys.modules[module].__spec__
if spec is None:
module_path = sys.argv[0]
self._job_name = module_path.rsplit(".", 2)[1]
else:
assert spec.origin is not None
module_path = spec.origin
module = spec.name
self._job_name = module.rsplit(".", 1)[1]
else:
spec = find_spec(module)
assert spec is not None and spec.origin is not None
module_path = spec.origin
self._job_name = module.rsplit(".", 1)[1]
self.full_config_path = Path(module_path).parent.resolve()
if config_path is not None:
self.full_config_path = self.full_config_path / config_path
self._initialized = False
self._base_cfg = self._get_config()
self._config_groups = self._get_config_groups()
dora = self._get_dora()
super().__init__(main, dora)
# this is a really dirty hack to make Hydra believe that this is
# coming from the __main__ module, as it would usually be.
# This allows to use relative paths for config_path.
main.__module__ = "__main__"
def _get_dora(self) -> DoraConfig:
dora = DoraConfig()
if hasattr(self._base_cfg, "dora"):
update_from_hydra(dora, self._base_cfg.dora)
dora.exclude += ["dora.*", "slurm.*"]
dora.dir = Path(dora.dir)
return dora
def get_slurm_config(self) -> SlurmConfig:
"""Return default Slurm config for the launch and grid actions.
"""
slurm = SlurmConfig()
if hasattr(self._base_cfg, "slurm"):
update_from_hydra(slurm, self._base_cfg.slurm)
return slurm
def get_xp(self, argv: tp.Sequence[str]):
argv = _simplify_argv(argv)
cfg = self._get_config(argv)
base, delta = self._get_base_config(argv)
delta += self._get_delta(base, cfg)
xp = XP(dora=self.dora, cfg=cfg, argv=argv, delta=delta)
return xp
def value_to_argv(self, arg: tp.Any) -> tp.List[str]:
# Here we get the raw stuff from what is passed to the grid launcher.
# arg is either a str (in which case it is a raw override)
# or a dict, in which case each entry is an override,
# or a list of dict or a list of str.
argv = []
if isinstance(arg, str):
argv.append(arg)
elif isinstance(arg, dict):
for key, value in arg.items():
if key not in self._config_groups:
# We need to convert the value using a custom function
# to respect how Hydra parses overrides.
value = _hydra_value_as_override(value)
argv.append(f"{key}={value}")
elif isinstance(arg, (list, tuple)):
for part in arg:
argv += self.value_to_argv(part)
else:
raise ValueError(f"Can only process dict, tuple, lists and str, but got {arg}")
return argv
def get_name_parts(self, xp: XP) -> OrderedDict:
parts = OrderedDict()
assert xp.delta is not None
for name, value in xp.delta:
parts[name] = value
return parts
def _main(self):
if is_xp():
run_dir = f"hydra.run.dir={get_xp().folder}"
sys.argv.append(run_dir)
try:
return hydra.main(
config_name=self.config_name,
config_path=self.config_path,
**self.hydra_kwargs)(self.main)()
finally:
if is_xp():
sys.argv.remove(run_dir)
def _get_config_groups(self) -> tp.List[str]:
with initialize_config_dir(str(self.full_config_path), job_name=self._job_name,
**self.hydra_kwargs):
gh = GlobalHydra.instance().hydra
assert gh is not None
return list(gh.list_all_config_groups())
def _is_active(self, argv: tp.List[str]) -> bool:
if '-m' in argv or '--multirun' in argv:
return False
return True
def _get_base_config(
self, overrides: tp.List[str] = []
) -> tp.Tuple[DictConfig, tp.List[tp.Tuple[str, str]]]:
"""
Return base config based on composition, along with delta for the
composition overrides.
"""
with initialize_config_dir(str(self.full_config_path), job_name=self._job_name,
**self.hydra_kwargs):
gh = GlobalHydra.instance().hydra
assert gh is not None
to_keep = []
delta: tp.List[tp.Tuple[str, str]] = []
for arg in overrides:
for group in self._config_groups:
if arg.startswith(f'{group}='):
to_keep.append(arg)
_, value = arg.split('=', 1)
delta = [(g, v) for g, v in delta if g != group]
delta.append((group, value))
if not to_keep:
return self._base_cfg, []
cfg = self._get_config_noinit(to_keep)
return cfg, delta
def _get_config(self,
overrides: tp.List[str] = []) -> DictConfig:
"""
Internal method, returns the config for the given override,
but without the dora.sig field filled.
"""
with initialize_config_dir(str(self.full_config_path), job_name=self._job_name,
**self.hydra_kwargs):
return self._get_config_noinit(overrides)
def _get_config_noinit(self, overrides: tp.List[str] = []) -> DictConfig:
if old_hydra:
with mock.patch.object(DictConfig, "__deepcopy__", _no_copy):
cfg = compose(self.config_name, overrides) # type: ignore
cfg = copy.deepcopy(cfg)
else:
cfg = compose(self.config_name, overrides) # type: ignore
return cfg
def _get_delta(self, init: DictConfig, other: DictConfig):
"""
Returns an iterator over all the differences between the init and other config.
"""
delta = []
for diff in _compare_config(init, other):
name = ".".join(diff.path)
delta.append((name, diff.other_value))
return delta
def hydra_main(config_name: str, config_path: str, **kwargs):
"""Wrap your main function with this.
You can pass extra kwargs, e.g. `version_base` introduced in 1.2.
"""
def _decorator(main: MainFun):
return HydraMain(main, config_name=config_name, config_path=config_path,
**kwargs)
return _decorator