# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # import contextlib import io import itertools import os import pickle import select import shutil import subprocess import sys import tarfile import typing as tp from pathlib import Path import cloudpickle @contextlib.contextmanager def environment_variables(**kwargs: tp.Any) -> tp.Iterator[None]: backup = {x: os.environ[x] for x in kwargs if x in os.environ} os.environ.update({x: str(y) for x, y in kwargs.items()}) try: yield finally: for x in kwargs: del os.environ[x] os.environ.update(backup) class UncompletedJobError(RuntimeError): """Job is uncomplete: either unfinished or failed""" class FailedJobError(UncompletedJobError): """Job failed during processing""" class FailedSubmissionError(RuntimeError): """Job Submission failed""" class JobPaths: """Creates paths related to the slurm job and its submission""" def __init__( self, folder: tp.Union[Path, str], job_id: tp.Optional[str] = None, task_id: tp.Optional[int] = None ) -> None: self._folder = Path(folder).expanduser().absolute() self.job_id = job_id self.task_id = task_id or 0 @property def folder(self) -> Path: return self._format_id(self._folder) @property def submission_file(self) -> Path: if self.job_id and "_" in self.job_id: # We only have one submission file per job array return self._format_id(self.folder / "%A_submission.sh") return self._format_id(self.folder / "%j_submission.sh") @property def submitted_pickle(self) -> Path: return self._format_id(self.folder / "%j_submitted.pkl") @property def result_pickle(self) -> Path: return self._format_id(self.folder / "%j_%t_result.pkl") @property def stderr(self) -> Path: return self._format_id(self.folder / "%j_%t_log.err") @property def stdout(self) -> Path: return self._format_id(self.folder / "%j_%t_log.out") def _format_id(self, path: tp.Union[Path, str]) -> Path: """Replace id tag by actual id if available""" if self.job_id is None: return Path(path) replaced_path = str(path).replace("%j", str(self.job_id)).replace("%t", str(self.task_id)) array_id, *array_index = str(self.job_id).split("_", 1) if "%a" in replaced_path: if len(array_index) != 1: raise ValueError("%a is in the folder path but this is not a job array") replaced_path = replaced_path.replace("%a", array_index[0]) return Path(replaced_path.replace("%A", array_id)) def move_temporary_file( self, tmp_path: tp.Union[Path, str], name: str, keep_as_symlink: bool = False ) -> None: self.folder.mkdir(parents=True, exist_ok=True) Path(tmp_path).rename(getattr(self, name)) if keep_as_symlink: Path(tmp_path).symlink_to(getattr(self, name)) @staticmethod def get_first_id_independent_folder(folder: tp.Union[Path, str]) -> Path: """Returns the closest folder which is id independent""" parts = Path(folder).expanduser().absolute().parts tags = ["%j", "%t", "%A", "%a"] indep_parts = itertools.takewhile(lambda x: not any(tag in x for tag in tags), parts) return Path(*indep_parts) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.folder})" class DelayedSubmission: """Object for specifying the function/callable call to submit and process later. This is only syntactic sugar to make sure everything is well formatted: If what you want to compute later is func(*args, **kwargs), just instanciate: DelayedSubmission(func, *args, **kwargs). It also provides convenient tools for dumping and loading. """ def __init__(self, function: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any) -> None: self.function = function self.args = args self.kwargs = kwargs self._result: tp.Any = None self._done = False self._timeout_min: int = 0 self._timeout_countdown: int = 0 # controlled in submission and execution def result(self) -> tp.Any: if self._done: return self._result self._result = self.function(*self.args, **self.kwargs) self._done = True return self._result def done(self) -> bool: return self._done def dump(self, filepath: tp.Union[str, Path]) -> None: cloudpickle_dump(self, filepath) def set_timeout(self, timeout_min: int, max_num_timeout: int) -> None: self._timeout_min = timeout_min self._timeout_countdown = max_num_timeout @classmethod def load(cls: tp.Type["DelayedSubmission"], filepath: tp.Union[str, Path]) -> "DelayedSubmission": obj = pickle_load(filepath) # following assertion is relaxed compared to isinstance, to allow flexibility # (Eg: copying this class in a project to be able to have checkpointable jobs without adding submitit as dependency) assert obj.__class__.__name__ == cls.__name__, f"Loaded object is {type(obj)} but should be {cls}." return obj # type: ignore def _checkpoint_function(self) -> tp.Optional["DelayedSubmission"]: checkpoint = getattr(self.function, "__submitit_checkpoint__", None) if checkpoint is None: checkpoint = getattr(self.function, "checkpoint", None) if checkpoint is None: return None return checkpoint(*self.args, **self.kwargs) # type: ignore @contextlib.contextmanager def temporary_save_path(filepath: tp.Union[Path, str]) -> tp.Iterator[Path]: """Yields a path where to save a file and moves it afterward to the provided location (and replaces any existing file) This is useful to avoid processes monitoring the filepath to break if trying to read when the file is being written. Note ---- The temporary path is the provided path appended with .save_tmp """ filepath = Path(filepath) tmppath = filepath.with_suffix(filepath.suffix + ".save_tmp") assert not tmppath.exists(), "A temporary saved file already exists." yield tmppath if not tmppath.exists(): raise FileNotFoundError("No file was saved at the temporary path.") if filepath.exists(): os.remove(filepath) os.rename(tmppath, filepath) def archive_dev_folders( folders: tp.List[tp.Union[str, Path]], outfile: tp.Optional[tp.Union[str, Path]] = None ) -> Path: """Creates a tar.gz file with all provided folders""" assert isinstance(folders, (list, tuple)), "Only lists and tuples of folders are allowed" if outfile is None: outfile = "_dev_folders_.tar.gz" outfile = Path(outfile) assert str(outfile).endswith(".tar.gz"), "Archive file must have extension .tar.gz" with tarfile.TarFile(outfile, mode="w") as tf: for folder in folders: tf.add(str(folder), arcname=Path(folder).name) return outfile def copy_par_file(par_file: tp.Union[str, Path], folder: tp.Union[str, Path]) -> Path: """Copy the par (or xar) file in the folder Parameter --------- par_file: str/Path Par file generated by buck folder: str/Path folder where the par file must be copied Returns ------- Path Path of the copied .par file """ par_file = Path(par_file).expanduser().absolute() folder = Path(folder).expanduser().absolute() folder.mkdir(parents=True, exist_ok=True) dst_name = folder / par_file.name shutil.copy2(par_file, dst_name) return dst_name def pickle_load(filename: tp.Union[str, Path]) -> tp.Any: # this is used by cloudpickle as well with open(filename, "rb") as ifile: return pickle.load(ifile) def cloudpickle_dump(obj: tp.Any, filename: tp.Union[str, Path]) -> None: with open(filename, "wb") as ofile: cloudpickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL) # pylint: disable=too-many-locals def copy_process_streams( process: subprocess.Popen, stdout: io.StringIO, stderr: io.StringIO, verbose: bool = False ): """ Reads the given process stdout/stderr and write them to StringIO objects. Make sure that there is no deadlock because of pipe congestion. If `verbose` the process stdout/stderr are also copying to the interpreter stdout/stderr. """ def raw(stream: tp.Optional[tp.IO[bytes]]) -> tp.IO[bytes]: if stream is None: raise RuntimeError("Stream should not be None") if isinstance(stream, io.BufferedIOBase): stream = stream.raw # type: ignore return stream # type: ignore p_stdout, p_stderr = raw(process.stdout), raw(process.stderr) stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.IO[str]]] = { p_stdout.fileno(): (p_stdout, stdout, sys.stdout), p_stderr.fileno(): (p_stderr, stderr, sys.stderr), } fds = list(stream_by_fd.keys()) poller = select.poll() for fd in stream_by_fd: poller.register(fd, select.POLLIN | select.POLLPRI) while fds: # `poll` syscall will wait until one of the registered file descriptors has content. ready = poller.poll() for fd, _ in ready: p_stream, string, std = stream_by_fd[fd] raw_buf = p_stream.read(2**16) if not raw_buf: fds.remove(fd) poller.unregister(fd) continue buf = raw_buf.decode() string.write(buf) string.flush() if verbose: std.write(buf) std.flush() # used in "_core", so cannot be in "helpers" class CommandFunction: """Wraps a command as a function in order to make sure it goes through the pipeline and notify when it is finished. The output is a string containing everything that has been sent to stdout. WARNING: use CommandFunction only if you know the output won't be too big ! Otherwise use subprocess.run() that also streams the outputto stdout/stderr. Parameters ---------- command: list command to run, as a list verbose: bool prints the command and stdout at runtime cwd: Path/str path to the location where the command must run from Returns ------- str Everything that has been sent to stdout """ def __init__( self, command: tp.List[str], verbose: bool = True, cwd: tp.Optional[tp.Union[str, Path]] = None, env: tp.Optional[tp.Dict[str, str]] = None, ) -> None: if not isinstance(command, list): raise TypeError("The command must be provided as a list") self.command = command self.verbose = verbose self.cwd = None if cwd is None else str(cwd) self.env = env def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> str: """Call the cammand line with addidional arguments The keyword arguments will be sent as --{key}={val} The logs bufferized. They will be printed if the job fails, or sent as output of the function Errors are provided with the internal stderr. """ full_command = ( self.command + [str(x) for x in args] + [f"--{x}={y}" for x, y in kwargs.items()] ) # TODO bad parsing if self.verbose: print(f"The following command is sent: \"{' '.join(full_command)}\"") with subprocess.Popen( full_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, cwd=self.cwd, env=self.env, ) as process: stdout_buffer = io.StringIO() stderr_buffer = io.StringIO() try: copy_process_streams(process, stdout_buffer, stderr_buffer, self.verbose) except Exception as e: process.kill() process.wait() raise FailedJobError("Job got killed for an unknown reason.") from e stdout = stdout_buffer.getvalue().strip() stderr = stderr_buffer.getvalue().strip() retcode = process.wait() if stderr and (retcode and not self.verbose): # We don't print is self.verbose, as it already happened before. print(stderr, file=sys.stderr) if retcode: subprocess_error = subprocess.CalledProcessError( retcode, process.args, output=stdout, stderr=stderr ) raise FailedJobError(stderr) from subprocess_error return stdout
Memory