Skip to content

Commit c9e6f05

Browse files
authored
Add option to build extensions in parallel (#1882)
* build extensions in parallel * fix: setup.py develop supports --parallel * update README.md * Update README.md: Remove an `export` command
1 parent b216eee commit c9e6f05

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation -
135135
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
136136
```
137137

138+
To reduce the build time of APEX, parallel building can be enhanced via
139+
```bash
140+
NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" ./
141+
```
142+
When CPU cores or memory are limited, the `--parallel` option is generally preferred over `--threads`. See [pull#1882](https://github.com/NVIDIA/apex/pull/1882) for more details.
143+
138144
APEX also supports a Python-only build via
139145
```bash
140146
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

setup.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import warnings
33
import os
4+
import threading
45
import glob
56
from packaging.version import parse, Version
67

@@ -859,6 +860,44 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
859860
)
860861

861862

863+
# Patch because `setup.py bdist_wheel` and `setup.py develop` do not support the `parallel` option
864+
parallel = None
865+
if "--parallel" in sys.argv:
866+
idx = sys.argv.index("--parallel")
867+
parallel = int(sys.argv[idx + 1])
868+
sys.argv.pop(idx + 1)
869+
sys.argv.pop(idx)
870+
871+
872+
# Prevent file conflicts when multiple extensions are compiled simultaneously
873+
class BuildExtensionSeparateDir(BuildExtension):
874+
build_extension_patch_lock = threading.Lock()
875+
thread_ext_name_map = {}
876+
877+
def finalize_options(self):
878+
if parallel is not None:
879+
self.parallel = parallel
880+
super().finalize_options()
881+
882+
def build_extension(self, ext):
883+
with self.build_extension_patch_lock:
884+
if not getattr(self.compiler, "_compile_separate_output_dir", False):
885+
compile_orig = self.compiler.compile
886+
887+
def compile_new(*args, **kwargs):
888+
return compile_orig(*args, **{
889+
**kwargs,
890+
"output_dir": os.path.join(
891+
kwargs["output_dir"],
892+
self.thread_ext_name_map[threading.current_thread().ident]),
893+
})
894+
self.compiler.compile = compile_new
895+
self.compiler._compile_separate_output_dir = True
896+
self.thread_ext_name_map[threading.current_thread().ident] = ext.name
897+
objects = super().build_extension(ext)
898+
return objects
899+
900+
862901
setup(
863902
name="apex",
864903
version="0.1",
@@ -868,6 +907,6 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
868907
install_requires=["packaging>20.6"],
869908
description="PyTorch Extensions written by NVIDIA",
870909
ext_modules=ext_modules,
871-
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
910+
cmdclass={"build_ext": BuildExtensionSeparateDir} if ext_modules else {},
872911
extras_require=extras,
873912
)

0 commit comments

Comments
 (0)