mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 16:20:17 +08:00
add bypass fwd into nodes list/trainer
This commit is contained in:
parent
2a420dc4db
commit
18f4fe8567
100
comfy/sd.py
100
comfy/sd.py
@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
import comfy.weight_adapter
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@ -100,6 +101,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
"""
|
||||
Load LoRA in bypass mode without modifying base model weights.
|
||||
|
||||
Instead of patching weights, this injects the LoRA computation into the
|
||||
forward pass: output = base_forward(x) + lora_path(x)
|
||||
|
||||
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
|
||||
|
||||
This is useful for training and when model weights are offloaded.
|
||||
"""
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
|
||||
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
|
||||
|
||||
# Separate adapters (for bypass) from other patches (for regular patching)
|
||||
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
|
||||
regular_patches = {} # diff, set, bias patches -> regular weight patching
|
||||
|
||||
for key, patch_data in loaded.items():
|
||||
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
|
||||
bypass_patches[key] = patch_data
|
||||
else:
|
||||
regular_patches[key] = patch_data
|
||||
|
||||
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
|
||||
|
||||
k = set()
|
||||
k1 = set()
|
||||
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
|
||||
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
|
||||
if regular_patches:
|
||||
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
|
||||
k.update(patched_keys)
|
||||
|
||||
# Apply adapter patches via bypass injection
|
||||
manager = comfy.weight_adapter.BypassInjectionManager()
|
||||
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
|
||||
|
||||
for key, adapter in bypass_patches.items():
|
||||
if key in model_sd_keys:
|
||||
manager.add_adapter(key, adapter, strength=strength_model)
|
||||
k.add(key)
|
||||
else:
|
||||
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
|
||||
|
||||
injections = manager.create_injections(new_modelpatcher.model)
|
||||
|
||||
if manager.get_hook_count() > 0:
|
||||
new_modelpatcher.set_injections("bypass_lora", injections)
|
||||
else:
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
|
||||
# Apply regular patches to clip
|
||||
if regular_patches:
|
||||
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
|
||||
k1.update(patched_keys)
|
||||
|
||||
# Apply adapter patches via bypass injection
|
||||
clip_manager = comfy.weight_adapter.BypassInjectionManager()
|
||||
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
|
||||
|
||||
for key, adapter in bypass_patches.items():
|
||||
if key in clip_sd_keys:
|
||||
clip_manager.add_adapter(key, adapter, strength=strength_clip)
|
||||
k1.add(key)
|
||||
|
||||
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
|
||||
if clip_manager.get_hook_count() > 0:
|
||||
new_clip.patcher.set_injections("bypass_lora", clip_injections)
|
||||
else:
|
||||
new_clip = None
|
||||
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
patch_data = loaded[x]
|
||||
patch_type = type(patch_data).__name__
|
||||
if isinstance(patch_data, tuple):
|
||||
patch_type = f"tuple({patch_data[0]})"
|
||||
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
|
||||
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
if no_init:
|
||||
|
||||
@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.weight_adapter import adapters, adapter_maps
|
||||
from comfy.weight_adapter.bypass import BypassInjectionManager
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
num_images = sum(t.shape[0] for t in latents)
|
||||
multi_res = False # Not using multi_res path in bucket mode
|
||||
|
||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
for i, lat in enumerate(latents):
|
||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||
logging.debug(f" Bucket {i}: shape {lat.shape}")
|
||||
return latents, num_images, multi_res
|
||||
|
||||
# Non-bucket mode
|
||||
@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
logging.debug(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
multi_res = True
|
||||
else:
|
||||
@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
||||
if bucket_mode:
|
||||
return positive # Skip validation in bucket mode
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
return positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
@ -596,6 +602,8 @@ def _create_weight_adapter(
|
||||
shape = module.weight.shape
|
||||
lora_params = {}
|
||||
|
||||
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
|
||||
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||
return lora_sd, all_weight_adapters
|
||||
|
||||
|
||||
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||
"""Setup LoRA adapters in bypass mode.
|
||||
|
||||
In bypass mode:
|
||||
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
|
||||
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
|
||||
|
||||
This is useful when the base model weights are quantized and cannot be
|
||||
directly modified.
|
||||
|
||||
Args:
|
||||
mp: Model patcher
|
||||
existing_weights: Dict of existing LoRA weights
|
||||
algorithm: Algorithm name for new adapters
|
||||
lora_dtype: dtype for LoRA weights
|
||||
rank: Rank for new LoRA adapters
|
||||
|
||||
Returns:
|
||||
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
|
||||
"""
|
||||
lora_sd = {}
|
||||
all_weight_adapters = []
|
||||
bypass_manager = BypassInjectionManager()
|
||||
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
adapter, params = _create_weight_adapter(
|
||||
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
lora_sd.update(params)
|
||||
all_weight_adapters.append(adapter)
|
||||
|
||||
key = f"{n}.weight"
|
||||
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
|
||||
# Only use bypass for adapters that have h() method (lora/lokr/oft)
|
||||
if isinstance(adapter, BiasDiff):
|
||||
mp.add_weight_wrapper(key, adapter)
|
||||
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
|
||||
else:
|
||||
bypass_manager.add_adapter(key, adapter, strength=1.0)
|
||||
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
|
||||
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
# Bias adapters still use weight wrapper (bias is usually not quantized)
|
||||
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||
lora_sd.update(bias_params)
|
||||
key = f"{n}.bias"
|
||||
mp.add_weight_wrapper(key, bias_adapter)
|
||||
all_weight_adapters.append(bias_adapter)
|
||||
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
|
||||
|
||||
return lora_sd, all_weight_adapters, bypass_manager
|
||||
|
||||
|
||||
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||
"""Create optimizer based on name.
|
||||
|
||||
@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default=False,
|
||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bypass_mode",
|
||||
default=False,
|
||||
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
display_name="model", tooltip="Model with LoRA applied"
|
||||
),
|
||||
io.Custom("LORA_MODEL").Output(
|
||||
display_name="lora", tooltip="LoRA weights"
|
||||
),
|
||||
@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
bypass_mode,
|
||||
):
|
||||
# Extract scalars from lists (due to is_input_list=True)
|
||||
model = model[0]
|
||||
@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode):
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
# Setup LoRA adapters
|
||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
bypass_manager = None
|
||||
if bypass_mode:
|
||||
logging.debug("Using bypass mode for training")
|
||||
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
else:
|
||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
|
||||
# Create optimizer and loss function
|
||||
optimizer = _create_optimizer(
|
||||
@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode):
|
||||
guider = TrainGuider(mp)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Inject bypass hooks if bypass mode is enabled
|
||||
bypass_injections = None
|
||||
if bypass_manager is not None:
|
||||
bypass_injections = bypass_manager.create_injections(mp.model)
|
||||
for injection in bypass_injections:
|
||||
injection.inject(mp)
|
||||
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
_run_training_loop(
|
||||
@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
# Eject bypass hooks if they were injected
|
||||
if bypass_injections is not None:
|
||||
for injection in bypass_injections:
|
||||
injection.eject(mp)
|
||||
logging.debug("[BypassMode] Ejected bypass hooks")
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
# mp in train node is highly specialized for training
|
||||
# use it in inference will result in bad behavior so we don't return it
|
||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
|
||||
67
nodes.py
67
nodes.py
@ -699,6 +699,69 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class LoraLoaderBypass:
|
||||
"""
|
||||
Apply LoRA in bypass mode without modifying base model weights.
|
||||
|
||||
Bypass mode computes: output = base_forward(x) + lora_path(x)
|
||||
This is useful for training and when model weights are offloaded.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP")
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
self.loaded_lora = None
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
|
||||
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_lora_model_only"
|
||||
|
||||
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||
@ -2008,6 +2071,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentFlip": LatentFlip,
|
||||
"LatentCrop": LatentCrop,
|
||||
"LoraLoader": LoraLoader,
|
||||
"LoraLoaderBypass": LoraLoaderBypass,
|
||||
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
|
||||
"CLIPLoader": CLIPLoader,
|
||||
"UNETLoader": UNETLoader,
|
||||
"DualCLIPLoader": DualCLIPLoader,
|
||||
@ -2047,6 +2112,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"VAELoader": "Load VAE",
|
||||
"LoraLoader": "Load LoRA",
|
||||
"LoraLoaderBypass": "Load LoRA (Bypass)",
|
||||
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)",
|
||||
"CLIPLoader": "Load CLIP",
|
||||
"ControlNetLoader": "Load ControlNet Model",
|
||||
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user