# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Start multiple process locally for DDP.
"""
from functools import partial
import os
import subprocess as sp
import sys
import typing as tp
from .log import simple_log, fatal
log = partial(simple_log, "Executor:")
class ChildrenManager:
def __init__(self):
self.children = []
self.failed = False
def add(self, child):
child.rank = len(self.children)
self.children.append(child)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_value is not None:
log(f"An exception happened while starting workers {exc_value}")
self.failed = True
try:
while self.children and not self.failed:
for child in list(self.children):
try:
exitcode = child.wait(0.05)
except sp.TimeoutExpired:
continue
else:
self.children.remove(child)
if exitcode:
log(f"Worker {child.rank} died, killing all workers")
self.failed = True
except KeyboardInterrupt:
log("Received keyboard interrupt, trying to kill all workers.")
self.failed = True
for child in self.children:
child.terminate()
if not self.failed:
log("All workers completed successfully")
def start_ddp_workers(main, argv, num_workers: tp.Optional[int] = None):
import torch as th
world_size = num_workers or th.cuda.device_count()
if not world_size:
fatal(
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
sys.exit(1)
xp = main.get_xp(argv)
xp.folder.mkdir(exist_ok=True, parents=True)
if xp.rendezvous_file.exists():
xp.rendezvous_file.unlink()
log(f"Starting {world_size} worker processes for DDP.")
with ChildrenManager() as manager:
for rank in range(world_size):
kwargs: tp.Dict[str, tp.Any] = {}
env = dict(os.environ)
env['RANK'] = str(rank)
env['WORLD_SIZE'] = str(world_size)
env['MASTER_ADDR'] = '127.0.0.1'
args = ["-m", "dora", "-P", main.package, "--main_module", main.main_module,
"run", "--"]
args += argv
if rank > 0:
kwargs['stdin'] = sp.DEVNULL
kwargs['stdout'] = open(xp.folder / f'worker_{rank}.log', 'w')
kwargs['stderr'] = sp.STDOUT
manager.add(
sp.Popen([sys.executable] + args, env=env, **kwargs))
sys.exit(int(manager.failed))