mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
Merge 52265120dc into 3086026401
This commit is contained in:
commit
6f212c975c
61
main.py
61
main.py
@ -5,6 +5,7 @@ import os
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import shutil
|
import shutil
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
import sysconfig
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
@ -29,6 +30,66 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
|
||||||
|
def is_rocm_torch_build():
|
||||||
|
try:
|
||||||
|
torch_spec = importlib.util.find_spec("torch")
|
||||||
|
if torch_spec is not None:
|
||||||
|
for folder in torch_spec.submodule_search_locations or []:
|
||||||
|
version_file = os.path.join(folder, "version.py")
|
||||||
|
if os.path.isfile(version_file):
|
||||||
|
spec = importlib.util.spec_from_file_location("torch_version_probe", version_file)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
if bool(getattr(module, "hip", None)):
|
||||||
|
return True
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch_version = importlib.metadata.version("torch").lower()
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
return False
|
||||||
|
return "rocm" in torch_version
|
||||||
|
|
||||||
|
|
||||||
|
def configure_windows_rocm_sdk_env():
|
||||||
|
if os.name != "nt":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not is_rocm_torch_build():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if any(os.environ.get(var) for var in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH")):
|
||||||
|
return None
|
||||||
|
|
||||||
|
platlib = sysconfig.get_paths().get("platlib")
|
||||||
|
if not platlib:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rocm_sdk_root = os.path.join(platlib, "_rocm_sdk_core")
|
||||||
|
hip_header = os.path.join(rocm_sdk_root, "include", "hip", "hip_runtime.h")
|
||||||
|
if not os.path.isfile(hip_header):
|
||||||
|
return None
|
||||||
|
|
||||||
|
os.environ["ROCM_HOME"] = rocm_sdk_root
|
||||||
|
os.environ["ROCM_PATH"] = rocm_sdk_root
|
||||||
|
os.environ["HIP_PATH"] = rocm_sdk_root
|
||||||
|
|
||||||
|
rocm_bin = os.path.join(rocm_sdk_root, "bin")
|
||||||
|
if os.path.isdir(rocm_bin):
|
||||||
|
path_entries = os.environ.get("PATH", "").split(os.pathsep)
|
||||||
|
if rocm_bin not in path_entries:
|
||||||
|
os.environ["PATH"] = rocm_bin + os.pathsep + os.environ.get("PATH", "")
|
||||||
|
|
||||||
|
return rocm_sdk_root
|
||||||
|
|
||||||
|
|
||||||
|
configured_rocm_sdk_root = configure_windows_rocm_sdk_env()
|
||||||
|
if configured_rocm_sdk_root is not None:
|
||||||
|
logging.info("Configured ROCm SDK environment from Python package path: %s", configured_rocm_sdk_root)
|
||||||
|
|
||||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||||
|
|
||||||
import comfy_aimdo.control
|
import comfy_aimdo.control
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user