This commit is contained in:
liminfei-amd 2026-06-22 05:42:03 +00:00 committed by GitHub
commit e2e8dae641
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -432,6 +432,98 @@ def is_amd():
return True return True
return False return False
def is_integrated_gpu(device=None):
# AMD APUs / integrated GPUs expose host RAM (GTT/shared) as device memory
# via mem_get_info(); torch flags these as integrated. See ComfyUI #14274.
if cpu_state != CPUState.GPU:
return False
if not (is_nvidia() or is_amd()):
return False
try:
if device is None:
device = get_torch_device()
return bool(getattr(torch.cuda.get_device_properties(device), "is_integrated", 0))
except Exception:
return False
def _amd_vram_gtt_totals(device=None):
# Best-effort (vram_total, gtt_total) in bytes from the amdgpu sysfs nodes
# mem_info_vram_total / mem_info_gtt_total, or None when they cannot be read
# (e.g. NVIDIA Tegra integrated parts that have no dedicated VRAM). #14274
if not is_amd():
return None
try:
drm_root = "/sys/class/drm"
candidates = []
for name in os.listdir(drm_root):
if not (name.startswith("card") and name[len("card"):].isdigit()):
continue
dev_dir = os.path.join(drm_root, name, "device")
vram_path = os.path.join(dev_dir, "mem_info_vram_total")
gtt_path = os.path.join(dev_dir, "mem_info_gtt_total")
if not (os.path.exists(vram_path) and os.path.exists(gtt_path)):
continue
try:
with open(os.path.join(dev_dir, "vendor")) as vf:
if vf.read().strip().lower() != "0x1002":
continue
except OSError:
pass
candidates.append((os.path.basename(os.path.realpath(dev_dir)), vram_path, gtt_path))
if not candidates:
return None
chosen = None
target_bdf = None
try:
if device is None:
device = get_torch_device()
props = torch.cuda.get_device_properties(device)
# torch reports the PCI location as integers (pci_domain_id / pci_bus_id
# / pci_device_id); amdgpu names its sysfs nodes as a hex
# "domain:bus:device.function" BDF. Build the canonical hex BDF so the
# two are comparable (the old str(pci_bus_id) compared a decimal bus
# number against a hex BDF string and could never match). #14274
target_bdf = "%04x:%02x:%02x" % (
int(getattr(props, "pci_domain_id", 0) or 0),
int(getattr(props, "pci_bus_id", 0) or 0),
int(getattr(props, "pci_device_id", 0) or 0),
)
except Exception:
target_bdf = None
if target_bdf:
for pci, vram_path, gtt_path in candidates:
# candidates carry the realpath() leaf BDF (domain:bus:device.function),
# so matching the domain:bus:device part works whether the GPU is
# attached directly or sits behind a PCIe bridge (nested sysfs path). #14274
if pci.lower().rsplit(".", 1)[0] == target_bdf:
chosen = (vram_path, gtt_path)
break
if chosen is None and len(candidates) == 1:
chosen = (candidates[0][1], candidates[0][2])
if chosen is None:
return None
with open(chosen[0]) as f:
vram_total = int(f.read().strip())
with open(chosen[1]) as f:
gtt_total = int(f.read().strip())
return (vram_total, gtt_total)
except Exception:
return None
def integrated_gpu_is_shared_heavy(device=None):
# For an integrated GPU, decide whether its memory is dominated by the shared
# GTT/host-RAM aperture (treat as UMA -> SHARED) or by a large dedicated VRAM
# carveout (keep NORMAL/HIGH_VRAM). Keys on the amdgpu mem_info_vram_total vs
# mem_info_gtt_total ratio (ComfyUI #14274). Defaults to True when the totals
# are unavailable (e.g. NVIDIA Tegra parts that have no dedicated VRAM).
totals = _amd_vram_gtt_totals(device)
if totals is None:
return True
vram_total, gtt_total = totals
if not vram_total or vram_total <= 0:
return True
return gtt_total >= vram_total
def amd_min_version(device=None, min_rdna_version=0): def amd_min_version(device=None, min_rdna_version=0):
if not is_amd(): if not is_amd():
return False return False
@ -569,6 +661,15 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED vram_state = VRAMState.SHARED
if vram_state == VRAMState.NORMAL_VRAM and is_integrated_gpu() and integrated_gpu_is_shared_heavy():
# Integrated/UMA GPU whose shared GTT/host-RAM pool dominates the (small)
# dedicated VRAM carveout: treat as UMA and use SHARED so the shared pool is
# not double-counted as dedicated VRAM (#14274). Dedicated-heavy integrated
# parts (large BIOS UMA carveout, e.g. Strix Halo) keep NORMAL_VRAM where
# HIGH_VRAM is correct.
vram_state = VRAMState.SHARED
logging.info("Integrated GPU with shared-memory-dominant pool detected (UMA): using SHARED vram state to avoid double-counting GTT/shared memory as dedicated VRAM.")
logging.info(f"Set vram state to: {vram_state.name}") logging.info(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory DISABLE_SMART_MEMORY = args.disable_smart_memory