mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
multigpu: refactor deepclone_multigpu + register cached_patcher_init for CLIP/VAE; Select*Device retargets via deepclone
- ModelPatcher.deepclone_multigpu: remove copy.deepcopy fallback. Require cached_patcher_init (raise a descriptive RuntimeError if missing) and always go through clone(model_override=...) with empty backup containers so the per-device clone owns a pristine, unpatched module instead of a deepcopy of an already-loaded/already-patched one. Also call register_load_device on the new patcher so ModelPatcherDynamic per-device bookkeeping (e.g. dynamic_pins) is populated for the requested load device. - comfy/sd.py: register cached_patcher_init on the CLIP and VAE patchers returned by load_checkpoint_guess_config, and on the patcher returned by load_diffusion_model's companion paths. Add load_checkpoint_clip_patcher, load_checkpoint_vae_patcher, and load_vae_patcher reload helpers so the same loader context can be reused to produce per-device clones. - nodes.py: VAELoader registers cached_patcher_init on the produced VAE's patcher when there is a single backing file (skip for pixel_space and composite image-TAESDs which aren't addressable by a single path). - comfy_extras/nodes_multigpu.py: SelectModelDevice / SelectCLIPDevice / SelectVAEDevice now retarget via deepclone_multigpu when the requested device differs from the current load_device, so the consumed model is not just relabeled but actually rehomed onto the chosen device. Verified on runner-2 (2x RTX 4090, comfy-aimdo 0.4.4): - 10/10 focused unit tests (deepclone behavior, missing-factory error path, Select*Device behavior). - Device-switch-after-consumption end-to-end (SD1.5) produces bit-identical PNGs on cuda:0 and cuda:1. - Z Image multigpu CFG split: ~1.90x speedup (10.5s vs 19.9s steady). - Qwen Image multigpu CFG split (real text negative, cfg=4): ~1.69x speedup (32.5s vs 54.8s steady) -- matches pre-refactor numbers. - Baseline (patch stashed) and patched produce identical timings on both models, so the refactor is performance-neutral. Amp-Thread-ID: https://ampcode.com/threads/T-019e5783-b810-74b1-8ca9-09d675de1479 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
5c2e34ca4e
commit
bece6b2aec
@ -457,23 +457,38 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||||
|
if self.cached_patcher_init is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
|
||||||
|
"the loader that produced this model does not support multigpu "
|
||||||
|
"(cached_patcher_init is not initialized). Use a core loader "
|
||||||
|
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
|
||||||
|
"or have the custom loader register a cached_patcher_init factory."
|
||||||
|
)
|
||||||
comfy.model_management.unload_model_and_clones(self)
|
comfy.model_management.unload_model_and_clones(self)
|
||||||
n = self.clone()
|
# Produce a freshly-loaded patcher from the loader factory so the multigpu
|
||||||
# set load device, if present
|
# clone owns its own untainted model weights (rather than relying on
|
||||||
if new_load_device is not None:
|
# copy.deepcopy of an already-patched/already-loaded module).
|
||||||
n.load_device = new_load_device
|
|
||||||
if self.cached_patcher_init is not None:
|
|
||||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||||
if len(self.cached_patcher_init) > 2:
|
if len(self.cached_patcher_init) > 2:
|
||||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||||
n.model = temp_model_patcher.model
|
# Override clone()'s normal "share self.model + share backup containers" with
|
||||||
else:
|
# the pristine model from temp_model_patcher plus empty backup containers --
|
||||||
n.model = copy.deepcopy(n.model)
|
# the fresh model has no patches applied, so any deepcopy of self's stale
|
||||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
# backup/object_patches_backup/pinned would just propagate dead state that
|
||||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
# no longer corresponds to anything in n.model.
|
||||||
n.backup = copy.deepcopy(n.backup)
|
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
|
||||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
n = self.clone(model_override=model_override)
|
||||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
# clone() copies hook_backup by reference from self; reset since model is pristine.
|
||||||
|
n.hook_backup = {}
|
||||||
|
# set load device, if present
|
||||||
|
if new_load_device is not None:
|
||||||
|
n.load_device = new_load_device
|
||||||
|
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
|
||||||
|
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
|
||||||
|
# __init__ only registered its own (default) load_device.
|
||||||
|
if hasattr(n, "register_load_device"):
|
||||||
|
n.register_load_device(n.load_device)
|
||||||
# multigpu clone should not have multigpu additional_models entry
|
# multigpu clone should not have multigpu additional_models entry
|
||||||
n.remove_additional_models("multigpu")
|
n.remove_additional_models("multigpu")
|
||||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||||
|
|||||||
62
comfy/sd.py
62
comfy/sd.py
@ -1727,8 +1727,50 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if out[0] is not None:
|
if out[0] is not None:
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||||
|
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
|
||||||
|
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
|
||||||
|
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
|
||||||
|
# already-loaded module.
|
||||||
|
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
|
||||||
|
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
|
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
|
||||||
|
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
|
||||||
|
factory for the CLIP returned by load_checkpoint_guess_config."""
|
||||||
|
_, clip, _, _ = load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=False,
|
||||||
|
output_clip=True,
|
||||||
|
output_clipvision=False,
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic,
|
||||||
|
)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
|
||||||
|
factory for the VAE returned by load_checkpoint_guess_config."""
|
||||||
|
_, _, vae, _ = load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=True,
|
||||||
|
output_clip=False,
|
||||||
|
output_clipvision=False,
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic,
|
||||||
|
)
|
||||||
|
return vae.patcher
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||||
embedding_directory=embedding_directory,
|
embedding_directory=embedding_directory,
|
||||||
@ -1954,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
|||||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
|
||||||
|
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
|
||||||
|
|
||||||
|
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
|
||||||
|
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||||
|
fresh, untainted VAE patcher (no inherited per-device load state, no
|
||||||
|
in-place quantization fallout) for multigpu work-units and the
|
||||||
|
SelectVAEDevice node. The optional ``device`` matches the source loader's
|
||||||
|
VAE initialization path; the deepclone's ``load_device`` still controls
|
||||||
|
where the cloned patcher is targeted.
|
||||||
|
"""
|
||||||
|
if metadata is None:
|
||||||
|
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
|
else:
|
||||||
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
|
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
|
return vae.patcher
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|||||||
@ -49,48 +49,82 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
|
|||||||
def _remember_base_devices(patcher: ModelPatcher):
|
def _remember_base_devices(patcher: ModelPatcher):
|
||||||
"""Stash the original load/offload device on the underlying model.
|
"""Stash the original load/offload device on the underlying model.
|
||||||
|
|
||||||
Stored on patcher.model (which is shared across patcher clones), so
|
Stored on patcher.model (which is shared with the input patcher), so
|
||||||
repeated selector applications can recover the loader's original
|
later "default" selections can recover the loader's original routing.
|
||||||
routing when the user picks "default".
|
Only the first Select on a given chain writes these attrs; subsequent
|
||||||
|
deepclones inherit them onto their freshly-loaded model below.
|
||||||
"""
|
"""
|
||||||
if not hasattr(patcher.model, "_select_base_load_device"):
|
if not hasattr(patcher.model, "_select_base_load_device"):
|
||||||
patcher.model._select_base_load_device = patcher.load_device
|
patcher.model._select_base_load_device = patcher.load_device
|
||||||
patcher.model._select_base_offload_device = patcher.offload_device
|
patcher.model._select_base_offload_device = patcher.offload_device
|
||||||
|
|
||||||
|
|
||||||
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
|
def _propagate_base_devices(src_model, dst_model):
|
||||||
"""Apply *resolved* to a freshly-cloned patcher; respect base devices on default.
|
"""Carry the loader-original device attrs onto the freshly-deepcloned model."""
|
||||||
|
if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
|
||||||
|
dst_model._select_base_load_device = src_model._select_base_load_device
|
||||||
|
dst_model._select_base_offload_device = src_model._select_base_offload_device
|
||||||
|
|
||||||
Returns the (possibly newly-replaced) patcher. For CPU on a dynamic
|
|
||||||
patcher, also tries to downgrade to a plain ModelPatcher so the
|
def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
|
||||||
dynamic-only code paths are bypassed (best-effort: silently keeps
|
"""Return a patcher whose actual model weights live on *target_load_device*.
|
||||||
the dynamic patcher if downgrade is not supported).
|
|
||||||
|
If *patcher* is already on *target_load_device* we just retarget the
|
||||||
|
(already-cloned) patcher's metadata in place. Otherwise we call
|
||||||
|
:meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
|
||||||
|
the loader's ``cached_patcher_init`` factory -- the only safe way to
|
||||||
|
move weights that may already be partially loaded onto another device.
|
||||||
|
|
||||||
|
NOTE: reusing the input patcher's model when the requested device
|
||||||
|
matches its current load_device is a deliberate fast path. Anything
|
||||||
|
that has already mutated the original model (e.g. a prior KSampler
|
||||||
|
invocation on the same model) will be observed here. This is by
|
||||||
|
design and documented on the SelectXDeviceNode docstrings -- placing
|
||||||
|
Select X Device after a node that consumes the same model is not
|
||||||
|
recommended.
|
||||||
|
"""
|
||||||
|
if patcher.load_device == target_load_device:
|
||||||
|
# Fast path: weights already on the desired device, just update offload.
|
||||||
|
patcher.offload_device = target_offload_device
|
||||||
|
return patcher
|
||||||
|
src_model = patcher.model
|
||||||
|
patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
|
||||||
|
patcher.offload_device = target_offload_device
|
||||||
|
_propagate_base_devices(src_model, patcher.model)
|
||||||
|
if hasattr(patcher, "register_load_device"):
|
||||||
|
patcher.register_load_device(patcher.load_device)
|
||||||
|
return patcher
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
|
||||||
|
"""Resolve the requested device and produce a patcher routed there.
|
||||||
|
|
||||||
|
For "default" we restore the loader's original load/offload pair.
|
||||||
|
For CPU we pin both load and offload to CPU (and, on a dynamic
|
||||||
|
patcher, downgrade to a plain ModelPatcher so the dynamic-only
|
||||||
|
code paths are bypassed).
|
||||||
|
For an explicit GPU we keep the loader's original offload but
|
||||||
|
target the requested load device; if that differs from the current
|
||||||
|
load device the patcher is deepcloned onto the new device.
|
||||||
"""
|
"""
|
||||||
_remember_base_devices(patcher)
|
_remember_base_devices(patcher)
|
||||||
base_load = patcher.model._select_base_load_device
|
base_load = patcher.model._select_base_load_device
|
||||||
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
|
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
|
||||||
|
|
||||||
if resolved is None:
|
if resolved is None:
|
||||||
# "default" -> reset routing to whatever the loader produced
|
# "default" -> route back to the loader's original devices.
|
||||||
patcher.load_device = base_load
|
return _retarget_patcher(patcher, base_load, base_offload)
|
||||||
patcher.offload_device = base_offload
|
if resolved.type == "cpu":
|
||||||
elif resolved.type == "cpu":
|
|
||||||
if patcher.is_dynamic():
|
if patcher.is_dynamic():
|
||||||
try:
|
# clone(disable_dynamic=True) requires cached_patcher_init; let the
|
||||||
|
# exception surface to the caller (Select*DeviceNode.execute), which
|
||||||
|
# will translate it into a passthrough+log so unsupported loaders
|
||||||
|
# don't hard-fail the workflow.
|
||||||
patcher = patcher.clone(disable_dynamic=True)
|
patcher = patcher.clone(disable_dynamic=True)
|
||||||
except Exception:
|
|
||||||
# Downgrade unavailable (no cached_patcher_init); fall
|
|
||||||
# back to the existing dynamic patcher.
|
|
||||||
pass
|
|
||||||
patcher.load_device = resolved
|
patcher.load_device = resolved
|
||||||
patcher.offload_device = resolved
|
patcher.offload_device = resolved
|
||||||
else:
|
|
||||||
patcher.load_device = resolved
|
|
||||||
patcher.offload_device = base_offload
|
|
||||||
|
|
||||||
if hasattr(patcher, "register_load_device"):
|
|
||||||
patcher.register_load_device(patcher.load_device)
|
|
||||||
return patcher
|
return patcher
|
||||||
|
return _retarget_patcher(patcher, resolved, base_offload)
|
||||||
|
|
||||||
|
|
||||||
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
|
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
|
||||||
@ -122,6 +156,12 @@ class SelectModelDeviceNode(io.ComfyNode):
|
|||||||
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
||||||
device is restored to the loader's original choice.
|
device is restored to the loader's original choice.
|
||||||
|
|
||||||
|
When the requested device differs from the device the input model is
|
||||||
|
already on, a fresh model is spawned via the loader's reload factory
|
||||||
|
(cached_patcher_init) so the new patcher owns independent weights on
|
||||||
|
the new device. Loaders that don't support multigpu (no factory) will
|
||||||
|
cause the node to pass through unchanged with a warning.
|
||||||
|
|
||||||
If the workflow already has MultiGPU CFG Split applied and the chosen
|
If the workflow already has MultiGPU CFG Split applied and the chosen
|
||||||
GPU collides with one of the existing multigpu clones, that clone is
|
GPU collides with one of the existing multigpu clones, that clone is
|
||||||
dropped so two patchers don't end up bound to the same device.
|
dropped so two patchers don't end up bound to the same device.
|
||||||
@ -130,6 +170,13 @@ class SelectModelDeviceNode(io.ComfyNode):
|
|||||||
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||||
the node passes the model through unchanged and logs a message
|
the node passes the model through unchanged and logs a message
|
||||||
instead of failing.
|
instead of failing.
|
||||||
|
|
||||||
|
NOTE: Placing Select Model Device *after* a node that has already
|
||||||
|
consumed the same model (e.g. a KSampler that ran on this model on
|
||||||
|
the original device) is not recommended -- any state the prior
|
||||||
|
consumer mutated on the original model will be observed when the
|
||||||
|
selected device matches the original (fast path). Place Select Model
|
||||||
|
Device before any consumer of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -161,7 +208,11 @@ class SelectModelDeviceNode(io.ComfyNode):
|
|||||||
if resolved is None and device not in (None, "default"):
|
if resolved is None and device not in (None, "default"):
|
||||||
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
|
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
try:
|
||||||
model = _apply_patcher_device(model, resolved)
|
model = _apply_patcher_device(model, resolved)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
||||||
|
return io.NodeOutput(model)
|
||||||
if resolved is not None:
|
if resolved is not None:
|
||||||
_prune_multigpu_collision(model, model.load_device)
|
_prune_multigpu_collision(model, model.load_device)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
@ -208,7 +259,10 @@ class SelectCLIPDeviceNode(io.ComfyNode):
|
|||||||
if resolved is None and device not in (None, "default"):
|
if resolved is None and device not in (None, "default"):
|
||||||
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
|
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
|
||||||
return io.NodeOutput(clip)
|
return io.NodeOutput(clip)
|
||||||
|
try:
|
||||||
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
|
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
|
||||||
return io.NodeOutput(clip)
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
|
||||||
@ -263,13 +317,19 @@ class SelectVAEDeviceNode(io.ComfyNode):
|
|||||||
if resolved is not None and resolved.type == "cpu":
|
if resolved is not None and resolved.type == "cpu":
|
||||||
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
|
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
|
||||||
return io.NodeOutput(vae)
|
return io.NodeOutput(vae)
|
||||||
|
if not hasattr(vae, "_select_base_device"):
|
||||||
|
vae._select_base_device = vae.device
|
||||||
|
try:
|
||||||
vae.patcher = _apply_patcher_device(
|
vae.patcher = _apply_patcher_device(
|
||||||
vae.patcher, resolved,
|
vae.patcher, resolved,
|
||||||
base_offload_override=comfy.model_management.vae_offload_device(),
|
base_offload_override=comfy.model_management.vae_offload_device(),
|
||||||
)
|
)
|
||||||
# VAE caches the working device separately from its patcher.
|
except RuntimeError as e:
|
||||||
if not hasattr(vae, "_select_base_device"):
|
logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
|
||||||
vae._select_base_device = vae.device
|
return io.NodeOutput(vae)
|
||||||
|
# Keep VAE wrapper in sync with whatever model the patcher now owns;
|
||||||
|
# deepclone_multigpu may have produced a fresh first_stage_model.
|
||||||
|
vae.first_stage_model = vae.patcher.model
|
||||||
vae.device = vae._select_base_device if resolved is None else resolved
|
vae.device = vae._select_base_device if resolved is None else resolved
|
||||||
return io.NodeOutput(vae)
|
return io.NodeOutput(vae)
|
||||||
|
|
||||||
|
|||||||
9
nodes.py
9
nodes.py
@ -795,6 +795,7 @@ class VAELoader:
|
|||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
metadata = None
|
metadata = None
|
||||||
|
vae_path = None
|
||||||
if vae_name == "pixel_space":
|
if vae_name == "pixel_space":
|
||||||
sd = {}
|
sd = {}
|
||||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||||
@ -813,6 +814,14 @@ class VAELoader:
|
|||||||
metadata["tae_latent_channels"] = 128
|
metadata["tae_latent_channels"] = 128
|
||||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
|
# Register a reload factory on the patcher so multigpu deepclones
|
||||||
|
# (Select VAE Device, future MultiGPU VAE work-units) can produce
|
||||||
|
# per-device clones from the same loader context. Only set when we
|
||||||
|
# actually have a single backing file -- pixel_space and the
|
||||||
|
# image TAESDs (composed from separate encoder/decoder files via
|
||||||
|
# load_taesd) are not addressable by a single vae_path.
|
||||||
|
if vae_path is not None:
|
||||||
|
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None))
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user