# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Any, List, Optional, Type from pytest import mark, param, raises, skip from hydra.core.default_element import InputDefault from hydra.core.object_type import ObjectType from hydra.plugins.config_source import ConfigLoadError, ConfigSource class ConfigSourceTestSuite: def skip_overlap_config_path_name(self) -> bool: """ Some config source plugins do not support config name and path overlap. For example the following may not be allowed: (dataset exists both as a config object and a config group) /dateset.yaml /dataset/cifar.yaml Overriding and returning True here will disable testing of this scenario by assuming the dataset config (dataset.yaml) is not present. """ return False def test_not_available(self, type_: Type[ConfigSource], path: str) -> None: scheme = type_(provider="foo", path=path).scheme() # Test is meaningless for StructuredConfigSource if scheme == "structured": return src = type_(provider="foo", path=f"{scheme}://___NOT_FOUND___") assert not src.available() @mark.parametrize( "config_path, expected", [ param("", True, id="empty"), param("dataset", True, id="dataset"), param("optimizer", True, id="optimizer"), param( "configs_with_defaults_list", True, id="configs_with_defaults_list", ), param("dataset/imagenet", False, id="dataset/imagenet"), param("level1", True, id="level1"), param("level1/level2", True, id="level1/level2"), param("level1/level2/nested1", False, id="level1/level2/nested1"), param("not_found", False, id="not_found"), ], ) def test_is_group( self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool ) -> None: src = type_(provider="foo", path=path) ret = src.is_group(config_path=config_path) assert ret == expected @mark.parametrize( "config_path, expected", [ ("", False), ("optimizer", False), ("dataset/imagenet", True), ("dataset/imagenet.yaml", True), ("dataset/imagenet.foobar", False), ("configs_with_defaults_list/global_package", True), ("configs_with_defaults_list/group_package", True), ("level1", False), ("level1/level2", False), ("level1/level2/nested1", True), ("not_found", False), ], ) def test_is_config( self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool ) -> None: src = type_(provider="foo", path=path) ret = src.is_config(config_path=config_path) assert ret == expected @mark.parametrize( "config_path, expected", [ ("dataset", True), ], ) def test_is_config_with_overlap_name( self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool ) -> None: if self.skip_overlap_config_path_name(): skip( f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups " f"with overlapping names." ) src = type_(provider="foo", path=path) ret = src.is_config(config_path=config_path) assert ret == expected @mark.parametrize( "config_path,results_filter,expected", [ # groups ("", ObjectType.GROUP, ["dataset", "level1", "optimizer"]), ("dataset", ObjectType.GROUP, []), ("optimizer", ObjectType.GROUP, []), ("level1", ObjectType.GROUP, ["level2"]), ("level1/level2", ObjectType.GROUP, []), # Configs ("", ObjectType.CONFIG, ["config_without_group"]), ("dataset", ObjectType.CONFIG, ["cifar10", "imagenet"]), ("optimizer", ObjectType.CONFIG, ["adam", "nesterov"]), ("level1", ObjectType.CONFIG, []), ("level1/level2", ObjectType.CONFIG, ["nested1", "nested2"]), # both ("", None, ["config_without_group", "dataset", "level1", "optimizer"]), ("dataset", None, ["cifar10", "imagenet"]), ("optimizer", None, ["adam", "nesterov"]), ("level1", None, ["level2"]), ("level1/level2", None, ["nested1", "nested2"]), ("", None, ["config_without_group", "dataset", "level1", "optimizer"]), ], ) def test_list( self, type_: Type[ConfigSource], path: str, config_path: str, results_filter: Optional[ObjectType], expected: List[str], ) -> None: src = type_(provider="foo", path=path) ret = src.list(config_path=config_path, results_filter=results_filter) for x in expected: assert x in ret assert ret == sorted(ret) @mark.parametrize( "config_path,results_filter,expected", [ # Configs ("", ObjectType.CONFIG, ["dataset"]), ], ) def test_list_with_overlap_name( self, type_: Type[ConfigSource], path: str, config_path: str, results_filter: Optional[ObjectType], expected: List[str], ) -> None: if self.skip_overlap_config_path_name(): skip( f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups " f"with overlapping names." ) src = type_(provider="foo", path=path) ret = src.list(config_path=config_path, results_filter=results_filter) for x in expected: assert x in ret assert ret == sorted(ret) @mark.parametrize( "config_path,expected_config,expected_defaults_list,expected_package", [ param( "config_without_group", {"group": False}, None, None, id="config_without_group", ), param( "config_with_unicode", {"group": "数据库"}, None, None, id="config_with_unicode", ), param( "dataset/imagenet", {"name": "imagenet", "path": "/datasets/imagenet"}, None, None, id="dataset/imagenet", ), param( "dataset/cifar10", {"name": "cifar10", "path": "/datasets/cifar10"}, None, None, id="dataset/cifar10", ), param( "dataset/not_found", raises(ConfigLoadError), None, None, id="dataset/not_found", ), param( "level1/level2/nested1", {"l1_l2_n1": True}, None, None, id="level1/level2/nested1", ), param( "level1/level2/nested2", {"l1_l2_n2": True}, None, None, id="level1/level2/nested2", ), param( "config_with_defaults_list", { "defaults": [{"dataset": "imagenet"}], "key": "value", }, None, None, id="config_with_defaults_list", ), param( "configs_with_defaults_list/global_package", { "defaults": [{"foo": "bar"}], "x": 10, }, None, "_global_", id="configs_with_defaults_list/global_package", ), param( "configs_with_defaults_list/group_package", { "defaults": [{"foo": "bar"}], "x": 10, }, None, "_group_", id="configs_with_defaults_list/group_package", ), ], ) def test_source_load_config( self, type_: Type[ConfigSource], path: str, config_path: str, expected_defaults_list: List[InputDefault], expected_package: Any, expected_config: Any, recwarn: Any, ) -> None: assert issubclass(type_, ConfigSource) src = type_(provider="foo", path=path) if isinstance(expected_config, dict): ret = src.load_config(config_path=config_path) assert ret.config == expected_config assert ret.header["package"] == expected_package assert ret.defaults_list == expected_defaults_list else: with expected_config: src.load_config(config_path=config_path) @mark.parametrize( "config_path, expected_result, expected_package", [ param("package_test/none", {"foo": "bar"}, None, id="none"), param("package_test/explicit", {"foo": "bar"}, "a.b", id="explicit"), param("package_test/global", {"foo": "bar"}, "_global_", id="global"), param("package_test/group", {"foo": "bar"}, "_group_", id="group"), param( "package_test/group_name", {"foo": "bar"}, "foo._group_._name_", id="group_name", ), param("package_test/name", {"foo": "bar"}, "_name_", id="name"), ], ) def test_package_behavior( self, type_: Type[ConfigSource], path: str, config_path: str, expected_result: Any, expected_package: str, ) -> None: src = type_(provider="foo", path=path) cfg = src.load_config(config_path=config_path) assert cfg.header["package"] == expected_package assert cfg.config == expected_result def test_default_package_for_primary_config( self, type_: Type[ConfigSource], path: str ) -> None: src = type_(provider="foo", path=path) cfg = src.load_config(config_path="primary_config") assert cfg.header["package"] == None def test_primary_config_with_non_global_package( self, type_: Type[ConfigSource], path: str ) -> None: src = type_(provider="foo", path=path) cfg = src.load_config(config_path="primary_config_with_non_global_package") assert cfg.header["package"] == "foo"
Memory