# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import collections
import contextlib
import datetime
import itertools
import os
import random
import shutil
import subprocess
import tempfile
import time
import typing as tp
from pathlib import Path
# pylint: disable=unused-import
# import DelayedSubmission and CommandFunction to populate helpers namespace
from .core import core
from .core.job_environment import JobEnvironment
from .core.utils import CommandFunction as CommandFunction # noqa
from .core.utils import DelayedSubmission as DelayedSubmission # noqa
from .core.utils import environment_variables as environment_variables # noqa
class Checkpointable:
"""Derived callable classes are requeued after timeout with their current
state dumped at checkpoint.
__call__ method must be implemented to make your class a callable.
Note
----
The following implementation of the checkpoint method resubmits the full current
state of the callable (self) with the initial argument. You may want to replace the method to
curate the state (dump a neural network to a standard format and remove it from
the state so that not to pickle it) and change/remove the initial parameters.
"""
# pylint: disable=unused-argument
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
assert callable(
instance
), f"Class {cls.__name__} is marked as Checkpointable but doesn't have a __call__ method. Please add a __call__ method."
return instance
def checkpoint(self, *args: tp.Any, **kwargs: tp.Any) -> DelayedSubmission:
"""Resubmits the same callable with the same arguments"""
# The DelayedSubmission class goal is only to register and format
# the arguments of the call "self(*args, **kwargs)" for submission to slurm
return DelayedSubmission(self, *args, **kwargs) # type: ignore
class FunctionSequence(Checkpointable):
"""This is for gathering several estimations into one function, which
will return the sequence of outputs.
Also this "function" is stateful, hence it can be stopped, and recovered,
which is useful when job can be preempted.
Usage
-----
func = FunctionSequence()
func.add(my_function1, arg1, kwarg1=value_kwarg1)
func.add(my_function2, arg1, arg2)
result1, result2 = func()
Note
----
This function is checkpointable because:
- it derives from Checkpointable
- it keeps DelayedSubmission objects as attribute, which in turn store the
results of the computation in memory once they are computed. So at checkpoint
time, those results will be saved, and only the non-computed results
will be computed once the job restarts.
"""
def __init__(self, verbose: bool = False) -> None:
self.verbose = verbose
self.delayed_functions: tp.List[DelayedSubmission] = []
def add(self, func: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any) -> None:
self.delayed_functions.append(DelayedSubmission(func, *args, **kwargs))
def __len__(self) -> int:
return len(self.delayed_functions)
def __iter__(self) -> tp.Iterator[DelayedSubmission]:
return iter(self.delayed_functions)
def __call__(self) -> tp.List[tp.Any]: # pylint: disable=arguments-differ
if self.verbose:
done = sum(f.done() for f in self) # those were computed before checkpoint
print(f"Starting from {done}/{len(self.delayed_functions)}", flush=True)
return [
f.result() for f in self.delayed_functions
] # results all results one by one (by running the functions if not already done)
def as_completed(
jobs: tp.Sequence[core.Job[core.R]],
timeout: tp.Optional[tp.Union[int, float]] = None,
poll_frequency: float = 10,
) -> tp.Iterator[core.Job[core.R]]:
"""
Yields jobs as they complete (finished, failed or were cancelled).
Raises a TimeoutError if the result isn’t available after timeout seconds.
timeout can be an int or float. If timeout is not specified or None, there is no
limit to the wait time.
Parameters
----------
jobs: list
Jobs instances
timeout: int/float
Maximum time (in sec) to wait for jobs completion
poll_frequency: float
Frequency in second at which we check job status.
Yields
------
Job
The next completed job
"""
start = time.time()
jobs_done: tp.Set[int] = set()
while True:
if timeout is not None and time.time() - start > timeout:
raise TimeoutError
for i, job in enumerate(jobs):
if i in jobs_done:
continue
if job.done():
jobs_done.add(i)
yield job
if len(jobs_done) == len(jobs):
break
time.sleep(poll_frequency)
def run_cmd(str_args, **kwargs):
return subprocess.check_output(str_args, **kwargs).decode("utf-8").strip()
class RsyncSnapshot:
"""Takes a snapshot of the git repository that the script lives in.
This ensures that remote jobs always use the code from when they are scheduled
and not the code from when they are launched / re-started.
Parameters
----------
snapshot_dir: Path
A path to where the snapshot should be created
with_submodules: bool
Whether or not submodules should be included in the snapshot
exclude: Sequence[str]
An optional list of patterns to exclude from the snapshot
include: Sequence[str]
A list of relative file names to include from the snapshot.
Useful for .so or other build artifacts that are genarally not tracked by git.
Note
----
- Only files that are checked in to the repository are included in the snapshot.
If you have experimental code that you would like to include in the snapshot,
you'll need to `git add` the file first for it to be included, or use `include` arg.
"""
def __init__(
self,
snapshot_dir: Path,
root_dir: tp.Optional[Path] = None,
with_submodules: bool = False,
exclude: tp.Sequence[str] = (),
include: tp.Sequence[str] = (),
):
self.available(throw=True)
self.snapshot_dir = Path(snapshot_dir)
self.root_dir = root_dir or run_cmd(["git", "rev-parse", "--show-toplevel"])
self.original_dir = Path.cwd()
self.with_submodules = with_submodules
self.exclude = exclude
self.include = include
@staticmethod
def available(throw: bool = False) -> bool:
if not shutil.which("rsync"):
if throw:
raise RuntimeError("RsyncSnapshot requires rsync to be installed.")
return False
return True
def __enter__(self) -> None:
self.original_dir = Path.cwd()
# Get the repository root
root_dir = str(self.root_dir)
sub = "--recurse-submodules" if self.with_submodules else "-s"
# Make a shallow git clone
if not self.snapshot_dir.exists():
self.snapshot_dir.parent.mkdir(parents=True, exist_ok=True)
subprocess.check_call(["git", "clone", "--depth=2", f"file://{root_dir}", str(self.snapshot_dir)])
# Get a list of all the checked in files that we can pass to rsync
# Is Rsync faster than a `git pull` ?
with tempfile.NamedTemporaryFile() as tfile:
# https://stackoverflow.com/a/51689219/4876946
run_cmd(f"git ls-files {sub} | grep -v ^16 | cut -f2- > {tfile.name}", cwd=root_dir, shell=True)
exclude = list(itertools.chain.from_iterable(("--exclude", pat) for pat in self.exclude))
with open(tfile.name, "a", encoding="utf8") as o:
for inc in self.include:
print(inc, file=o)
run_cmd(["rsync", "-a", "--files-from", tfile.name, root_dir, str(self.snapshot_dir)] + exclude)
os.chdir(self.snapshot_dir)
def __exit__(self, *args):
os.chdir(self.original_dir)
def _default_custom_logging(monitoring_start_time: float, n_jobs: int, state_jobs: tp.Dict[str, tp.Set[int]]):
run_time = time.time() - monitoring_start_time
date_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
failed_job_indices = sorted(state_jobs["FAILED"])
n_chars = len(str(n_jobs))
print(
f"[{date_time}] Launched {int(run_time / 60)} minutes ago,",
f"{len(state_jobs['RUNNING']):{n_chars}}/{n_jobs} jobs running,",
f"{len(failed_job_indices):{n_chars}}/{n_jobs} jobs failed,",
f"{len(state_jobs['DONE']) - len(failed_job_indices):{n_chars}}/{n_jobs} jobs done",
flush=True,
)
if len(failed_job_indices) > 0:
print(f"[{date_time}] Failed jobs, indices {failed_job_indices}", flush=True)
def monitor_jobs(
jobs: tp.Sequence[core.Job[core.R]],
poll_frequency: float = 30,
test_mode: bool = False,
custom_logging: tp.Callable = _default_custom_logging,
) -> None:
"""Continuously monitors given jobs until they are all done or failed.
Parameters
----------
jobs: List[Jobs]
A list of jobs to monitor
poll_frequency: int
The time (in seconds) between two refreshes of the monitoring.
Can't be inferior to 30s.
test_mode: bool
If in test mode, we do not check the length of poll_frequency
"""
if not test_mode:
assert poll_frequency >= 30, "You can't refresh too often (>= 30s) to avoid overloading squeue"
n_jobs = len(jobs)
if n_jobs == 0:
print("There are no jobs to monitor")
return
job_arrays = ", ".join(sorted(set(str(job.job_id).split("_", 1)[0] for job in jobs)))
print(f"Monitoring {n_jobs} jobs from job arrays {job_arrays} \n")
monitoring_start_time = time.time()
while True:
if not test_mode:
jobs[0].get_info(mode="force") # Force update once to sync the state
state_jobs = collections.defaultdict(set)
for i, job in enumerate(jobs):
state_jobs[job.state.upper()].add(i)
if job.done():
state_jobs["DONE"].add(i)
failed_job_indices = sorted(state_jobs["FAILED"])
if len(state_jobs["DONE"]) == len(jobs):
print(f"All jobs finished, jobs with indices {failed_job_indices} failed", flush=True)
break
custom_logging(monitoring_start_time, n_jobs, state_jobs)
time.sleep(poll_frequency)
print(f"Whole process is finished, took {int((time.time() - monitoring_start_time) / 60)} minutes")
@contextlib.contextmanager
def clean_env(extra_names: tp.Sequence[str] = ()) -> tp.Iterator[None]:
"""Removes slurm and submitit related environment variables so as to avoid interferences
when submiting a new job from a job.
Parameters
----------
extra_names: Sequence[str]
Additional environment variables to hide inside the context,
e.g. TRITON_CACHE_DIR and TORCHINDUCTOR_CACHE_DIR when using torch.compile.
Note
----
A slurm job submitted from within a slurm job inherits some of its attributes, which may
be confusing a cause weird gres errors (or pytorch distributed).
Submitting within this context should prevent this.
Usage
-----
with submitit.helpers.clean_env():
executor.submit(...)
"""
distrib_names = ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
cluster_env = {
x: os.environ.pop(x)
for x in os.environ
if (
x.startswith(("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_"))
or x in distrib_names
or x in extra_names
)
}
try:
yield
finally:
os.environ.update(cluster_env)
class TorchDistributedEnvironment:
def __init__(self) -> None:
"""Construct a class holding the parameters required to properly setup
PyTorch distributed (with the default env:// initialization method).
Examples
--------
>>> dist_env = TorchDistributedEnvironment().export()
>>> torch.distributed.init_process_group(backend="nccl")
>>> print(f"master: {dist_env.master_addr}:{dist_env.master_port}")
"""
self._job_env = JobEnvironment()
self.master_addr = self._job_env.hostnames[0]
self.master_port = self._get_master_port()
self.rank = self._job_env.global_rank
self.world_size = self._job_env.num_tasks
self.local_rank = self._job_env.local_rank
self.local_world_size = self._job_env.num_tasks // self._job_env.num_nodes
def _get_master_port(self) -> int:
# MIN_MASTER_PORT, MAX_MASTER_PORT = (1023, 65535)
MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
master_port_str = os.environ.get("MASTER_PORT")
if master_port_str is None:
rng = random.Random(self._job_env.job_id)
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
master_port = int(master_port_str)
# assert MIN_MASTER_PORT <= master_port <= MIN_MASTER_PORT
return master_port
def export(
self,
set_cuda_visible_devices: bool = True,
overwrite: bool = False,
) -> "TorchDistributedEnvironment":
"""Export all the environment variables required to properly setup
PyTorch distributed (with the default env:// initialization method) i.e.
MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE (to which LOCAL_RANK and
LOCAL_WORLD_SIZE are added).
Parameter
----------
set_cuda_visible_device: bool
if True, updates CUDA_VISIBLE_DEVICES to use only the device
matching the local rank.
overwrite: bool
if True, overwrites the environment variables if they exist;
this can be useful when launching a job from another job.
Returns
--------
TorchDistributedEnvironment
the current instance
"""
# See the "Environment variable initialization" section from
# https://pytorch.org/docs/stable/distributed.html for the complete list of
# environment variables required for the env:// initialization method.
env_vars = {
"MASTER_ADDR": self.master_addr,
"MASTER_PORT": str(self.master_port),
"RANK": str(self.rank),
"WORLD_SIZE": str(self.world_size),
"LOCAL_RANK": str(self.local_rank), # Not required
"LOCAL_WORLD_SIZE": str(self.local_world_size), # Not required
}
if not overwrite:
for key in env_vars:
if key in os.environ:
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
# Note: CUDA_VISIBLE_DEVICES may already be set with all available GPUs
if set_cuda_visible_devices:
env_vars["CUDA_VISIBLE_DEVICES"] = str(self.local_rank)
os.environ.update(env_vars)
return self