add bypass fwd into nodes list/trainer

This commit is contained in:
Kohaku-Blueleaf 2026-01-19 11:37:41 +08:00
parent 2a420dc4db
commit 18f4fe8567
3 changed files with 267 additions and 11 deletions

View File

@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import comfy.weight_adapter
import yaml
import math
import os
@ -100,6 +101,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
return (new_modelpatcher, new_clip)
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
"""
Load LoRA in bypass mode without modifying base model weights.
Instead of patching weights, this injects the LoRA computation into the
forward pass: output = base_forward(x) + lora_path(x)
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
This is useful for training and when model weights are offloaded.
"""
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
# Separate adapters (for bypass) from other patches (for regular patching)
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
regular_patches = {} # diff, set, bias patches -> regular weight patching
for key, patch_data in loaded.items():
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
bypass_patches[key] = patch_data
else:
regular_patches[key] = patch_data
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
k = set()
k1 = set()
if model is not None:
new_modelpatcher = model.clone()
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
if regular_patches:
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
k.update(patched_keys)
# Apply adapter patches via bypass injection
manager = comfy.weight_adapter.BypassInjectionManager()
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in model_sd_keys:
manager.add_adapter(key, adapter, strength=strength_model)
k.add(key)
else:
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
injections = manager.create_injections(new_modelpatcher.model)
if manager.get_hook_count() > 0:
new_modelpatcher.set_injections("bypass_lora", injections)
else:
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
# Apply regular patches to clip
if regular_patches:
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
k1.update(patched_keys)
# Apply adapter patches via bypass injection
clip_manager = comfy.weight_adapter.BypassInjectionManager()
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in clip_sd_keys:
clip_manager.add_adapter(key, adapter, strength=strength_clip)
k1.add(key)
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
if clip_manager.get_hook_count() > 0:
new_clip.patcher.set_injections("bypass_lora", clip_injections)
else:
new_clip = None
for x in loaded:
if (x not in k) and (x not in k1):
patch_data = loaded[x]
patch_type = type(patch_data).__name__
if isinstance(patch_data, tuple):
patch_type = f"tuple({patch_data[0]})"
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
return (new_modelpatcher, new_clip)
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:

View File

@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapters, adapter_maps
from comfy.weight_adapter.bypass import BypassInjectionManager
from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
logging.debug(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
logging.debug(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
@ -596,6 +602,8 @@ def _create_weight_adapter(
shape = module.weight.shape
lora_params = {}
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
return lora_sd, all_weight_adapters
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup LoRA adapters in bypass mode.
In bypass mode:
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
This is useful when the base model weights are quantized and cannot be
directly modified.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
"""
lora_sd = {}
all_weight_adapters = []
bypass_manager = BypassInjectionManager()
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
all_weight_adapters.append(adapter)
key = f"{n}.weight"
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
# Only use bypass for adapters that have h() method (lora/lokr/oft)
if isinstance(adapter, BiasDiff):
mp.add_weight_wrapper(key, adapter)
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
else:
bypass_manager.add_adapter(key, adapter, strength=1.0)
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
if hasattr(m, "bias") and m.bias is not None:
# Bias adapters still use weight wrapper (bias is usually not quantized)
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
return lora_sd, all_weight_adapters, bypass_manager
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode):
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
io.Boolean.Input(
"bypass_mode",
default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="Model with LoRA applied"
),
io.Custom("LORA_MODEL").Output(
display_name="lora", tooltip="LoRA weights"
),
@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing,
existing_lora,
bucket_mode,
bypass_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
# Process latents based on mode
if bucket_mode:
@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode):
existing_weights, existing_steps = _load_existing_lora(existing_lora)
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
bypass_manager = None
if bypass_mode:
logging.debug("Using bypass mode for training")
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
mp, existing_weights, algorithm, lora_dtype, rank
)
else:
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
# Create optimizer and loss function
optimizer = _create_optimizer(
@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode):
guider = TrainGuider(mp)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
bypass_injections = None
if bypass_manager is not None:
bypass_injections = bypass_manager.create_injections(mp.model)
for injection in bypass_injections:
injection.inject(mp)
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
# Run training loop
try:
_run_training_loop(
@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
injection.eject(mp)
logging.debug("[BypassMode] Ejected bypass hooks")
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode):
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):#

View File

@ -699,6 +699,69 @@ class LoraLoaderModelOnly(LoraLoader):
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class LoraLoaderBypass:
"""
Apply LoRA in bypass mode without modifying base model weights.
Bypass mode computes: output = base_forward(x) + lora_path(x)
This is useful for training and when model weights are offloaded.
"""
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
CATEGORY = "loaders"
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
self.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@ -2008,6 +2071,8 @@ NODE_CLASS_MAPPINGS = {
"LatentFlip": LatentFlip,
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"LoraLoaderBypass": LoraLoaderBypass,
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
"CLIPLoader": CLIPLoader,
"UNETLoader": UNETLoader,
"DualCLIPLoader": DualCLIPLoader,
@ -2047,6 +2112,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"LoraLoader": "Load LoRA",
"LoraLoaderBypass": "Load LoRA (Bypass)",
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",
"DiffControlNetLoader": "Load ControlNet Model (diff)",