# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # import logging import re import typing as tp from pathlib import Path import pkg_resources import pytest from . import core, plugins from .job_environment import JobEnvironment @pytest.mark.parametrize("env", plugins.get_job_environments().values()) def test_env(env: JobEnvironment) -> None: assert isinstance(env, JobEnvironment) # We are not inside a submitit job assert not env.activated() assert type(env)._requeue is not JobEnvironment._requeue, "_requeue need to be overridden" @pytest.mark.parametrize("ex", plugins.get_executors().values()) def test_executors(ex: tp.Type[core.Executor]) -> None: assert isinstance(ex, type) assert issubclass(ex, core.Executor) assert ex.affinity() >= -1 def test_finds_default_environments() -> None: envs = plugins.get_job_environments() assert len(envs) >= 3 assert "slurm" in envs assert "local" in envs assert "debug" in envs def test_finds_default_executors() -> None: ex = plugins.get_executors() assert len(ex) >= 3 assert "slurm" in ex assert "local" in ex assert "debug" in ex def test_job_environment_works(monkeypatch): monkeypatch.setenv("_TEST_CLUSTER_", "slurm") env = plugins.get_job_environment() assert env.cluster == "slurm" assert type(env).__name__ == "SlurmJobEnvironment" env2 = JobEnvironment() assert env2.cluster == "slurm" assert type(env2).__name__ == "SlurmJobEnvironment" def test_job_environment_raises_outside_of_job() -> None: with pytest.raises(RuntimeError, match=r"which environment.*slurm.*local.*debug"): plugins.get_job_environment() class PluginCreator: def __init__(self, tmp_path: Path, monkeypatch): self.tmp_path = tmp_path self.monkeypatch = monkeypatch def add_plugin(self, name: str, entry_points: str, init: str): plugin = self.tmp_path / name plugin.mkdir(mode=0o777) plugin_egg = plugin.with_suffix(".egg-info") plugin_egg.mkdir(mode=0o777) (plugin_egg / "entry_points.txt").write_text(entry_points) (plugin / "__init__.py").write_text(init) # also fix pkg_resources since it already has loaded old packages in other tests. working_set = pkg_resources.WorkingSet([str(self.tmp_path)]) self.monkeypatch.setattr(pkg_resources, "iter_entry_points", working_set.iter_entry_points) def __enter__(self) -> None: _clear_plugin_cache() self.monkeypatch.syspath_prepend(self.tmp_path) def __exit__(self, *exception: tp.Any) -> None: _clear_plugin_cache() def _clear_plugin_cache() -> None: plugins._get_plugins.cache_clear() plugins.get_executors.cache_clear() @pytest.fixture(name="plugin_creator") def _plugin_creator(tmp_path: Path, monkeypatch) -> tp.Iterator[PluginCreator]: creator = PluginCreator(tmp_path, monkeypatch) with creator: yield creator def test_find_good_plugin(plugin_creator: PluginCreator) -> None: plugin_creator.add_plugin( "submitit_good", entry_points="""[submitit] executor = submitit_good:GoodExecutor job_environment = submitit_good:GoodJobEnvironment unsupported_key = submitit_good:SomethingElse """, init=""" import submitit class GoodExecutor(submitit.Executor): pass class GoodJobEnvironment: pass """, ) executors = plugins.get_executors().keys() # Only the plugins declared with plugin_creator are visible. assert set(executors) == {"good", "slurm", "local", "debug"} def test_skip_bad_plugin(caplog, plugin_creator: PluginCreator) -> None: caplog.set_level(logging.WARNING, logger="submitit") plugin_creator.add_plugin( "submitit_bad", entry_points="""[submitit] executor = submitit_bad:NonExisitingExecutor job_environment = submitit_bad:BadEnvironment unsupported_key = submitit_bad:SomethingElse """, init=""" import submitit class BadEnvironment: name = "bad" def __init__(self): raise Exception("this is a bad environment") """, ) executors = plugins.get_executors().keys() assert {"slurm", "local", "debug"} == set(executors) assert "bad" not in executors expected = [ (logging.ERROR, r"'submitit_bad'.*no attribute 'NonExisitingExecutor'"), (logging.ERROR, r"'submitit_bad'.*this is a bad environment"), (logging.WARNING, "unsupported_key = submitit_bad:SomethingElse"), ] assert len(caplog.records) == len(expected) for record, ex_record in zip(caplog.records, expected): assert record.name == "submitit" assert record.levelno == ex_record[0] assert re.search(ex_record[1], record.getMessage())
Memory