mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
multigpu: shorten _force_supported_compute_dtype docstring
Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
4ca4d39076
commit
8969bbbf02
@ -49,18 +49,11 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
|
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
|
||||||
"""Ensure the patcher's compute dtype is one the target device actually supports.
|
"""Cast compute dtype to one the device supports; no-op if already supported.
|
||||||
|
|
||||||
Defers to comfy.model_management.unet_manual_cast, which already encodes
|
Uses unet_manual_cast which encodes per-device dtype support (e.g. CPU
|
||||||
per-device dtype support (CPU returns False for fp16/bf16, older GPUs may
|
rejects fp16/bf16, falling back to fp32 to avoid PyTorch's ~500-600x
|
||||||
not support bf16, pre-14 MPS doesn't support bf16, etc.). It returns None
|
slower software emulation)."""
|
||||||
when the weight dtype is already fine and the cast dtype otherwise.
|
|
||||||
|
|
||||||
Concrete motivation: PyTorch's CPU conv2d kernels emulate fp16/bf16 in
|
|
||||||
software (~500-600x slower than fp32), so SelectModelDevice -> CPU on an
|
|
||||||
fp16 model would otherwise look frozen for hours. Routing through
|
|
||||||
set_model_compute_dtype leaves the weights as-is and casts at use, so peak
|
|
||||||
memory does not blow up."""
|
|
||||||
weight_dtype = patcher.model_dtype()
|
weight_dtype = patcher.model_dtype()
|
||||||
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
|
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
|
||||||
if cast_dtype is None:
|
if cast_dtype is None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user