model_management: match AMD GPU by canonical PCI BDF, not str(pci_bus_id)

The _amd_vram_gtt_totals() device match compared str(pci_bus_id) against the
sysfs leaf BDF, but torch reports pci_bus_id as a decimal integer while amdgpu
names its nodes as a hex "domain🚌device.function" BDF, so the comparison
never matched. A single-GPU host was rescued by the len(candidates) == 1
fallback; a hybrid / multi-GPU host has no fallback and could fall through to
shared-heavy, demoting a dedicated GPU to SHARED (reported for a GPU sitting
behind a PCIe bridge).

Build the canonical hex BDF from torch's integer pci_domain_id / pci_bus_id /
pci_device_id and compare it against the candidate's realpath leaf BDF (PCI
function stripped). realpath already collapses any bridge chain to the leaf,
so this works for directly-attached, behind-a-bridge, and multi-GPU hosts
alike. The len(candidates) == 1 fallback is kept.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

#14274
This commit is contained in:
liminfei-amd 2026-06-22 13:41:58 +08:00
parent a2db31582f
commit 50d77af3af

View File

@ -473,17 +473,29 @@ def _amd_vram_gtt_totals(device=None):
if not candidates:
return None
chosen = None
bus_id = None
target_bdf = None
try:
if device is None:
device = get_torch_device()
bus_id = getattr(torch.cuda.get_device_properties(device), "pci_bus_id", None)
props = torch.cuda.get_device_properties(device)
# torch reports the PCI location as integers (pci_domain_id / pci_bus_id
# / pci_device_id); amdgpu names its sysfs nodes as a hex
# "domain:bus:device.function" BDF. Build the canonical hex BDF so the
# two are comparable (the old str(pci_bus_id) compared a decimal bus
# number against a hex BDF string and could never match). #14274
target_bdf = "%04x:%02x:%02x" % (
int(getattr(props, "pci_domain_id", 0) or 0),
int(getattr(props, "pci_bus_id", 0) or 0),
int(getattr(props, "pci_device_id", 0) or 0),
)
except Exception:
bus_id = None
if bus_id:
bus_id = str(bus_id).lower()
target_bdf = None
if target_bdf:
for pci, vram_path, gtt_path in candidates:
if pci.lower().endswith(bus_id) or bus_id.endswith(pci.lower()):
# candidates carry the realpath() leaf BDF (domain:bus:device.function),
# so matching the domain:bus:device part works whether the GPU is
# attached directly or sits behind a PCIe bridge (nested sysfs path). #14274
if pci.lower().rsplit(".", 1)[0] == target_bdf:
chosen = (vram_path, gtt_path)
break
if chosen is None and len(candidates) == 1: