diff --git a/server.py b/server.py index 44470b904..268441bd1 100644 --- a/server.py +++ b/server.py @@ -646,18 +646,37 @@ class PromptServer(): @routes.get("/system_stats") async def system_stats(request): - device = comfy.model_management.get_torch_device() - device_name = comfy.model_management.get_torch_device_name(device) + primary_device = comfy.model_management.get_torch_device() cpu_device = comfy.model_management.torch.device("cpu") ram_total = comfy.model_management.get_total_memory(cpu_device) ram_free = comfy.model_management.get_free_memory(cpu_device) - vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) - vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) required_frontend_version = FrontendManager.get_required_frontend_version() installed_templates_version = FrontendManager.get_installed_templates_version() required_templates_version = FrontendManager.get_required_templates_version() comfy_package_versions = FrontendManager.get_comfy_package_versions() + # Report every torch device visible to multigpu, with the primary + # device first so existing clients that read devices[0] keep working. + torch_devices = comfy.model_management.get_all_torch_devices() + if primary_device in torch_devices: + torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device] + else: + torch_devices = [primary_device] + list(torch_devices) + + device_entries = [] + for d in torch_devices: + vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True) + device_entries.append({ + "name": comfy.model_management.get_torch_device_name(d), + "type": d.type, + "index": d.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + }) + system_stats = { "system": { "os": sys.platform, @@ -673,17 +692,7 @@ class PromptServer(): "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "argv": sys.argv }, - "devices": [ - { - "name": device_name, - "type": device.type, - "index": device.index, - "vram_total": vram_total, - "vram_free": vram_free, - "torch_vram_total": torch_vram_total, - "torch_vram_free": torch_vram_free, - } - ] + "devices": device_entries } return web.json_response(system_stats)