# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import contextlib
import io
import itertools
import os
import pickle
import select
import shutil
import subprocess
import sys
import tarfile
import typing as tp
from pathlib import Path
import cloudpickle
@contextlib.contextmanager
def environment_variables(**kwargs: tp.Any) -> tp.Iterator[None]:
backup = {x: os.environ[x] for x in kwargs if x in os.environ}
os.environ.update({x: str(y) for x, y in kwargs.items()})
try:
yield
finally:
for x in kwargs:
del os.environ[x]
os.environ.update(backup)
class UncompletedJobError(RuntimeError):
"""Job is uncomplete: either unfinished or failed"""
class FailedJobError(UncompletedJobError):
"""Job failed during processing"""
class FailedSubmissionError(RuntimeError):
"""Job Submission failed"""
class JobPaths:
"""Creates paths related to the slurm job and its submission"""
def __init__(
self, folder: tp.Union[Path, str], job_id: tp.Optional[str] = None, task_id: tp.Optional[int] = None
) -> None:
self._folder = Path(folder).expanduser().absolute()
self.job_id = job_id
self.task_id = task_id or 0
@property
def folder(self) -> Path:
return self._format_id(self._folder)
@property
def submission_file(self) -> Path:
if self.job_id and "_" in self.job_id:
# We only have one submission file per job array
return self._format_id(self.folder / "%A_submission.sh")
return self._format_id(self.folder / "%j_submission.sh")
@property
def submitted_pickle(self) -> Path:
return self._format_id(self.folder / "%j_submitted.pkl")
@property
def result_pickle(self) -> Path:
return self._format_id(self.folder / "%j_%t_result.pkl")
@property
def stderr(self) -> Path:
return self._format_id(self.folder / "%j_%t_log.err")
@property
def stdout(self) -> Path:
return self._format_id(self.folder / "%j_%t_log.out")
def _format_id(self, path: tp.Union[Path, str]) -> Path:
"""Replace id tag by actual id if available"""
if self.job_id is None:
return Path(path)
replaced_path = str(path).replace("%j", str(self.job_id)).replace("%t", str(self.task_id))
array_id, *array_index = str(self.job_id).split("_", 1)
if "%a" in replaced_path:
if len(array_index) != 1:
raise ValueError("%a is in the folder path but this is not a job array")
replaced_path = replaced_path.replace("%a", array_index[0])
return Path(replaced_path.replace("%A", array_id))
def move_temporary_file(
self, tmp_path: tp.Union[Path, str], name: str, keep_as_symlink: bool = False
) -> None:
self.folder.mkdir(parents=True, exist_ok=True)
Path(tmp_path).rename(getattr(self, name))
if keep_as_symlink:
Path(tmp_path).symlink_to(getattr(self, name))
@staticmethod
def get_first_id_independent_folder(folder: tp.Union[Path, str]) -> Path:
"""Returns the closest folder which is id independent"""
parts = Path(folder).expanduser().absolute().parts
tags = ["%j", "%t", "%A", "%a"]
indep_parts = itertools.takewhile(lambda x: not any(tag in x for tag in tags), parts)
return Path(*indep_parts)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.folder})"
class DelayedSubmission:
"""Object for specifying the function/callable call to submit and process later.
This is only syntactic sugar to make sure everything is well formatted:
If what you want to compute later is func(*args, **kwargs), just instanciate:
DelayedSubmission(func, *args, **kwargs).
It also provides convenient tools for dumping and loading.
"""
def __init__(self, function: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any) -> None:
self.function = function
self.args = args
self.kwargs = kwargs
self._result: tp.Any = None
self._done = False
self._timeout_min: int = 0
self._timeout_countdown: int = 0 # controlled in submission and execution
def result(self) -> tp.Any:
if self._done:
return self._result
self._result = self.function(*self.args, **self.kwargs)
self._done = True
return self._result
def done(self) -> bool:
return self._done
def dump(self, filepath: tp.Union[str, Path]) -> None:
cloudpickle_dump(self, filepath)
def set_timeout(self, timeout_min: int, max_num_timeout: int) -> None:
self._timeout_min = timeout_min
self._timeout_countdown = max_num_timeout
@classmethod
def load(cls: tp.Type["DelayedSubmission"], filepath: tp.Union[str, Path]) -> "DelayedSubmission":
obj = pickle_load(filepath)
# following assertion is relaxed compared to isinstance, to allow flexibility
# (Eg: copying this class in a project to be able to have checkpointable jobs without adding submitit as dependency)
assert obj.__class__.__name__ == cls.__name__, f"Loaded object is {type(obj)} but should be {cls}."
return obj # type: ignore
def _checkpoint_function(self) -> tp.Optional["DelayedSubmission"]:
checkpoint = getattr(self.function, "__submitit_checkpoint__", None)
if checkpoint is None:
checkpoint = getattr(self.function, "checkpoint", None)
if checkpoint is None:
return None
return checkpoint(*self.args, **self.kwargs) # type: ignore
@contextlib.contextmanager
def temporary_save_path(filepath: tp.Union[Path, str]) -> tp.Iterator[Path]:
"""Yields a path where to save a file and moves it
afterward to the provided location (and replaces any
existing file)
This is useful to avoid processes monitoring the filepath
to break if trying to read when the file is being written.
Note
----
The temporary path is the provided path appended with .save_tmp
"""
filepath = Path(filepath)
tmppath = filepath.with_suffix(filepath.suffix + ".save_tmp")
assert not tmppath.exists(), "A temporary saved file already exists."
yield tmppath
if not tmppath.exists():
raise FileNotFoundError("No file was saved at the temporary path.")
if filepath.exists():
os.remove(filepath)
os.rename(tmppath, filepath)
def archive_dev_folders(
folders: tp.List[tp.Union[str, Path]], outfile: tp.Optional[tp.Union[str, Path]] = None
) -> Path:
"""Creates a tar.gz file with all provided folders"""
assert isinstance(folders, (list, tuple)), "Only lists and tuples of folders are allowed"
if outfile is None:
outfile = "_dev_folders_.tar.gz"
outfile = Path(outfile)
assert str(outfile).endswith(".tar.gz"), "Archive file must have extension .tar.gz"
with tarfile.TarFile(outfile, mode="w") as tf:
for folder in folders:
tf.add(str(folder), arcname=Path(folder).name)
return outfile
def copy_par_file(par_file: tp.Union[str, Path], folder: tp.Union[str, Path]) -> Path:
"""Copy the par (or xar) file in the folder
Parameter
---------
par_file: str/Path
Par file generated by buck
folder: str/Path
folder where the par file must be copied
Returns
-------
Path
Path of the copied .par file
"""
par_file = Path(par_file).expanduser().absolute()
folder = Path(folder).expanduser().absolute()
folder.mkdir(parents=True, exist_ok=True)
dst_name = folder / par_file.name
shutil.copy2(par_file, dst_name)
return dst_name
def pickle_load(filename: tp.Union[str, Path]) -> tp.Any:
# this is used by cloudpickle as well
with open(filename, "rb") as ifile:
return pickle.load(ifile)
def cloudpickle_dump(obj: tp.Any, filename: tp.Union[str, Path]) -> None:
with open(filename, "wb") as ofile:
cloudpickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL)
# pylint: disable=too-many-locals
def copy_process_streams(
process: subprocess.Popen, stdout: io.StringIO, stderr: io.StringIO, verbose: bool = False
):
"""
Reads the given process stdout/stderr and write them to StringIO objects.
Make sure that there is no deadlock because of pipe congestion.
If `verbose` the process stdout/stderr are also copying to the interpreter stdout/stderr.
"""
def raw(stream: tp.Optional[tp.IO[bytes]]) -> tp.IO[bytes]:
if stream is None:
raise RuntimeError("Stream should not be None")
if isinstance(stream, io.BufferedIOBase):
stream = stream.raw # type: ignore
return stream # type: ignore
p_stdout, p_stderr = raw(process.stdout), raw(process.stderr)
stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.IO[str]]] = {
p_stdout.fileno(): (p_stdout, stdout, sys.stdout),
p_stderr.fileno(): (p_stderr, stderr, sys.stderr),
}
fds = list(stream_by_fd.keys())
poller = select.poll()
for fd in stream_by_fd:
poller.register(fd, select.POLLIN | select.POLLPRI)
while fds:
# `poll` syscall will wait until one of the registered file descriptors has content.
ready = poller.poll()
for fd, _ in ready:
p_stream, string, std = stream_by_fd[fd]
raw_buf = p_stream.read(2**16)
if not raw_buf:
fds.remove(fd)
poller.unregister(fd)
continue
buf = raw_buf.decode()
string.write(buf)
string.flush()
if verbose:
std.write(buf)
std.flush()
# used in "_core", so cannot be in "helpers"
class CommandFunction:
"""Wraps a command as a function in order to make sure it goes through the
pipeline and notify when it is finished.
The output is a string containing everything that has been sent to stdout.
WARNING: use CommandFunction only if you know the output won't be too big !
Otherwise use subprocess.run() that also streams the outputto stdout/stderr.
Parameters
----------
command: list
command to run, as a list
verbose: bool
prints the command and stdout at runtime
cwd: Path/str
path to the location where the command must run from
Returns
-------
str
Everything that has been sent to stdout
"""
def __init__(
self,
command: tp.List[str],
verbose: bool = True,
cwd: tp.Optional[tp.Union[str, Path]] = None,
env: tp.Optional[tp.Dict[str, str]] = None,
) -> None:
if not isinstance(command, list):
raise TypeError("The command must be provided as a list")
self.command = command
self.verbose = verbose
self.cwd = None if cwd is None else str(cwd)
self.env = env
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> str:
"""Call the cammand line with addidional arguments
The keyword arguments will be sent as --{key}={val}
The logs bufferized. They will be printed if the job fails, or sent as output of the function
Errors are provided with the internal stderr.
"""
full_command = (
self.command + [str(x) for x in args] + [f"--{x}={y}" for x, y in kwargs.items()]
) # TODO bad parsing
if self.verbose:
print(f"The following command is sent: \"{' '.join(full_command)}\"")
with subprocess.Popen(
full_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=False,
cwd=self.cwd,
env=self.env,
) as process:
stdout_buffer = io.StringIO()
stderr_buffer = io.StringIO()
try:
copy_process_streams(process, stdout_buffer, stderr_buffer, self.verbose)
except Exception as e:
process.kill()
process.wait()
raise FailedJobError("Job got killed for an unknown reason.") from e
stdout = stdout_buffer.getvalue().strip()
stderr = stderr_buffer.getvalue().strip()
retcode = process.wait()
if stderr and (retcode and not self.verbose):
# We don't print is self.verbose, as it already happened before.
print(stderr, file=sys.stderr)
if retcode:
subprocess_error = subprocess.CalledProcessError(
retcode, process.args, output=stdout, stderr=stderr
)
raise FailedJobError(stderr) from subprocess_error
return stdout