# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities used by tests """ import copy import logging import os import re import shutil import string import subprocess import sys import tempfile from contextlib import contextmanager from difflib import unified_diff from pathlib import Path from subprocess import PIPE, Popen from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union from omegaconf import Container, DictConfig, OmegaConf from hydra._internal.hydra import Hydra from hydra._internal.utils import detect_task_name from hydra.core.global_hydra import GlobalHydra from hydra.core.utils import JobReturn, validate_config_path from hydra.types import TaskFunction if sys.version_info >= (3, 8, 0): from typing import Protocol else: from typing_extensions import Protocol # type: ignore @contextmanager def does_not_raise(enter_result: Any = None) -> Iterator[Any]: yield enter_result class TaskTestFunction: """ Context function """ def __init__(self) -> None: self.temp_dir: Optional[str] = None self.overrides: Optional[List[str]] = None self.calling_file: Optional[str] = None self.calling_module: Optional[str] = None self.config_path: Optional[str] = None self.config_name: Optional[str] = None self.hydra: Optional[Hydra] = None self.job_ret: Optional[JobReturn] = None self.configure_logging: bool = False def __call__(self, cfg: DictConfig) -> Any: """ Actual function being executed by Hydra """ return 100 def __enter__(self) -> "TaskTestFunction": try: validate_config_path(self.config_path) job_name = detect_task_name(self.calling_file, self.calling_module) self.hydra = Hydra.create_main_hydra_file_or_module( calling_file=self.calling_file, calling_module=self.calling_module, config_path=self.config_path, job_name=job_name, ) self.temp_dir = tempfile.mkdtemp() overrides = copy.deepcopy(self.overrides) assert overrides is not None overrides.append(f"hydra.run.dir={self.temp_dir}") self.job_ret = self.hydra.run( config_name=self.config_name, task_function=self, overrides=overrides, with_log_configuration=self.configure_logging, ) return self finally: GlobalHydra().clear() def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # release log file handles. if self.configure_logging: logging.shutdown() assert self.temp_dir is not None shutil.rmtree(self.temp_dir, ignore_errors=True) class TTaskRunner(Protocol): def __call__( self, calling_file: Optional[str], calling_module: Optional[str], config_path: Optional[str], config_name: Optional[str], overrides: Optional[List[str]] = None, configure_logging: bool = False, ) -> TaskTestFunction: ... class SweepTaskFunction: """ Context function """ def __init__(self) -> None: """ if sweep_dir is None, we use a temp dir, else we will create dir with the path from sweep_dir. """ self.temp_dir: Optional[str] = None self.overrides: Optional[List[str]] = None self.calling_file: Optional[str] = None self.calling_module: Optional[str] = None self.task_function: Optional[TaskFunction] = None self.config_path: Optional[str] = None self.config_name: Optional[str] = None self.sweeps = None self.returns = None self.configure_logging: bool = False def __call__(self, cfg: DictConfig) -> Any: """ Actual function being executed by Hydra """ if self.task_function is not None: return self.task_function(cfg) return 100 def __enter__(self) -> "SweepTaskFunction": overrides = copy.deepcopy(self.overrides) assert overrides is not None if self.temp_dir: Path(self.temp_dir).mkdir(parents=True, exist_ok=True) else: self.temp_dir = tempfile.mkdtemp() overrides.append(f"hydra.sweep.dir={self.temp_dir}") try: validate_config_path(self.config_path) job_name = detect_task_name(self.calling_file, self.calling_module) hydra_ = Hydra.create_main_hydra_file_or_module( calling_file=self.calling_file, calling_module=self.calling_module, config_path=self.config_path, job_name=job_name, ) self.returns = hydra_.multirun( config_name=self.config_name, task_function=self, overrides=overrides, with_log_configuration=self.configure_logging, ) finally: GlobalHydra().clear() return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.configure_logging: logging.shutdown() assert self.temp_dir is not None shutil.rmtree(self.temp_dir, ignore_errors=True) class TSweepRunner(Protocol): returns: List[List[JobReturn]] def __call__( self, calling_file: Optional[str], calling_module: Optional[str], task_function: Optional[TaskFunction], config_path: Optional[str], config_name: Optional[str], overrides: Optional[List[str]], temp_dir: Optional[Path] = None, ) -> SweepTaskFunction: ... def chdir_hydra_root(subdir: Optional[str] = None) -> None: """ Change the cwd to the root of the hydra project. used from unit tests to make them runnable from anywhere in the tree. """ _chdir_to_dir_containing(target="ATTRIBUTION") if subdir is not None: os.chdir(subdir) def chdir_plugin_root() -> None: """ Change the cwd to the root of the plugin (location of setup.py) """ _chdir_to_dir_containing(target="setup.py") def _chdir_to_dir_containing( target: str, max_up: int = 6, initial_dir: str = os.getcwd() ) -> None: cur = find_parent_dir_containing(target, max_up, initial_dir) os.chdir(cur) def find_parent_dir_containing( target: str, max_up: int = 6, initial_dir: str = os.getcwd() ) -> str: cur = initial_dir while not os.path.exists(os.path.join(cur, target)) and max_up > 0: cur = os.path.relpath(os.path.join(cur, "..")) max_up = max_up - 1 if max_up == 0: raise IOError(f"Could not find {target} in parents of {os.getcwd()}") return cur def verify_dir_outputs( job_return: JobReturn, overrides: Optional[List[str]] = None ) -> None: """ Verify that directory output makes sense """ assert isinstance(job_return, JobReturn) assert job_return.working_dir is not None assert job_return.task_name is not None assert job_return.hydra_cfg is not None assert os.path.exists( os.path.join(job_return.working_dir, job_return.task_name + ".log") ) hydra_dir = os.path.join( job_return.working_dir, job_return.hydra_cfg.hydra.output_subdir ) assert os.path.exists(os.path.join(hydra_dir, "config.yaml")) assert os.path.exists(os.path.join(hydra_dir, "overrides.yaml")) assert OmegaConf.load( os.path.join(hydra_dir, "overrides.yaml") ) == OmegaConf.create(overrides or []) def _get_statements(indent: str, statements: Union[None, str, List[str]]) -> str: if isinstance(statements, str): statements = [statements] code = "" if statements is None or len(statements) == 0: code = "pass" else: for p in statements: code += f"{indent}{p}\n" return code def integration_test( tmpdir: Path, task_config: Any, overrides: List[str], prints: Union[str, List[str]], expected_outputs: Union[str, List[str]], prolog: Union[None, str, List[str]] = None, filename: str = "task.py", env_override: Optional[Dict[str, str]] = None, clean_environment: bool = False, generate_custom_cmd: Callable[..., List[str]] = lambda cmd, *args, **kwargs: cmd, ) -> str: Path(tmpdir).mkdir(parents=True, exist_ok=True) if isinstance(expected_outputs, str): expected_outputs = [expected_outputs] if not isinstance(task_config, Container): task_config = OmegaConf.create(task_config) if isinstance(prints, str): prints = [prints] prints = [f'f.write({p} + "\\n")' for p in prints] s = string.Template( """import hydra import os from hydra.core.hydra_config import HydraConfig $PROLOG @hydra.main(version_base=None, config_path='.', config_name='config') def experiment(cfg): with open("$OUTPUT_FILE", "w") as f: $PRINTS if __name__ == "__main__": experiment() """ ) print_code = _get_statements(indent=" ", statements=prints) prolog_code = _get_statements(indent="", statements=prolog) if task_config is not None: cfg_file = tmpdir / "config.yaml" with open(str(cfg_file), "w") as f: f.write("# @package _global_\n") OmegaConf.save(task_config, f) output_file = str(tmpdir / "output.txt") # replace Windows path separator \ with an escaped version \\ output_file = output_file.replace("\\", "\\\\") code = s.substitute( PRINTS=print_code, OUTPUT_FILE=output_file, PROLOG=prolog_code, ) task_file = tmpdir / filename task_file.write_text(str(code), encoding="utf-8") cmd = [sys.executable, str(task_file)] orig_dir = os.getcwd() try: os.chdir(str(tmpdir)) cmd = generate_custom_cmd(cmd, filename) cmd.extend(overrides) if clean_environment: modified_env = {} else: modified_env = os.environ.copy() if env_override is not None: modified_env.update(env_override) subprocess.check_call(cmd, env=modified_env) with open(output_file, "r") as f: file_str = f.read() output = str.splitlines(file_str) if expected_outputs is not None: assert len(output) == len( expected_outputs ), f"Unexpected number of output lines from {task_file}, output lines:\n\n{file_str}" for idx in range(len(output)): assert_regex_match(expected_outputs[idx], output[idx]) # some tests are parsing the file output for more specialized testing. return file_str finally: os.chdir(orig_dir) def run_with_error(cmd: Any, env: Any = None) -> str: cmd = [sys.executable, "-Werror"] + cmd with Popen(cmd, stdout=PIPE, stderr=PIPE, env=env) as p: _stdout, stderr = p.communicate() err = stderr.decode("utf-8").rstrip().replace("\r\n", "\n") assert p.returncode != 0 return err def run_python_script( cmd: Any, env: Any = None, allow_warnings: bool = False, print_error: bool = True, raise_exception: bool = True, ) -> Tuple[str, str]: if allow_warnings: cmd = [sys.executable] + cmd else: cmd = [sys.executable, "-Werror"] + cmd return run_process(cmd, env, print_error, raise_exception) def run_process( cmd: Any, env: Any = None, print_error: bool = True, raise_exception: bool = True, timeout: Optional[float] = None, ) -> Tuple[str, str]: try: process = subprocess.Popen( args=cmd, shell=False, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) bstdout, bstderr = process.communicate(timeout=timeout) stdout = normalize_newlines(bstdout.decode().rstrip()) stderr = normalize_newlines(bstderr.decode().rstrip()) if process.returncode != 0: if print_error: sys.stderr.write(f"Subprocess error:\n{stderr}\n") if raise_exception: raise subprocess.CalledProcessError( returncode=process.returncode, cmd=cmd ) return stdout, stderr except Exception as e: if print_error: cmd = " ".join(cmd) sys.stderr.write(f"=== Error executing:\n{cmd}\n===================") raise e def normalize_newlines(s: str) -> str: """ Normalizes new lines such they are comparable across different operating systems :param s: :return: """ return s.replace("\r\n", "\n").replace("\r", "\n") def assert_text_same( from_line: str, to_line: str, from_name: str = "Expected", to_name: str = "Actual" ) -> None: from_line = normalize_newlines(from_line) to_line = normalize_newlines(to_line) lines = [ line for line in unified_diff( a=from_line.splitlines(), b=to_line.splitlines(), fromfile=from_name, tofile=to_name, ) ] diff = "\n".join(lines) if len(diff) > 0: print("\n------------ DIFF -------------") print(diff) print("-------------------------------") assert False, "Mismatch between expected and actual text" def assert_regex_match( from_line: str, to_line: str, from_name: str = "Expected", to_name: str = "Actual" ) -> None: """Check that the lines of `from_line` (which can be a regex expression) matches the corresponding lines of `to_line` string. In case the regex match fails, we display the diff as if `from_line` was a regular string. """ normalized_from_line = [x for x in normalize_newlines(from_line).split("\n") if x] normalized_to_line = [x for x in normalize_newlines(to_line).split("\n") if x] if len(normalized_from_line) != len(normalized_to_line): assert_text_same( from_line=from_line, to_line=to_line, from_name=from_name, to_name=to_name, ) for line1, line2 in zip(normalized_from_line, normalized_to_line): if line1 != line2 and re.match(line1, line2) is None: assert_text_same( from_line=from_line, to_line=to_line, from_name=from_name, to_name=to_name, ) def assert_multiline_regex_search( pattern: str, string: str, from_name: str = "Expected", to_name: str = "Actual" ) -> None: """Check that `pattern` (which can be a regex expression) matches the corresponding lines of `string` string. In case the regex match fails, we display the diff as if `pattern` was a regular string. """ pattern = normalize_newlines(pattern) string = normalize_newlines(string) if re.search(pattern, string, flags=re.MULTILINE) is None: print("\n-------- PATTERN: -----------") print(pattern) print("---------- STRING: ------------") print(string) print("-------------------------------") assert False, "Regex pattern did not match"
Memory