diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 2bd752b7d..0bd5f2995 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -48,17 +48,25 @@ class MultiGPUCFGSplitNode(io.ComfyNode): return io.NodeOutput(model) -def _force_fp32_cpu_compute(patcher: ModelPatcher): - """Force fp32 inference dtype for CPU. +def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device): + """Ensure the patcher's compute dtype is one the target device actually supports. - PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16 - and run ~500-600x slower than fp32, which makes a normal-sized workflow - look frozen for hours. Routing through set_model_compute_dtype leaves the - weights as-is and casts at use, so peak memory does not blow up.""" - dtype = patcher.model_dtype() - if dtype in (torch.float16, torch.bfloat16): - logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).") - patcher.set_model_compute_dtype(torch.float32) + Defers to comfy.model_management.unet_manual_cast, which already encodes + per-device dtype support (CPU returns False for fp16/bf16, older GPUs may + not support bf16, pre-14 MPS doesn't support bf16, etc.). It returns None + when the weight dtype is already fine and the cast dtype otherwise. + + Concrete motivation: PyTorch's CPU conv2d kernels emulate fp16/bf16 in + software (~500-600x slower than fp32), so SelectModelDevice -> CPU on an + fp16 model would otherwise look frozen for hours. Routing through + set_model_compute_dtype leaves the weights as-is and casts at use, so peak + memory does not blow up.""" + weight_dtype = patcher.model_dtype() + cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device) + if cast_dtype is None: + return + logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).") + patcher.set_model_compute_dtype(cast_dtype) def _remember_base_devices(patcher: ModelPatcher): @@ -229,8 +237,7 @@ class SelectModelDeviceNode(io.ComfyNode): logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") return io.NodeOutput(model) if resolved is not None: - if resolved.type == "cpu": - _force_fp32_cpu_compute(model) + _force_supported_compute_dtype(model, resolved) _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model)