1
1
import sys
2
2
import warnings
3
3
import os
4
+ import threading
4
5
import glob
5
6
from packaging .version import parse , Version
6
7
@@ -859,6 +860,44 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
859
860
)
860
861
861
862
863
+ # Patch because `setup.py bdist_wheel` and `setup.py develop` do not support the `parallel` option
864
+ parallel = None
865
+ if "--parallel" in sys .argv :
866
+ idx = sys .argv .index ("--parallel" )
867
+ parallel = int (sys .argv [idx + 1 ])
868
+ sys .argv .pop (idx + 1 )
869
+ sys .argv .pop (idx )
870
+
871
+
872
+ # Prevent file conflicts when multiple extensions are compiled simultaneously
873
+ class BuildExtensionSeparateDir (BuildExtension ):
874
+ build_extension_patch_lock = threading .Lock ()
875
+ thread_ext_name_map = {}
876
+
877
+ def finalize_options (self ):
878
+ if parallel is not None :
879
+ self .parallel = parallel
880
+ super ().finalize_options ()
881
+
882
+ def build_extension (self , ext ):
883
+ with self .build_extension_patch_lock :
884
+ if not getattr (self .compiler , "_compile_separate_output_dir" , False ):
885
+ compile_orig = self .compiler .compile
886
+
887
+ def compile_new (* args , ** kwargs ):
888
+ return compile_orig (* args , ** {
889
+ ** kwargs ,
890
+ "output_dir" : os .path .join (
891
+ kwargs ["output_dir" ],
892
+ self .thread_ext_name_map [threading .current_thread ().ident ]),
893
+ })
894
+ self .compiler .compile = compile_new
895
+ self .compiler ._compile_separate_output_dir = True
896
+ self .thread_ext_name_map [threading .current_thread ().ident ] = ext .name
897
+ objects = super ().build_extension (ext )
898
+ return objects
899
+
900
+
862
901
setup (
863
902
name = "apex" ,
864
903
version = "0.1" ,
@@ -868,6 +907,6 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
868
907
install_requires = ["packaging>20.6" ],
869
908
description = "PyTorch Extensions written by NVIDIA" ,
870
909
ext_modules = ext_modules ,
871
- cmdclass = {"build_ext" : BuildExtension } if ext_modules else {},
910
+ cmdclass = {"build_ext" : BuildExtensionSeparateDir } if ext_modules else {},
872
911
extras_require = extras ,
873
912
)
0 commit comments