From 8969bbbf0252080b3e5186e2150d626dd98cd5b4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 25 May 2026 19:56:07 -0700 Subject: [PATCH] multigpu: shorten _force_supported_compute_dtype docstring Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 0bd5f2995..878d85baf 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -49,18 +49,11 @@ class MultiGPUCFGSplitNode(io.ComfyNode): def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device): - """Ensure the patcher's compute dtype is one the target device actually supports. + """Cast compute dtype to one the device supports; no-op if already supported. - Defers to comfy.model_management.unet_manual_cast, which already encodes - per-device dtype support (CPU returns False for fp16/bf16, older GPUs may - not support bf16, pre-14 MPS doesn't support bf16, etc.). It returns None - when the weight dtype is already fine and the cast dtype otherwise. - - Concrete motivation: PyTorch's CPU conv2d kernels emulate fp16/bf16 in - software (~500-600x slower than fp32), so SelectModelDevice -> CPU on an - fp16 model would otherwise look frozen for hours. Routing through - set_model_compute_dtype leaves the weights as-is and casts at use, so peak - memory does not blow up.""" + Uses unet_manual_cast which encodes per-device dtype support (e.g. CPU + rejects fp16/bf16, falling back to fp32 to avoid PyTorch's ~500-600x + slower software emulation).""" weight_dtype = patcher.model_dtype() cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device) if cast_dtype is None: