diff --git a/requirements.txt b/requirements.txt index 9cc2ec676..f9a728992 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,4 @@ scipy tqdm diffusers>=0.16.1 protobuf==3.20.3 -rembg[gpu] +rembg diff --git a/setup.py b/setup.py index 1c597b390..9c832a36c 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,11 @@ cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu" # xformers not required for new torch +""" +Packages that should have a specific option set when a GPU accelerator is present +""" +gpu_accelerated_packages = {"rembg": "rembg[gpu]"} + def _is_nvidia() -> bool: system = platform.system().lower() @@ -92,12 +97,15 @@ def dependencies() -> [str]: _alternative_indices = [amd_torch_index, nvidia_torch_index, cpu_torch_index_nightlies] session = PipSession() + gpu_accelerated = False index_urls = ['https://pypi.org/simple'] # prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device if _is_nvidia(): index_urls += [nvidia_torch_index] + gpu_accelerated = True elif _is_amd(): index_urls += [amd_torch_index] + gpu_accelerated = True else: index_urls += [cpu_torch_index_nightlies] @@ -121,8 +129,11 @@ def dependencies() -> [str]: for i, package in enumerate(_dependencies[:]): requirement = InstallRequirement(Requirement(package), comes_from=f"{package_name}=={version}") candidate = finder.find_best_candidate(requirement.name, requirement.specifier) - if any([url in candidate.best_candidate.link.url for url in _alternative_indices]): - _dependencies[i] = f"{requirement.name} @ {candidate.best_candidate.link.url}" + if candidate.best_candidate is not None: + if gpu_accelerated and requirement.name in gpu_accelerated_packages: + _dependencies[i] = gpu_accelerated_packages[requirement.name] + if any([url in candidate.best_candidate.link.url for url in _alternative_indices]): + _dependencies[i] = f"{requirement.name} @ {candidate.best_candidate.link.url}" return _dependencies @@ -137,7 +148,7 @@ setup( # todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins packages=find_packages(where=".", include=['comfy', 'comfy_extras']), install_requires=dependencies(), - setup_requires=["wheel"], + setup_requires=["pip", "wheel"], entry_points={ 'console_scripts': [ # todo: eventually migrate main here