# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # import functools import os from typing import TYPE_CHECKING, List, Mapping, Tuple, Type from ..core import logger if TYPE_CHECKING: # Breaks the import cycle from ..core.core import Executor from ..core.job_environment import JobEnvironment @functools.lru_cache() def _get_plugins() -> Tuple[List[Type["Executor"]], List["JobEnvironment"]]: # pylint: disable=cyclic-import,import-outside-toplevel # Load dynamically to avoid import cycle # pkg_resources goes through all modules on import. import pkg_resources from ..local import debug, local from ..slurm import slurm # TODO: use sys.modules.keys() and importlib.resources to find the files # We load both kind of entry points at the same time because we have to go through all module files anyway. executors: List[Type["Executor"]] = [slurm.SlurmExecutor, local.LocalExecutor, debug.DebugExecutor] job_envs = [slurm.SlurmJobEnvironment(), local.LocalJobEnvironment(), debug.DebugJobEnvironment()] for entry_point in pkg_resources.iter_entry_points("submitit"): if entry_point.name not in ("executor", "job_environment"): logger.warning(f"Found unknown entry point in package {entry_point.module_name}: {entry_point}") continue try: # call `load` rather than `resolve`. # `load` also checks the module and its dependencies are correctly installed. cls = entry_point.load() except Exception as e: # This may happen if the plugin haven't been correctly installed logger.exception(f"Failed to load submitit plugin '{entry_point.module_name}': {e}") continue if entry_point.name == "executor": executors.append(cls) else: try: job_env = cls() except Exception as e: logger.exception( f"Failed to init JobEnvironment '{cls.name}' ({cls}) from submitit plugin '{entry_point.module_name}': {e}" ) continue job_envs.append(job_env) return (executors, job_envs) @functools.lru_cache() def get_executors() -> Mapping[str, Type["Executor"]]: # TODO: check collisions between executor names return {ex.name(): ex for ex in _get_plugins()[0]} def get_job_environment() -> "JobEnvironment": # Don't cache this function. It makes testing harder. # The slow part is the plugin discovery anyway. envs = get_job_environments() # bypassing can be helful for testing if "_TEST_CLUSTER_" in os.environ: c = os.environ["_TEST_CLUSTER_"] assert c in envs, f"Unknown $_TEST_CLUSTER_='{c}', available: {envs.keys()}." return envs[c] for env in envs.values(): # TODO? handle the case where several envs are valid if env.activated(): return env raise RuntimeError( f"Could not figure out which environment the job is runnning in. Known environments: {', '.join(envs.keys())}." ) @functools.lru_cache() def get_job_environments() -> Mapping[str, "JobEnvironment"]: return {env.name(): env for env in _get_plugins()[1]}
Memory