diff --git a/comfy/multigpu.py b/comfy/multigpu.py index eff7d0649..e7f5b3d6f 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -131,7 +131,11 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: skip_devices.add(mm.load_device) skip_devices = list(skip_devices) - full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + # Exclude the primary model's actual device, not the global current device: + # after SelectModelDevice(gpu:N) the primary may not live on the process's + # current CUDA device, and excluding the wrong device picks bad extras. + all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False) + full_extra_devices = [d for d in all_devices if d != model.load_device] limit_extra_devices = full_extra_devices[:max_gpus-1] extra_devices = limit_extra_devices.copy() # exclude skipped devices @@ -143,16 +147,30 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: for device in extra_devices: device_patcher = None if reuse_loaded: - # check if there are any ModelPatchers currently loaded that could be referenced here after a clone + # Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice + # patcher on the same device shares clone_base_uuid but has + # is_multigpu_base_clone=False, which would later be filtered out by + # prepare_model_patcher_multigpu_clones() and silently shrink the + # work split back to one GPU. loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() for lm in loaded_models: - if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device: - device_patcher = lm.clone() - logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") - break + if lm.model is None: + continue + if lm.load_device != device: + continue + if lm.clone_base_uuid != model.clone_base_uuid: + continue + if not getattr(lm, "is_multigpu_base_clone", False): + continue + device_patcher = lm.clone() + logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}") + break if device_patcher is None: device_patcher = model.deepclone_multigpu(new_load_device=device) - device_patcher.is_multigpu_base_clone = True + # Always flag the clone; whether reused or freshly deepcloned, it must + # advertise itself as a MultiGPU base clone so the cond scheduler picks + # it up in prepare_model_patcher_multigpu_clones(). + device_patcher.is_multigpu_base_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) model.set_additional_models("multigpu", multigpu_models) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index d39cca3f8..2bd752b7d 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,6 +11,8 @@ from comfy_api.latest import ComfyExtension, io if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP, VAE +import torch + import comfy.model_management import comfy.multigpu @@ -46,6 +48,19 @@ class MultiGPUCFGSplitNode(io.ComfyNode): return io.NodeOutput(model) +def _force_fp32_cpu_compute(patcher: ModelPatcher): + """Force fp32 inference dtype for CPU. + + PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16 + and run ~500-600x slower than fp32, which makes a normal-sized workflow + look frozen for hours. Routing through set_model_compute_dtype leaves the + weights as-is and casts at use, so peak memory does not blow up.""" + dtype = patcher.model_dtype() + if dtype in (torch.float16, torch.bfloat16): + logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).") + patcher.set_model_compute_dtype(torch.float32) + + def _remember_base_devices(patcher: ModelPatcher): """Stash the original load/offload device on the underlying model. @@ -214,6 +229,8 @@ class SelectModelDeviceNode(io.ComfyNode): logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") return io.NodeOutput(model) if resolved is not None: + if resolved.type == "cpu": + _force_fp32_cpu_compute(model) _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model)