diff --git a/comfy/model_management.py b/comfy/model_management.py index b15d08ba1..39508cd43 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -432,6 +432,98 @@ def is_amd(): return True return False +def is_integrated_gpu(device=None): + # AMD APUs / integrated GPUs expose host RAM (GTT/shared) as device memory + # via mem_get_info(); torch flags these as integrated. See ComfyUI #14274. + if cpu_state != CPUState.GPU: + return False + if not (is_nvidia() or is_amd()): + return False + try: + if device is None: + device = get_torch_device() + return bool(getattr(torch.cuda.get_device_properties(device), "is_integrated", 0)) + except Exception: + return False + +def _amd_vram_gtt_totals(device=None): + # Best-effort (vram_total, gtt_total) in bytes from the amdgpu sysfs nodes + # mem_info_vram_total / mem_info_gtt_total, or None when they cannot be read + # (e.g. NVIDIA Tegra integrated parts that have no dedicated VRAM). #14274 + if not is_amd(): + return None + try: + drm_root = "/sys/class/drm" + candidates = [] + for name in os.listdir(drm_root): + if not (name.startswith("card") and name[len("card"):].isdigit()): + continue + dev_dir = os.path.join(drm_root, name, "device") + vram_path = os.path.join(dev_dir, "mem_info_vram_total") + gtt_path = os.path.join(dev_dir, "mem_info_gtt_total") + if not (os.path.exists(vram_path) and os.path.exists(gtt_path)): + continue + try: + with open(os.path.join(dev_dir, "vendor")) as vf: + if vf.read().strip().lower() != "0x1002": + continue + except OSError: + pass + candidates.append((os.path.basename(os.path.realpath(dev_dir)), vram_path, gtt_path)) + if not candidates: + return None + chosen = None + target_bdf = None + try: + if device is None: + device = get_torch_device() + props = torch.cuda.get_device_properties(device) + # torch reports the PCI location as integers (pci_domain_id / pci_bus_id + # / pci_device_id); amdgpu names its sysfs nodes as a hex + # "domain:bus:device.function" BDF. Build the canonical hex BDF so the + # two are comparable (the old str(pci_bus_id) compared a decimal bus + # number against a hex BDF string and could never match). #14274 + target_bdf = "%04x:%02x:%02x" % ( + int(getattr(props, "pci_domain_id", 0) or 0), + int(getattr(props, "pci_bus_id", 0) or 0), + int(getattr(props, "pci_device_id", 0) or 0), + ) + except Exception: + target_bdf = None + if target_bdf: + for pci, vram_path, gtt_path in candidates: + # candidates carry the realpath() leaf BDF (domain:bus:device.function), + # so matching the domain:bus:device part works whether the GPU is + # attached directly or sits behind a PCIe bridge (nested sysfs path). #14274 + if pci.lower().rsplit(".", 1)[0] == target_bdf: + chosen = (vram_path, gtt_path) + break + if chosen is None and len(candidates) == 1: + chosen = (candidates[0][1], candidates[0][2]) + if chosen is None: + return None + with open(chosen[0]) as f: + vram_total = int(f.read().strip()) + with open(chosen[1]) as f: + gtt_total = int(f.read().strip()) + return (vram_total, gtt_total) + except Exception: + return None + +def integrated_gpu_is_shared_heavy(device=None): + # For an integrated GPU, decide whether its memory is dominated by the shared + # GTT/host-RAM aperture (treat as UMA -> SHARED) or by a large dedicated VRAM + # carveout (keep NORMAL/HIGH_VRAM). Keys on the amdgpu mem_info_vram_total vs + # mem_info_gtt_total ratio (ComfyUI #14274). Defaults to True when the totals + # are unavailable (e.g. NVIDIA Tegra parts that have no dedicated VRAM). + totals = _amd_vram_gtt_totals(device) + if totals is None: + return True + vram_total, gtt_total = totals + if not vram_total or vram_total <= 0: + return True + return gtt_total >= vram_total + def amd_min_version(device=None, min_rdna_version=0): if not is_amd(): return False @@ -569,6 +661,15 @@ if cpu_state != CPUState.GPU: if cpu_state == CPUState.MPS: vram_state = VRAMState.SHARED +if vram_state == VRAMState.NORMAL_VRAM and is_integrated_gpu() and integrated_gpu_is_shared_heavy(): + # Integrated/UMA GPU whose shared GTT/host-RAM pool dominates the (small) + # dedicated VRAM carveout: treat as UMA and use SHARED so the shared pool is + # not double-counted as dedicated VRAM (#14274). Dedicated-heavy integrated + # parts (large BIOS UMA carveout, e.g. Strix Halo) keep NORMAL_VRAM where + # HIGH_VRAM is correct. + vram_state = VRAMState.SHARED + logging.info("Integrated GPU with shared-memory-dominant pool detected (UMA): using SHARED vram state to avoid double-counting GTT/shared memory as dedicated VRAM.") + logging.info(f"Set vram state to: {vram_state.name}") DISABLE_SMART_MEMORY = args.disable_smart_memory