From dd85851efec772298772f159e2134cea45bd1b3e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:45 -0700 Subject: [PATCH] Prune inherited multigpu clones when max_gpus is lowered create_multigpu_deepclones cloned the existing 'multigpu' additional_models list verbatim and never pruned entries beyond limit_extra_devices. If a workflow was previously prepared for more GPUs, reducing max_gpus would leave stale clones attached and eligible for later scheduling. Replace the TODO block with a real prune that keeps only clones whose load_device is either the model's load_device or in limit_extra_devices, and re-match clones if anything was removed. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/multigpu.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 096270c12..eff7d0649 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -162,16 +162,16 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") - # TODO: only keep model clones that don't go 'past' the intended max_gpu count - # multigpu_models = model.get_additional_models_with_key("multigpu") - # new_multigpu_models = [] - # for m in multigpu_models: - # if m.load_device in limit_extra_devices: - # new_multigpu_models.append(m) - # model.set_additional_models("multigpu", new_multigpu_models) - # persist skip_devices for use in sampling code - # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: - # model.model_options["multigpu_skip_devices"] = skip_devices + # only keep model clones that don't go 'past' the intended max_gpu count; + # this prunes any inherited multigpu clones whose load_device is no longer allowed + # when max_gpus is lowered between runs. + allowed_devices = set(limit_extra_devices) + allowed_devices.add(model.load_device) + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices] + if len(new_multigpu_models) != len(multigpu_models): + model.set_additional_models("multigpu", new_multigpu_models) + model.match_multigpu_clones() return model