From 50d1dd6273be924d5945f52e9f218ed22c4154a1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:23 -0700 Subject: [PATCH 1/9] Fix MultiGPU Options node discarding cloned GPUOptionsGroup GPUOptionsGroup.clone() returns a new instance, but the return value was discarded, causing the node to mutate the upstream caller's group in-place. When multiple MultiGPU Options nodes share an input group, each node's additions would leak into earlier siblings. Assign the clone result back to gpu_options so each node owns its own copy. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 5d24952bf..53b50029e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -68,7 +68,8 @@ class MultiGPUOptionsNode(io.ComfyNode): def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: if not gpu_options: gpu_options = comfy.multigpu.GPUOptionsGroup() - gpu_options.clone() + else: + gpu_options = gpu_options.clone() opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) gpu_options.add(opt) From 9a681ccfc9d70f1797d1df0dd6e87eee4caf4b21 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:31 -0700 Subject: [PATCH 2/9] Guard cached_patcher_init when output_model is False load_checkpoint_guess_config_clip_only() calls load_checkpoint_guess_config() with output_model=False, leaving out[0] as None. The subsequent unconditional assignment of cached_patcher_init crashed with AttributeError, breaking CLIP-only checkpoint loading entirely. Guard the assignment with a None check. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/sd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index e7857bf0a..481c87cb1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1688,7 +1688,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + if out[0] is not None: + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): From ba417750a73d93c035485f71f56e5a3b146c111c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:38 -0700 Subject: [PATCH 3/9] Fix get_all_torch_devices for XPU/NPU and guard remove() torch.device(i) defaults to CUDA, so XPU/NPU branches were producing 'cuda:N' devices that don't match get_torch_device() output ('xpu:N'/'npu:N'). This caused devices.remove(get_torch_device()) to raise ValueError when exclude_current=True on non-NVIDIA hardware. Use explicit device strings, and guard the remove() with a membership check for safety. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/model_management.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2e168f363..10b982868 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -215,17 +215,19 @@ def get_all_torch_devices(exclude_current=False): if cpu_state == CPUState.GPU: if is_nvidia(): for i in range(torch.cuda.device_count()): - devices.append(torch.device(i)) + devices.append(torch.device("cuda", i)) elif is_intel_xpu(): for i in range(torch.xpu.device_count()): - devices.append(torch.device(i)) + devices.append(torch.device("xpu", i)) elif is_ascend_npu(): for i in range(torch.npu.device_count()): - devices.append(torch.device(i)) + devices.append(torch.device("npu", i)) else: devices.append(get_torch_device()) if exclude_current: - devices.remove(get_torch_device()) + current = get_torch_device() + if current in devices: + devices.remove(current) return devices def get_gpu_device_options(): From dd85851efec772298772f159e2134cea45bd1b3e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:45 -0700 Subject: [PATCH 4/9] Prune inherited multigpu clones when max_gpus is lowered create_multigpu_deepclones cloned the existing 'multigpu' additional_models list verbatim and never pruned entries beyond limit_extra_devices. If a workflow was previously prepared for more GPUs, reducing max_gpus would leave stale clones attached and eligible for later scheduling. Replace the TODO block with a real prune that keeps only clones whose load_device is either the model's load_device or in limit_extra_devices, and re-match clones if anything was removed. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/multigpu.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 096270c12..eff7d0649 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -162,16 +162,16 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") - # TODO: only keep model clones that don't go 'past' the intended max_gpu count - # multigpu_models = model.get_additional_models_with_key("multigpu") - # new_multigpu_models = [] - # for m in multigpu_models: - # if m.load_device in limit_extra_devices: - # new_multigpu_models.append(m) - # model.set_additional_models("multigpu", new_multigpu_models) - # persist skip_devices for use in sampling code - # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: - # model.model_options["multigpu_skip_devices"] = skip_devices + # only keep model clones that don't go 'past' the intended max_gpu count; + # this prunes any inherited multigpu clones whose load_device is no longer allowed + # when max_gpus is lowered between runs. + allowed_devices = set(limit_extra_devices) + allowed_devices.add(model.load_device) + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices] + if len(new_multigpu_models) != len(multigpu_models): + model.set_additional_models("multigpu", new_multigpu_models) + model.match_multigpu_clones() return model From ac0a90c323735333346397ff2d9b7bf493b531d3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 19:52:03 -0700 Subject: [PATCH 5/9] Use cond_shapes in multigpu memory-fit check (parity with single-GPU path) The multigpu cond-batching loop called model.memory_required(input_shape) without conditioning shapes, while the single-GPU path at line 279 passes cond_shapes. Large conditioning tensors (e.g. video prompts, control inputs) were therefore under-counted, risking OOM at runtime when the chosen batch size was too large. Match the single-GPU pattern by building cond_shapes from each batched cond's conditioning dict and passing it to memory_required. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/samplers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f0d67cb7e..a99af5217 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -433,7 +433,11 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: to_batch = batch_amount break From 4d9106dcedbecb3df8c98a9cd05cfa8fdb3fd862 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 20:48:59 -0700 Subject: [PATCH 6/9] Document --cuda-device comma format and MultiGPU Options relative_speed gap Two doc-only changes addressing minor CodeRabbit findings on PR #7063: * cli_args.py: clarify --cuda-device help text to document the required comma-separated format ('0' or '0,1'), matching how the value is consumed by CUDA_VISIBLE_DEVICES in main.py. * nodes_multigpu.py: add a docstring NOTE on the (currently unregistered) MultiGPUOptionsNode explaining that its relative_speed input is plumbed through to model_options['multigpu_options'] but is not yet consulted by the cond scheduler, which still uses uniform round-robin via next_available_device(). Wire relative_speed into the scheduler before re-enabling the node. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/cli_args.py | 2 +- comfy_extras/nodes_multigpu.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index df3841871..3a14a470d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 53b50029e..fedafef71 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -45,6 +45,16 @@ class MultiGPUCFGSplitNode(io.ComfyNode): class MultiGPUOptionsNode(io.ComfyNode): """ Select the relative speed of GPUs in the special case they have significantly different performance from one another. + + NOTE (not registered yet, see MultiGPUExtension.get_node_list below): + The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on + model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond + scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult + relative_speed when distributing conds across devices; it uses a uniform conds_per_device + round-robin via next_available_device(). Before re-enabling this node, wire its + relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(), + which already implements the proportional split) so the input actually affects work + distribution. """ @classmethod From adde1239b1037f7bf1b2dfce9052e6fd1fde4edf Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:35:39 -0700 Subject: [PATCH 7/9] Restore prepare_state backward-compatible signature Drop the new ignore_multigpu positional argument from prepare_state and from the ON_PREPARE_STATE callbacks; pass the flag via model_options instead. This restores the original 3-arg callback signature so existing custom-node ON_PREPARE_STATE handlers keep working unchanged, while still letting prepare_state's recursive call into multigpu_clones short-circuit. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/model_patcher.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 00d60ff72..b680de058 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1361,13 +1361,18 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep, model_options, ignore_multigpu=False): + def prepare_state(self, timestep, model_options): + ignore_multigpu = model_options.get("ignore_multigpu", False) for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep, model_options, ignore_multigpu) + callback(self, timestep, model_options) if not ignore_multigpu and "multigpu_clones" in model_options: - for p in model_options["multigpu_clones"].values(): - p: ModelPatcher - p.prepare_state(timestep, model_options, ignore_multigpu=True) + model_options["ignore_multigpu"] = True + try: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options) + finally: + model_options.pop("ignore_multigpu", None) def restore_hook_patches(self): if self.hook_patches_backup is not None: From 963621603ce2b43a567ec7cf88709555dfa9d6b5 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:35:54 -0700 Subject: [PATCH 8/9] Free QwenFunControlNet base_model reference in cleanup QwenFunControlNet.pre_run stashes the model's diffusion_model into self.extra_args['base_model'], but ControlBase.cleanup never clears extra_args. The diffusion_model reference therefore lingered between sampling runs, blocking ComfyUI's model offload/eviction logic from freeing the UNet and -- for multigpu -- holding one such reference per per-device control clone (defeating the max_gpus pruning added in this PR). Override cleanup to drop the entry; super().cleanup() already recurses into multigpu_clones so each per-device clone pops its own. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/controlnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 837aa907a..6dbbaa959 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -357,6 +357,10 @@ class QwenFunControlNet(ControlNet): super().pre_run(model, percent_to_timestep_function) self.set_extra_arg("base_model", model.diffusion_model) + def cleanup(self): + self.extra_args.pop("base_model", None) + super().cleanup() + def copy(self): c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c.control_model = self.control_model From a18dd219d57079c8d20ee1506feaf17a0e995ffb Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:40:49 -0700 Subject: [PATCH 9/9] Pass per-device model to multigpu control clones in pre_run_control QwenFunControlNet.pre_run stashes model.diffusion_model into extra_args, which the control_model then uses for forward passes (img_in, txt_in, pe_embedder, time_text_embed). With multigpu, every per-device control clone was being pre_run with the base model on GPU0, so secondary devices would invoke those modules with parameters on GPU0 and inputs on their own device, raising 'Expected all tensors to be on the same device'. Build a device -> per-device BaseModel lookup from the patcher's additional multigpu models and pass each clone the model on its own device. Falls back to the base model when no per-device match is found (single-GPU path and the case where cnet.multigpu_clones lags the patcher's clone set). Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/samplers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index a99af5217..8bfc42bdb 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -870,14 +870,21 @@ def calculate_start_end_timesteps(model, conds): def pre_run_control(model, conds): s = model.model_sampling + # Per-device model lookup so multigpu control clones get the matching + # diffusion_model (e.g. QwenFunControlNet stashes it into extra_args). + device_models: dict = {} + patcher = getattr(model, "current_patcher", None) + if patcher is not None: + for p in patcher.get_additional_models_with_key("multigpu"): + device_models[p.load_device] = p.model for t in range(len(conds)): x = conds[t] percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) - for device_cnet in x['control'].multigpu_clones.values(): - device_cnet.pre_run(model, percent_to_timestep_function) + for device, device_cnet in x['control'].multigpu_clones.items(): + device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = []