diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2bb363fab..c68a52cc2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -457,23 +457,38 @@ class ModelPatcher: def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + if self.cached_patcher_init is None: + raise RuntimeError( + f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: " + "the loader that produced this model does not support multigpu " + "(cached_patcher_init is not initialized). Use a core loader " + "(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), " + "or have the custom loader register a cached_patcher_init factory." + ) comfy.model_management.unload_model_and_clones(self) - n = self.clone() + # Produce a freshly-loaded patcher from the loader factory so the multigpu + # clone owns its own untainted model weights (rather than relying on + # copy.deepcopy of an already-patched/already-loaded module). + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + # Override clone()'s normal "share self.model + share backup containers" with + # the pristine model from temp_model_patcher plus empty backup containers -- + # the fresh model has no patches applied, so any deepcopy of self's stale + # backup/object_patches_backup/pinned would just propagate dead state that + # no longer corresponds to anything in n.model. + model_override = (temp_model_patcher.model, ({}, {}, {}, set())) + n = self.clone(model_override=model_override) + # clone() copies hook_backup by reference from self; reset since model is pristine. + n.hook_backup = {} # set load device, if present if new_load_device is not None: n.load_device = new_load_device - if self.cached_patcher_init is not None: - temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) - if len(self.cached_patcher_init) > 2: - temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] - n.model = temp_model_patcher.model - else: - n.model = copy.deepcopy(n.model) - # unlike for normal clone, backup dicts that shared same ref should not; - # otherwise, patchers that have deep copies of base models will erroneously influence each other. - n.backup = copy.deepcopy(n.backup) - n.object_patches_backup = copy.deepcopy(n.object_patches_backup) - n.hook_backup = copy.deepcopy(n.hook_backup) + # Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins) + # has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's + # __init__ only registered its own (default) load_device. + if hasattr(n, "register_load_device"): + n.register_load_device(n.load_device) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") # multigpu_clone all stored additional_models; make sure circular references are properly handled diff --git a/comfy/sd.py b/comfy/sd.py index 1670a0486..084170c62 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1727,8 +1727,50 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) if out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + # Register reload factories for the CLIP and VAE produced by the same checkpoint so + # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device, + # MultiGPU work-units, etc.) without falling back to copy.deepcopy of an + # already-loaded module. + if out[1] is not None and getattr(out[1], "patcher", None) is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) + if out[2] is not None and getattr(out[2], "patcher", None) is not None: + out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) return out + +def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init + factory for the CLIP returned by load_checkpoint_guess_config.""" + _, clip, _, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=False, + output_clip=True, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return clip.patcher + + +def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init + factory for the VAE returned by load_checkpoint_guess_config.""" + _, _, vae, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=False, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return vae.patcher + def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -1954,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model + +def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False): + """Reload a disk-backed VAE from ``vae_path`` and return its patcher. + + Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so + :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a + fresh, untainted VAE patcher (no inherited per-device load state, no + in-place quantization fallout) for multigpu work-units and the + SelectVAEDevice node. The optional ``device`` matches the source loader's + VAE initialization path; the deepclone's ``load_device`` still controls + where the cloned patcher is targeted. + """ + if metadata is None: + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + else: + sd = comfy.utils.load_torch_file(vae_path) + vae = VAE(sd=sd, metadata=metadata, device=device) + vae.throw_exception_if_invalid() + return vae.patcher + def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 0e109f426..d39cca3f8 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -49,48 +49,82 @@ class MultiGPUCFGSplitNode(io.ComfyNode): def _remember_base_devices(patcher: ModelPatcher): """Stash the original load/offload device on the underlying model. - Stored on patcher.model (which is shared across patcher clones), so - repeated selector applications can recover the loader's original - routing when the user picks "default". + Stored on patcher.model (which is shared with the input patcher), so + later "default" selections can recover the loader's original routing. + Only the first Select on a given chain writes these attrs; subsequent + deepclones inherit them onto their freshly-loaded model below. """ if not hasattr(patcher.model, "_select_base_load_device"): patcher.model._select_base_load_device = patcher.load_device patcher.model._select_base_offload_device = patcher.offload_device -def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): - """Apply *resolved* to a freshly-cloned patcher; respect base devices on default. +def _propagate_base_devices(src_model, dst_model): + """Carry the loader-original device attrs onto the freshly-deepcloned model.""" + if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"): + dst_model._select_base_load_device = src_model._select_base_load_device + dst_model._select_base_offload_device = src_model._select_base_offload_device - Returns the (possibly newly-replaced) patcher. For CPU on a dynamic - patcher, also tries to downgrade to a plain ModelPatcher so the - dynamic-only code paths are bypassed (best-effort: silently keeps - the dynamic patcher if downgrade is not supported). + +def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device): + """Return a patcher whose actual model weights live on *target_load_device*. + + If *patcher* is already on *target_load_device* we just retarget the + (already-cloned) patcher's metadata in place. Otherwise we call + :meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from + the loader's ``cached_patcher_init`` factory -- the only safe way to + move weights that may already be partially loaded onto another device. + + NOTE: reusing the input patcher's model when the requested device + matches its current load_device is a deliberate fast path. Anything + that has already mutated the original model (e.g. a prior KSampler + invocation on the same model) will be observed here. This is by + design and documented on the SelectXDeviceNode docstrings -- placing + Select X Device after a node that consumes the same model is not + recommended. + """ + if patcher.load_device == target_load_device: + # Fast path: weights already on the desired device, just update offload. + patcher.offload_device = target_offload_device + return patcher + src_model = patcher.model + patcher = patcher.deepclone_multigpu(new_load_device=target_load_device) + patcher.offload_device = target_offload_device + _propagate_base_devices(src_model, patcher.model) + if hasattr(patcher, "register_load_device"): + patcher.register_load_device(patcher.load_device) + return patcher + + +def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): + """Resolve the requested device and produce a patcher routed there. + + For "default" we restore the loader's original load/offload pair. + For CPU we pin both load and offload to CPU (and, on a dynamic + patcher, downgrade to a plain ModelPatcher so the dynamic-only + code paths are bypassed). + For an explicit GPU we keep the loader's original offload but + target the requested load device; if that differs from the current + load device the patcher is deepcloned onto the new device. """ _remember_base_devices(patcher) base_load = patcher.model._select_base_load_device base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device if resolved is None: - # "default" -> reset routing to whatever the loader produced - patcher.load_device = base_load - patcher.offload_device = base_offload - elif resolved.type == "cpu": + # "default" -> route back to the loader's original devices. + return _retarget_patcher(patcher, base_load, base_offload) + if resolved.type == "cpu": if patcher.is_dynamic(): - try: - patcher = patcher.clone(disable_dynamic=True) - except Exception: - # Downgrade unavailable (no cached_patcher_init); fall - # back to the existing dynamic patcher. - pass + # clone(disable_dynamic=True) requires cached_patcher_init; let the + # exception surface to the caller (Select*DeviceNode.execute), which + # will translate it into a passthrough+log so unsupported loaders + # don't hard-fail the workflow. + patcher = patcher.clone(disable_dynamic=True) patcher.load_device = resolved patcher.offload_device = resolved - else: - patcher.load_device = resolved - patcher.offload_device = base_offload - - if hasattr(patcher, "register_load_device"): - patcher.register_load_device(patcher.load_device) - return patcher + return patcher + return _retarget_patcher(patcher, resolved, base_offload) def _prune_multigpu_collision(model: ModelPatcher, primary_device): @@ -122,6 +156,12 @@ class SelectModelDeviceNode(io.ComfyNode): - "gpu:N" pins the load device to the Nth available GPU; the offload device is restored to the loader's original choice. + When the requested device differs from the device the input model is + already on, a fresh model is spawned via the loader's reload factory + (cached_patcher_init) so the new patcher owns independent weights on + the new device. Loaders that don't support multigpu (no factory) will + cause the node to pass through unchanged with a warning. + If the workflow already has MultiGPU CFG Split applied and the chosen GPU collides with one of the existing multigpu clones, that clone is dropped so two patchers don't end up bound to the same device. @@ -130,6 +170,13 @@ class SelectModelDeviceNode(io.ComfyNode): (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), the node passes the model through unchanged and logs a message instead of failing. + + NOTE: Placing Select Model Device *after* a node that has already + consumed the same model (e.g. a KSampler that ran on this model on + the original device) is not recommended -- any state the prior + consumer mutated on the original model will be observed when the + selected device matches the original (fast path). Place Select Model + Device before any consumer of the model. """ @classmethod @@ -161,7 +208,11 @@ class SelectModelDeviceNode(io.ComfyNode): if resolved is None and device not in (None, "default"): logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") return io.NodeOutput(model) - model = _apply_patcher_device(model, resolved) + try: + model = _apply_patcher_device(model, resolved) + except RuntimeError as e: + logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") + return io.NodeOutput(model) if resolved is not None: _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model) @@ -208,7 +259,10 @@ class SelectCLIPDeviceNode(io.ComfyNode): if resolved is None and device not in (None, "default"): logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") return io.NodeOutput(clip) - clip.patcher = _apply_patcher_device(clip.patcher, resolved) + try: + clip.patcher = _apply_patcher_device(clip.patcher, resolved) + except RuntimeError as e: + logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})") return io.NodeOutput(clip) @@ -263,13 +317,19 @@ class SelectVAEDeviceNode(io.ComfyNode): if resolved is not None and resolved.type == "cpu": logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.") return io.NodeOutput(vae) - vae.patcher = _apply_patcher_device( - vae.patcher, resolved, - base_offload_override=comfy.model_management.vae_offload_device(), - ) - # VAE caches the working device separately from its patcher. if not hasattr(vae, "_select_base_device"): vae._select_base_device = vae.device + try: + vae.patcher = _apply_patcher_device( + vae.patcher, resolved, + base_offload_override=comfy.model_management.vae_offload_device(), + ) + except RuntimeError as e: + logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})") + return io.NodeOutput(vae) + # Keep VAE wrapper in sync with whatever model the patcher now owns; + # deepclone_multigpu may have produced a fresh first_stage_model. + vae.first_stage_model = vae.patcher.model vae.device = vae._select_base_device if resolved is None else resolved return io.NodeOutput(vae) diff --git a/nodes.py b/nodes.py index d1e9a2511..fd4365c90 100644 --- a/nodes.py +++ b/nodes.py @@ -795,6 +795,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): metadata = None + vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -813,6 +814,14 @@ class VAELoader: metadata["tae_latent_channels"] = 128 vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() + # Register a reload factory on the patcher so multigpu deepclones + # (Select VAE Device, future MultiGPU VAE work-units) can produce + # per-device clones from the same loader context. Only set when we + # actually have a single backing file -- pixel_space and the + # image TAESDs (composed from separate encoder/decoder files via + # load_taesd) are not addressable by a single vae_path. + if vae_path is not None: + vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None)) return (vae,) class ControlNetLoader: