diff --git a/comfy/sd.py b/comfy/sd.py index b689c0dfc..fcd2b4b4a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert +import comfy.weight_adapter import yaml import math import os @@ -100,6 +101,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): return (new_modelpatcher, new_clip) +def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip): + """ + Load LoRA in bypass mode without modifying base model weights. + + Instead of patching weights, this injects the LoRA computation into the + forward pass: output = base_forward(x) + lora_path(x) + + Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches. + + This is useful for training and when model weights are offloaded. + """ + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries") + + lora = comfy.lora_convert.convert_lora(lora) + loaded = comfy.lora.load_lora(lora, key_map) + + logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries") + + # Separate adapters (for bypass) from other patches (for regular patching) + bypass_patches = {} # WeightAdapterBase instances -> bypass mode + regular_patches = {} # diff, set, bias patches -> regular weight patching + + for key, patch_data in loaded.items(): + if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase): + bypass_patches[key] = patch_data + else: + regular_patches[key] = patch_data + + logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches") + + k = set() + k1 = set() + + if model is not None: + new_modelpatcher = model.clone() + + # Apply regular patches (bias diff, weight diff, etc.) via normal patching + if regular_patches: + patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model) + k.update(patched_keys) + + # Apply adapter patches via bypass injection + manager = comfy.weight_adapter.BypassInjectionManager() + model_sd_keys = set(new_modelpatcher.model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in model_sd_keys: + manager.add_adapter(key, adapter, strength=strength_model) + k.add(key) + else: + logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}") + + injections = manager.create_injections(new_modelpatcher.model) + + if manager.get_hook_count() > 0: + new_modelpatcher.set_injections("bypass_lora", injections) + else: + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + + # Apply regular patches to clip + if regular_patches: + patched_keys = new_clip.add_patches(regular_patches, strength_clip) + k1.update(patched_keys) + + # Apply adapter patches via bypass injection + clip_manager = comfy.weight_adapter.BypassInjectionManager() + clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in clip_sd_keys: + clip_manager.add_adapter(key, adapter, strength=strength_clip) + k1.add(key) + + clip_injections = clip_manager.create_injections(new_clip.cond_stage_model) + if clip_manager.get_hook_count() > 0: + new_clip.patcher.set_injections("bypass_lora", clip_injections) + else: + new_clip = None + + for x in loaded: + if (x not in k) and (x not in k1): + patch_data = loaded[x] + patch_type = type(patch_data).__name__ + if isinstance(patch_data, tuple): + patch_type = f"tuple({patch_data[0]})" + logging.warning(f"NOT LOADED: {x} (type={patch_type})") + + return (new_modelpatcher, new_clip) + + class CLIP: def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 364804205..69a43c7ee 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps +from comfy.weight_adapter.bypass import BypassInjectionManager from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar @@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + for param_groups in self.optimizer.param_groups: + for param in param_groups["params"]: + if param.grad is None: + continue + param.grad.data = param.grad.data.to(param.data.dtype) self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) @@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): num_images = sum(t.shape[0] for t in latents) multi_res = False # Not using multi_res path in bucket mode - logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") for i, lat in enumerate(latents): - logging.info(f" Bucket {i}: shape {lat.shape}") + logging.debug(f" Bucket {i}: shape {lat.shape}") return latents, num_images, multi_res # Non-bucket mode @@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") + logging.debug(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: @@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode): if bucket_mode: return positive # Skip validation in bucket mode - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: return positive * num_images elif len(positive) != num_images: @@ -596,6 +602,8 @@ def _create_weight_adapter( shape = module.weight.shape lora_params = {} + logging.debug(f"Creating weight adapter for {key} with shape {shape}") + if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get(f"{key}.dora_scale", None) @@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): return lora_sd, all_weight_adapters +def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup LoRA adapters in bypass mode. + + In bypass mode: + - Weight adapters (lora/lokr/oft) use bypass injection (forward hook) + - Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification) + + This is useful when the base model weights are quantized and cannot be + directly modified. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list, bypass_manager) + """ + lora_sd = {} + all_weight_adapters = [] + bypass_manager = BypassInjectionManager() + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + all_weight_adapters.append(adapter) + + key = f"{n}.weight" + # BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass + # Only use bypass for adapters that have h() method (lora/lokr/oft) + if isinstance(adapter, BiasDiff): + mp.add_weight_wrapper(key, adapter) + logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}") + else: + bypass_manager.add_adapter(key, adapter, strength=1.0) + logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}") + + if hasattr(m, "bias") and m.bias is not None: + # Bias adapters still use weight wrapper (bias is usually not quantized) + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}") + + return lora_sd, all_weight_adapters, bypass_manager + + def _create_optimizer(optimizer_name, parameters, learning_rate): """Create optimizer based on name. @@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", ), + io.Boolean.Input( + "bypass_mode", + default=False, + tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", + ), ], outputs=[ - io.Model.Output( - display_name="model", tooltip="Model with LoRA applied" - ), io.Custom("LORA_MODEL").Output( display_name="lora", tooltip="LoRA weights" ), @@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing, existing_lora, bucket_mode, + bypass_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] + bypass_mode = bypass_mode[0] # Process latents based on mode if bucket_mode: @@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode): existing_weights, existing_steps = _load_existing_lora(existing_lora) # Setup LoRA adapters - lora_sd, all_weight_adapters = _setup_lora_adapters( - mp, existing_weights, algorithm, lora_dtype, rank - ) + bypass_manager = None + if bypass_mode: + logging.debug("Using bypass mode for training") + lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass( + mp, existing_weights, algorithm, lora_dtype, rank + ) + else: + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) # Create optimizer and loss function optimizer = _create_optimizer( @@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode): guider = TrainGuider(mp) guider.set_conds(positive) + # Inject bypass hooks if bypass mode is enabled + bypass_injections = None + if bypass_manager is not None: + bypass_injections = bypass_manager.create_injections(mp.model) + for injection in bypass_injections: + injection.inject(mp) + logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks") + # Run training loop try: _run_training_loop( @@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + # Eject bypass hooks if they were injected + if bypass_injections is not None: + for injection in bypass_injections: + injection.eject(mp) + logging.debug("[BypassMode] Ejected bypass hooks") for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer @@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode): for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) + # mp in train node is highly specialized for training + # use it in inference will result in bad behavior so we don't return it + return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) class LoraModelLoader(io.ComfyNode):# diff --git a/nodes.py b/nodes.py index 5a9d42d4a..abad42a80 100644 --- a/nodes.py +++ b/nodes.py @@ -699,6 +699,69 @@ class LoraLoaderModelOnly(LoraLoader): def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @@ -2008,6 +2071,8 @@ NODE_CLASS_MAPPINGS = { "LatentFlip": LatentFlip, "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, "CLIPLoader": CLIPLoader, "UNETLoader": UNETLoader, "DualCLIPLoader": DualCLIPLoader, @@ -2047,6 +2112,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA", + "LoraLoaderBypass": "Load LoRA (Bypass)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)",