mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
multigpu: use unet_manual_cast for SelectModelDevice compute dtype (#14108)
This commit is contained in:
parent
da49b7d0b6
commit
88956e77af
@ -48,17 +48,14 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
def _force_fp32_cpu_compute(patcher: ModelPatcher):
|
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
|
||||||
"""Force fp32 inference dtype for CPU.
|
"""Cast compute dtype to one the device supports; no-op if already supported."""
|
||||||
|
weight_dtype = patcher.model_dtype()
|
||||||
PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
|
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
|
||||||
and run ~500-600x slower than fp32, which makes a normal-sized workflow
|
if cast_dtype is None:
|
||||||
look frozen for hours. Routing through set_model_compute_dtype leaves the
|
return
|
||||||
weights as-is and casts at use, so peak memory does not blow up."""
|
logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
|
||||||
dtype = patcher.model_dtype()
|
patcher.set_model_compute_dtype(cast_dtype)
|
||||||
if dtype in (torch.float16, torch.bfloat16):
|
|
||||||
logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).")
|
|
||||||
patcher.set_model_compute_dtype(torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def _remember_base_devices(patcher: ModelPatcher):
|
def _remember_base_devices(patcher: ModelPatcher):
|
||||||
@ -229,8 +226,7 @@ class SelectModelDeviceNode(io.ComfyNode):
|
|||||||
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
if resolved is not None:
|
if resolved is not None:
|
||||||
if resolved.type == "cpu":
|
_force_supported_compute_dtype(model, resolved)
|
||||||
_force_fp32_cpu_compute(model)
|
|
||||||
_prune_multigpu_collision(model, model.load_device)
|
_prune_multigpu_collision(model, model.load_device)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user