# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # import os import signal import socket import sys import time import types import typing as tp from pathlib import Path from . import logger, utils from .utils import DelayedSubmission, JobPaths _PREEMPT_SIG_ENV = "SUBMITIT_PREEMPT_SIGNAL" class JobEnvironment: """Describe the environment inside which the job is running. This includes job id, number of GPUs available, ... This class can only be instantiated from a running submitit job. @plugin-dev: default implementation look for information into environment variables. Override _env to map environment variable to each property. """ # preemption signal uses USR2 as default, but this behavior # can be overiden (eg: export SUBMITIT_PREEMPT_SIGNAL=USR2) # CAUTION: NCCL may catch USR1 so it should be avoided USR_SIG = os.environ.get(_PREEMPT_SIG_ENV, "USR2") _env: tp.ClassVar[tp.Dict[str, str]] = {} def __new__(cls, *args: tp.Any) -> "JobEnvironment": if cls is not JobEnvironment: return super().__new__(cls, *args) # type: ignore from . import plugins # pylint: disable=cyclic-import,import-outside-toplevel return plugins.get_job_environment() def __init__(self) -> None: self.cluster = self.name() @classmethod def name(cls) -> str: n = cls.__name__ if n.endswith("JobEnvironment"): n = n[: -len("JobEnvironment")] return n.lower() @property def paths(self) -> JobPaths: """Provides the paths used by submitit, including stdout, stderr, submitted_pickle and folder. """ folder = os.environ["SUBMITIT_FOLDER"] return JobPaths(folder, job_id=self.job_id, task_id=self.global_rank) def activated(self) -> bool: """Tests if we are running inside this environment. @plugin-dev: assumes that the SUBMITIT_EXECUTOR variable has been set to the executor name """ return os.environ.get("SUBMITIT_EXECUTOR", "") == self.name() @property def hostname(self) -> str: return socket.gethostname() @property def hostnames(self) -> tp.Sequence[str]: return [self.hostname] @property def job_id(self) -> str: if self.array_job_id: return f"{self.array_job_id}_{self.array_task_id}" else: return self.raw_job_id @property def raw_job_id(self) -> str: return os.environ[self._env["job_id"]] @property def array_job_id(self) -> tp.Optional[str]: n = "array_job_id" return None if n not in self._env else os.environ.get(self._env[n], None) @property def array_task_id(self) -> tp.Optional[str]: n = "array_task_id" return None if n not in self._env else os.environ.get(self._env[n], None) @property def num_tasks(self) -> int: """Total number of tasks for the job""" return int(os.environ.get(self._env["num_tasks"], 1)) @property def num_nodes(self) -> int: """Total number of nodes for the job""" return int(os.environ.get(self._env["num_nodes"], 1)) @property def node(self) -> int: """Id of the current node""" return int(os.environ.get(self._env["node"], 0)) @property def global_rank(self) -> int: """Global rank of the task""" return int(os.environ.get(self._env["global_rank"], 0)) @property def local_rank(self) -> int: """Local rank of the task, ie on the current node.""" return int(os.environ.get(self._env["local_rank"], 0)) def __repr__(self) -> str: # should look like this: # JobEnvironment(job_id=17015819, hostname=learnfair0218, local_rank=2(3), node=1(2), global_rank=5(6)) info = [f"{n}={getattr(self, n)}" for n in ("job_id", "hostname")] names = ("local_rank", "node", "global_rank") totals = [self.num_tasks // self.num_nodes, self.num_nodes, self.num_tasks] info += [f"{n}={getattr(self, n)}({t})" for n, t in zip(names, totals)] info_str = ", ".join(info) return f"JobEnvironment({info_str})" @classmethod def _usr_sig(cls) -> tp.Any: name = "SIG" + cls.USR_SIG out = getattr(signal, name, None) if out is None: raise RuntimeError( f"Unknown signal {name}, you may need to unset or update env var {_PREEMPT_SIG_ENV} (Eg: USR2)" ) return out def _handle_signals(self, paths: JobPaths, submission: DelayedSubmission) -> None: """Set up signals handler for the current executable. The default implementation checkpoint the given submission and requeues it. @plugin-dev: Should be adapted to the signals used in this cluster. """ handler = SignalHandler(self, paths, submission) signal.signal(self._usr_sig(), handler.checkpoint_and_try_requeue) # A priori we don't need other signals anymore, # but still log them to make it easier to debug. signal.signal(signal.SIGTERM, handler.bypass) signal.signal(signal.SIGCONT, handler.bypass) # pylint: disable=unused-argument def _requeue(self, countdown: int) -> None: """Requeue the current job. @plugin-dev:Must be overridden by JobEnvironment implementations. Use self.job_id to find what need to be requeued. """ class SignalHandler: def __init__(self, env: JobEnvironment, job_paths: JobPaths, delayed: DelayedSubmission) -> None: self.env = env self._job_paths = job_paths self._delayed = delayed self._logger = logger.get_logger() self._start_time = time.time() def has_timed_out(self) -> bool: # SignalHandler is created by submitit as soon as the process start, # so _start_time is an accurate measure of the global runtime of the job. walltime = time.time() - self._start_time max_walltime = self._delayed._timeout_min * 60 guaranteed_walltime = min(max_walltime * 0.8, max_walltime - 10 * 60) timed_out = walltime >= guaranteed_walltime if timed_out: self._logger.info( f"Job has timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes." ) else: self._logger.info( f"Job has not timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes." ) return timed_out # pylint:disable=unused-argument def bypass(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None: self._logger.warning(f"Bypassing signal {signal.Signals(signum).name}") # pylint:disable=unused-argument def checkpoint_and_try_requeue(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None: timed_out = self.has_timed_out() case = "timed-out" if timed_out else "preempted" self._logger.warning( f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}: this job is {case}." ) procid = self.env.global_rank if procid != 0: self._logger.info(f"Not checkpointing nor requeuing since I am a slave (procid={procid}).") # do not sys.exit, because it might kill the master task return delayed = self._delayed countdown = delayed._timeout_countdown - timed_out no_requeue_reason = "" if hasattr(delayed.function, "checkpoint"): no_requeue_reason = _checkpoint(delayed, self._job_paths.submitted_pickle, countdown) elif timed_out: no_requeue_reason = "timed-out and not checkpointable" if countdown < 0: # this is the end no_requeue_reason = "timed-out too many times" if no_requeue_reason: # raise an error so as to create "result_pickle" file which notifies the job is over # this is caught by the try/except in "process_job" message = f"Job not requeued because: {no_requeue_reason}." self._logger.info(message) raise utils.UncompletedJobError(message) # if everything went well, requeue! self.env._requeue(countdown) self._exit() # pylint:disable=unused-argument def checkpoint_and_exit(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None: # Note: no signal is actually bound to `checkpoint_and_exit` but this is used by plugins. self._logger.info(f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}") procid = self.env.global_rank if procid: self._logger.info(f"Not checkpointing since I am a slave (procid={procid}).") # do not sys.exit, because it might kill the master task return delayed = self._delayed if hasattr(delayed.function, "checkpoint"): _checkpoint(self._delayed, self._job_paths.submitted_pickle, self._delayed._timeout_countdown) self._exit() def _exit(self) -> None: # extracted for mocking self._logger.info("Exiting gracefully after preemption/timeout.") sys.exit(-1) def _checkpoint(delayed: DelayedSubmission, filepath: Path, countdown: int) -> str: """Call the checkpoint method and dump the updated delayed. Returns: -------- no_requeue_reason: str a string explaining while there was no requeuing, else empty string if requeuing works """ logger.get_logger().info("Calling checkpoint method.") ckpt_delayed = delayed._checkpoint_function() if ckpt_delayed is None: return "checkpoint function returned None" ckpt_delayed.set_timeout(delayed._timeout_min, countdown) with utils.temporary_save_path(filepath) as tmp: ckpt_delayed.dump(tmp) return "" # requeues
Memory