From 4272b38867aefe162708fc9bada14868b5141b19 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Tue, 25 Feb 2025 10:54:42 +0800 Subject: [PATCH] Support specifying a torch range --- setup.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 7eb94528af..001877e23c 100644 --- a/setup.py +++ b/setup.py @@ -39,9 +39,12 @@ def _make_version_file(version, sha): def _get_pytorch_version(): - if "PYTORCH_VERSION" in os.environ: - return f"torch=={os.environ['PYTORCH_VERSION']}" - return "torch" + pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch") + if version_pin := os.getenv("PYTORCH_VERSION"): + pytorch_dep += "==" + version_pin + elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")): + pytorch_dep += f">={version_pin_ge},<{version_pin_lt}" + return pytorch_dep class clean(distutils.command.clean.clean):