# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import signal
import socket
import sys
import time
import types
import typing as tp
from pathlib import Path
from . import logger, utils
from .utils import DelayedSubmission, JobPaths
_PREEMPT_SIG_ENV = "SUBMITIT_PREEMPT_SIGNAL"
class JobEnvironment:
"""Describe the environment inside which the job is running.
This includes job id, number of GPUs available, ...
This class can only be instantiated from a running submitit job.
@plugin-dev: default implementation look for information into environment variables.
Override _env to map environment variable to each property.
"""
# preemption signal uses USR2 as default, but this behavior
# can be overiden (eg: export SUBMITIT_PREEMPT_SIGNAL=USR2)
# CAUTION: NCCL may catch USR1 so it should be avoided
USR_SIG = os.environ.get(_PREEMPT_SIG_ENV, "USR2")
_env: tp.ClassVar[tp.Dict[str, str]] = {}
def __new__(cls, *args: tp.Any) -> "JobEnvironment":
if cls is not JobEnvironment:
return super().__new__(cls, *args) # type: ignore
from . import plugins # pylint: disable=cyclic-import,import-outside-toplevel
return plugins.get_job_environment()
def __init__(self) -> None:
self.cluster = self.name()
@classmethod
def name(cls) -> str:
n = cls.__name__
if n.endswith("JobEnvironment"):
n = n[: -len("JobEnvironment")]
return n.lower()
@property
def paths(self) -> JobPaths:
"""Provides the paths used by submitit, including
stdout, stderr, submitted_pickle and folder.
"""
folder = os.environ["SUBMITIT_FOLDER"]
return JobPaths(folder, job_id=self.job_id, task_id=self.global_rank)
def activated(self) -> bool:
"""Tests if we are running inside this environment.
@plugin-dev: assumes that the SUBMITIT_EXECUTOR variable has been
set to the executor name
"""
return os.environ.get("SUBMITIT_EXECUTOR", "") == self.name()
@property
def hostname(self) -> str:
return socket.gethostname()
@property
def hostnames(self) -> tp.Sequence[str]:
return [self.hostname]
@property
def job_id(self) -> str:
if self.array_job_id:
return f"{self.array_job_id}_{self.array_task_id}"
else:
return self.raw_job_id
@property
def raw_job_id(self) -> str:
return os.environ[self._env["job_id"]]
@property
def array_job_id(self) -> tp.Optional[str]:
n = "array_job_id"
return None if n not in self._env else os.environ.get(self._env[n], None)
@property
def array_task_id(self) -> tp.Optional[str]:
n = "array_task_id"
return None if n not in self._env else os.environ.get(self._env[n], None)
@property
def num_tasks(self) -> int:
"""Total number of tasks for the job"""
return int(os.environ.get(self._env["num_tasks"], 1))
@property
def num_nodes(self) -> int:
"""Total number of nodes for the job"""
return int(os.environ.get(self._env["num_nodes"], 1))
@property
def node(self) -> int:
"""Id of the current node"""
return int(os.environ.get(self._env["node"], 0))
@property
def global_rank(self) -> int:
"""Global rank of the task"""
return int(os.environ.get(self._env["global_rank"], 0))
@property
def local_rank(self) -> int:
"""Local rank of the task, ie on the current node."""
return int(os.environ.get(self._env["local_rank"], 0))
def __repr__(self) -> str:
# should look like this:
# JobEnvironment(job_id=17015819, hostname=learnfair0218, local_rank=2(3), node=1(2), global_rank=5(6))
info = [f"{n}={getattr(self, n)}" for n in ("job_id", "hostname")]
names = ("local_rank", "node", "global_rank")
totals = [self.num_tasks // self.num_nodes, self.num_nodes, self.num_tasks]
info += [f"{n}={getattr(self, n)}({t})" for n, t in zip(names, totals)]
info_str = ", ".join(info)
return f"JobEnvironment({info_str})"
@classmethod
def _usr_sig(cls) -> tp.Any:
name = "SIG" + cls.USR_SIG
out = getattr(signal, name, None)
if out is None:
raise RuntimeError(
f"Unknown signal {name}, you may need to unset or update env var {_PREEMPT_SIG_ENV} (Eg: USR2)"
)
return out
def _handle_signals(self, paths: JobPaths, submission: DelayedSubmission) -> None:
"""Set up signals handler for the current executable.
The default implementation checkpoint the given submission and requeues it.
@plugin-dev: Should be adapted to the signals used in this cluster.
"""
handler = SignalHandler(self, paths, submission)
signal.signal(self._usr_sig(), handler.checkpoint_and_try_requeue)
# A priori we don't need other signals anymore,
# but still log them to make it easier to debug.
signal.signal(signal.SIGTERM, handler.bypass)
signal.signal(signal.SIGCONT, handler.bypass)
# pylint: disable=unused-argument
def _requeue(self, countdown: int) -> None:
"""Requeue the current job.
@plugin-dev:Must be overridden by JobEnvironment implementations.
Use self.job_id to find what need to be requeued.
"""
class SignalHandler:
def __init__(self, env: JobEnvironment, job_paths: JobPaths, delayed: DelayedSubmission) -> None:
self.env = env
self._job_paths = job_paths
self._delayed = delayed
self._logger = logger.get_logger()
self._start_time = time.time()
def has_timed_out(self) -> bool:
# SignalHandler is created by submitit as soon as the process start,
# so _start_time is an accurate measure of the global runtime of the job.
walltime = time.time() - self._start_time
max_walltime = self._delayed._timeout_min * 60
guaranteed_walltime = min(max_walltime * 0.8, max_walltime - 10 * 60)
timed_out = walltime >= guaranteed_walltime
if timed_out:
self._logger.info(
f"Job has timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes."
)
else:
self._logger.info(
f"Job has not timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes."
)
return timed_out
# pylint:disable=unused-argument
def bypass(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None:
self._logger.warning(f"Bypassing signal {signal.Signals(signum).name}")
# pylint:disable=unused-argument
def checkpoint_and_try_requeue(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None:
timed_out = self.has_timed_out()
case = "timed-out" if timed_out else "preempted"
self._logger.warning(
f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}: this job is {case}."
)
procid = self.env.global_rank
if procid != 0:
self._logger.info(f"Not checkpointing nor requeuing since I am a slave (procid={procid}).")
# do not sys.exit, because it might kill the master task
return
delayed = self._delayed
countdown = delayed._timeout_countdown - timed_out
no_requeue_reason = ""
if hasattr(delayed.function, "checkpoint"):
no_requeue_reason = _checkpoint(delayed, self._job_paths.submitted_pickle, countdown)
elif timed_out:
no_requeue_reason = "timed-out and not checkpointable"
if countdown < 0: # this is the end
no_requeue_reason = "timed-out too many times"
if no_requeue_reason:
# raise an error so as to create "result_pickle" file which notifies the job is over
# this is caught by the try/except in "process_job"
message = f"Job not requeued because: {no_requeue_reason}."
self._logger.info(message)
raise utils.UncompletedJobError(message)
# if everything went well, requeue!
self.env._requeue(countdown)
self._exit()
# pylint:disable=unused-argument
def checkpoint_and_exit(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None:
# Note: no signal is actually bound to `checkpoint_and_exit` but this is used by plugins.
self._logger.info(f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}")
procid = self.env.global_rank
if procid:
self._logger.info(f"Not checkpointing since I am a slave (procid={procid}).")
# do not sys.exit, because it might kill the master task
return
delayed = self._delayed
if hasattr(delayed.function, "checkpoint"):
_checkpoint(self._delayed, self._job_paths.submitted_pickle, self._delayed._timeout_countdown)
self._exit()
def _exit(self) -> None:
# extracted for mocking
self._logger.info("Exiting gracefully after preemption/timeout.")
sys.exit(-1)
def _checkpoint(delayed: DelayedSubmission, filepath: Path, countdown: int) -> str:
"""Call the checkpoint method and dump the updated delayed.
Returns:
--------
no_requeue_reason: str
a string explaining while there was no requeuing, else empty string if requeuing works
"""
logger.get_logger().info("Calling checkpoint method.")
ckpt_delayed = delayed._checkpoint_function()
if ckpt_delayed is None:
return "checkpoint function returned None"
ckpt_delayed.set_timeout(delayed._timeout_min, countdown)
with utils.temporary_save_path(filepath) as tmp:
ckpt_delayed.dump(tmp)
return "" # requeues