mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 16:27:33 +08:00
multigpu: use unet_manual_cast for SelectModelDevice compute dtype
Replace the hardcoded `_force_fp32_cpu_compute` helper with`_force_supported_compute_dtype`, which delegates to`comfy.model_management.unet_manual_cast(weight_dtype, device)`. The interrogator already encodes per-device dtype support (CPU returns False for fp16/bf16, older GPUs may not support bf16, pre-14 MPS doesn't support bf16, etc.) and returns None when no cast is needed.For SelectModelDevice -> CPU on an fp16/bf16 model, behavior is unchanged: `unet_manual_cast` returns `torch.float32` and `set_model_compute_dtype` casts at use without touching peak memory. As a bonus the same code path now handles other `weight_dtype not supported on device` cases (e.g. bf16 weights on pre-Ampere NVIDIA, bf16 on pre-macOS-14 MPS) without growing the code surface, so the call site no longer needs the `if resolved.type == 'cpu':` gate. Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
da49b7d0b6
commit
4ca4d39076
@ -48,17 +48,25 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _force_fp32_cpu_compute(patcher: ModelPatcher):
|
||||
"""Force fp32 inference dtype for CPU.
|
||||
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
|
||||
"""Ensure the patcher's compute dtype is one the target device actually supports.
|
||||
|
||||
PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
|
||||
and run ~500-600x slower than fp32, which makes a normal-sized workflow
|
||||
look frozen for hours. Routing through set_model_compute_dtype leaves the
|
||||
weights as-is and casts at use, so peak memory does not blow up."""
|
||||
dtype = patcher.model_dtype()
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).")
|
||||
patcher.set_model_compute_dtype(torch.float32)
|
||||
Defers to comfy.model_management.unet_manual_cast, which already encodes
|
||||
per-device dtype support (CPU returns False for fp16/bf16, older GPUs may
|
||||
not support bf16, pre-14 MPS doesn't support bf16, etc.). It returns None
|
||||
when the weight dtype is already fine and the cast dtype otherwise.
|
||||
|
||||
Concrete motivation: PyTorch's CPU conv2d kernels emulate fp16/bf16 in
|
||||
software (~500-600x slower than fp32), so SelectModelDevice -> CPU on an
|
||||
fp16 model would otherwise look frozen for hours. Routing through
|
||||
set_model_compute_dtype leaves the weights as-is and casts at use, so peak
|
||||
memory does not blow up."""
|
||||
weight_dtype = patcher.model_dtype()
|
||||
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
|
||||
if cast_dtype is None:
|
||||
return
|
||||
logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
|
||||
patcher.set_model_compute_dtype(cast_dtype)
|
||||
|
||||
|
||||
def _remember_base_devices(patcher: ModelPatcher):
|
||||
@ -229,8 +237,7 @@ class SelectModelDeviceNode(io.ComfyNode):
|
||||
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(model)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
_force_fp32_cpu_compute(model)
|
||||
_force_supported_compute_dtype(model, resolved)
|
||||
_prune_multigpu_collision(model, model.load_device)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user