From a97c98068f6301b1f87ce89e7bd942ee2db3155d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:56:22 +0800 Subject: [PATCH] [Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958) * Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer --- comfy/sd.py | 100 +++++++ comfy/weight_adapter/__init__.py | 8 + comfy/weight_adapter/base.py | 231 +++++++++++++++- comfy/weight_adapter/boft.py | 119 ++++++++- comfy/weight_adapter/bypass.py | 437 +++++++++++++++++++++++++++++++ comfy/weight_adapter/glora.py | 219 +++++++++++++++- comfy/weight_adapter/loha.py | 186 +++++++++++-- comfy/weight_adapter/lokr.py | 311 ++++++++++++++++++++-- comfy/weight_adapter/lora.py | 165 +++++++++++- comfy/weight_adapter/oft.py | 186 ++++++++++++- comfy_extras/nodes_train.py | 111 +++++++- nodes.py | 67 +++++ 12 files changed, 2039 insertions(+), 101 deletions(-) create mode 100644 comfy/weight_adapter/bypass.py diff --git a/comfy/sd.py b/comfy/sd.py index ce7e6bcff..f627f7d55 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert +import comfy.weight_adapter import yaml import math import os @@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): return (new_modelpatcher, new_clip) +def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip): + """ + Load LoRA in bypass mode without modifying base model weights. + + Instead of patching weights, this injects the LoRA computation into the + forward pass: output = base_forward(x) + lora_path(x) + + Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches. + + This is useful for training and when model weights are offloaded. + """ + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries") + + lora = comfy.lora_convert.convert_lora(lora) + loaded = comfy.lora.load_lora(lora, key_map) + + logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries") + + # Separate adapters (for bypass) from other patches (for regular patching) + bypass_patches = {} # WeightAdapterBase instances -> bypass mode + regular_patches = {} # diff, set, bias patches -> regular weight patching + + for key, patch_data in loaded.items(): + if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase): + bypass_patches[key] = patch_data + else: + regular_patches[key] = patch_data + + logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches") + + k = set() + k1 = set() + + if model is not None: + new_modelpatcher = model.clone() + + # Apply regular patches (bias diff, weight diff, etc.) via normal patching + if regular_patches: + patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model) + k.update(patched_keys) + + # Apply adapter patches via bypass injection + manager = comfy.weight_adapter.BypassInjectionManager() + model_sd_keys = set(new_modelpatcher.model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in model_sd_keys: + manager.add_adapter(key, adapter, strength=strength_model) + k.add(key) + else: + logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}") + + injections = manager.create_injections(new_modelpatcher.model) + + if manager.get_hook_count() > 0: + new_modelpatcher.set_injections("bypass_lora", injections) + else: + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + + # Apply regular patches to clip + if regular_patches: + patched_keys = new_clip.add_patches(regular_patches, strength_clip) + k1.update(patched_keys) + + # Apply adapter patches via bypass injection + clip_manager = comfy.weight_adapter.BypassInjectionManager() + clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in clip_sd_keys: + clip_manager.add_adapter(key, adapter, strength=strength_clip) + k1.add(key) + + clip_injections = clip_manager.create_injections(new_clip.cond_stage_model) + if clip_manager.get_hook_count() > 0: + new_clip.patcher.set_injections("bypass_lora", clip_injections) + else: + new_clip = None + + for x in loaded: + if (x not in k) and (x not in k1): + patch_data = loaded[x] + patch_type = type(patch_data).__name__ + if isinstance(patch_data, tuple): + patch_type = f"tuple({patch_data[0]})" + logging.warning(f"NOT LOADED: {x} (type={patch_type})") + + return (new_modelpatcher, new_clip) + + class CLIP: def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index b40f920e4..b9fa8d5cf 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -5,6 +5,11 @@ from .lokr import LoKrAdapter from .glora import GLoRAAdapter from .oft import OFTAdapter from .boft import BOFTAdapter +from .bypass import ( + BypassInjectionManager, + BypassForwardHook, + create_bypass_injections_from_patches, +) adapters: list[type[WeightAdapterBase]] = [ @@ -31,4 +36,7 @@ __all__ = [ "WeightAdapterTrainBase", "adapters", "adapter_maps", + "BypassInjectionManager", + "BypassForwardHook", + "create_bypass_injections_from_patches", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 43644b106..bce89a0e2 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch import torch.nn as nn @@ -7,12 +7,35 @@ import comfy.model_management class WeightAdapterBase: + """ + Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT/BOFT: g = transform, h = 0 + """ + name: str loaded_keys: set[str] weights: list[torch.Tensor] + # Attributes set by bypass system + multiplier: float = 1.0 + shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel) + @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + ) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -39,18 +62,202 @@ class WeightAdapterBase: ): raise NotImplementedError + # ===== Bypass Mode Methods ===== + # + # IMPORTANT: Bypass mode is designed for quantized models where original weights + # may not be accessible in a usable format. Therefore, h() and bypass_forward() + # do NOT take org_weight as a parameter. All necessary information (out_channels, + # in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook. + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT/BOFT), returns zeros. + + Note: + This method does NOT access original model weights. Bypass mode is + designed for quantized models where weights may not be in a usable format. + All shape info comes from module attributes set by BypassForwardHook. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Reference: LyCORIS LoConModule.bypass_forward_diff + """ + # Default: no additive component (for OFT/BOFT) + # Simply return zeros matching base_out shape + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT/BOFT override this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + + Reference: LyCORIS OFTModule applies orthogonal transform here + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Note: + This method does NOT take org_weight/org_bias parameters. Bypass mode + is designed for quantized models where weights may not be accessible. + The original forward function handles weight access internally. + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + + Reference: LyCORIS LoConModule.bypass_forward + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + class WeightAdapterTrainBase(nn.Module): - # We follow the scheme of PR #7032 + """ + Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT: g = transform, h = 0 + + Note: + Unlike WeightAdapterBase, TrainBase classes have simplified weight formats + with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition). + + We follow the scheme of PR #7032 + """ + + # Attributes set by bypass system (BypassForwardHook) + # These are set before h()/g()/bypass_forward() are called + multiplier: float = 1.0 + is_conv: bool = False + conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d + kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups + kernel_size: tuple = () + in_channels: int = None + out_channels: int = None + def __init__(self): super().__init__() def __call__(self, w): """ - w: The original weight tensor to be modified. + Weight modification mode: returns modified weight. + + Args: + w: The original weight tensor to be modified. + + Returns: + Modified weight tensor. """ raise NotImplementedError + # ===== Bypass Mode Methods ===== + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT), returns zeros. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Subclasses should override this method. + """ + raise NotImplementedError( + f"{self.__class__.__name__}.h() not implemented. " + "Subclasses must implement h() for bypass mode." + ) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT overrides this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + def passive_memory_usage(self): raise NotImplementedError("passive_memory_usage is not implemented") @@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module): return self.passive_memory_usage() -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) +def weight_decompose( + dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function +): + dora_scale = comfy.model_management.cast_to_device( + dora_scale, weight.device, intermediate_dtype + ) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) @@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten the original tensor will be truncated in that dimension. """ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): - raise ValueError("The new shape must be larger than the original tensor in all dimensions") + raise ValueError( + "The new shape must be larger than the original tensor in all dimensions" + ) if len(new_shape) != len(tensor.shape): - raise ValueError("The new shape must have the same number of dimensions as the original tensor") + raise ValueError( + "The new shape must have the same number of dimensions as the original tensor" + ) # Create a new tensor filled with zeros padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd4..02a8dc130 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase): alpha = v[2] dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) boft_m, block_num, boft_b, *_ = blocks.shape @@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(-1, -2) normed_q = q - if alpha > 0: # alpha in boft/bboft is for constraint + if alpha > 0: # alpha in boft/bboft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm @@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase): r = r.to(weight) inp = org = weight - r_b = boft_b//2 + r_b = boft_b // 2 for i in range(boft_m): bi = r[i] g = 2 k = 2**i * r_b if strength != 1: - bi = bi * strength + (1-strength) * I + bi = bi * strength + (1 - strength) * I inp = ( inp.unflatten(0, (-1, g, k)) .transpose(1, 2) @@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase): ) inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) inp = ( - inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) + inp.flatten(0, 1) + .unflatten(0, (-1, k, g)) + .transpose(1, 2) + .flatten(0, 2) ) if rescale is not None: inp = inp * rescale lora_diff = inp - org - lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + lora_diff = comfy.model_management.cast_to_device( + lora_diff, weight.device, intermediate_dtype + ) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrices(self, device, dtype): + """Compute the orthogonal rotation matrices R from BOFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + boft_m, block_num, boft_b, _ = blocks.shape + I = torch.eye(boft_b, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(-1, -2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, boft_m, boft_b + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for BOFT: applies butterfly orthogonal transform. + + BOFT uses multiple stages of butterfly-structured orthogonal transforms. + + Reference: LyCORIS ButterflyOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype) + r_b = boft_b // 2 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(boft_b, device=y.device, dtype=y.dtype) + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # Apply butterfly transform stages + inp = y + for i in range(boft_m): + bi = r[i] # (block_num, boft_b, boft_b) + g = 2 + k = 2**i * r_b + + # Interpolate with identity based on multiplier + if multiplier != 1: + bi = bi * multiplier + (1 - multiplier) * I + + # Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + # Apply block-diagonal orthogonal transform + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # Reshape back + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + inp = inp * rescale.transpose(0, -1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + inp = inp.transpose(1, -1) + + return inp diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py new file mode 100644 index 000000000..d4aaf98ca --- /dev/null +++ b/comfy/weight_adapter/bypass.py @@ -0,0 +1,437 @@ +""" +Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.) + +Bypass mode applies adapters during forward pass without modifying base weights: + bypass(f)(x) = g(f(x) + h(x)) + +Where: + - f(x): Original layer forward + - h(x): Additive component from adapter (LoRA path) + - g(y): Output transformation (identity for most adapters) + +This is useful for: + - Training with gradient checkpointing + - Avoiding weight modifications when weights are offloaded + - Supporting multiple adapters with different strengths dynamically +""" + +import logging +from typing import Optional, Union + +import torch +import torch.nn as nn + +from .base import WeightAdapterBase, WeightAdapterTrainBase +from comfy.patcher_extension import PatcherInjection + +# Type alias for adapters that support bypass mode +BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] + + +def get_module_type_info(module: nn.Module) -> dict: + """ + Determine module type and extract conv parameters from module class. + + This is more reliable than checking weight.ndim, especially for quantized layers + where weight shape might be different. + + Returns: + dict with keys: is_conv, conv_dim, stride, padding, dilation, groups + """ + info = { + "is_conv": False, + "conv_dim": 0, + "stride": (1,), + "padding": (0,), + "dilation": (1,), + "groups": 1, + "kernel_size": (1,), + "in_channels": None, + "out_channels": None, + } + + # Determine conv type + if isinstance(module, nn.Conv1d): + info["is_conv"] = True + info["conv_dim"] = 1 + elif isinstance(module, nn.Conv2d): + info["is_conv"] = True + info["conv_dim"] = 2 + elif isinstance(module, nn.Conv3d): + info["is_conv"] = True + info["conv_dim"] = 3 + elif isinstance(module, nn.Linear): + info["is_conv"] = False + info["conv_dim"] = 0 + else: + # Try to infer from class name for custom/quantized layers + class_name = type(module).__name__.lower() + if "conv3d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 3 + elif "conv2d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + elif "conv1d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 1 + elif "conv" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + + # Extract conv parameters if it's a conv layer + if info["is_conv"]: + # Try to get stride, padding, dilation, groups, kernel_size from module + info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"]) + info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"]) + info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"]) + info["groups"] = getattr(module, "groups", 1) + info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"]) + info["in_channels"] = getattr(module, "in_channels", None) + info["out_channels"] = getattr(module, "out_channels", None) + + # Ensure they're tuples + if isinstance(info["stride"], int): + info["stride"] = (info["stride"],) * info["conv_dim"] + if isinstance(info["padding"], int): + info["padding"] = (info["padding"],) * info["conv_dim"] + if isinstance(info["dilation"], int): + info["dilation"] = (info["dilation"],) * info["conv_dim"] + if isinstance(info["kernel_size"], int): + info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"] + + return info + + +class BypassForwardHook: + """ + Hook that wraps a layer's forward to apply adapter in bypass mode. + + Stores the original forward and replaces it with bypass version. + + Supports both: + - WeightAdapterBase: Inference adapters (uses self.weights tuple) + - WeightAdapterTrainBase: Training adapters (nn.Module with parameters) + """ + + def __init__( + self, + module: nn.Module, + adapter: BypassAdapter, + multiplier: float = 1.0, + ): + self.module = module + self.adapter = adapter + self.multiplier = multiplier + self.original_forward = None + + # Determine layer type and conv params from module class (works for quantized layers) + module_info = get_module_type_info(module) + + # Set multiplier and layer type info on adapter for use in h() + adapter.multiplier = multiplier + adapter.is_conv = module_info["is_conv"] + adapter.conv_dim = module_info["conv_dim"] + adapter.kernel_size = module_info["kernel_size"] + adapter.in_channels = module_info["in_channels"] + adapter.out_channels = module_info["out_channels"] + # Store kw_dict for conv operations (like LyCORIS extra_args) + if module_info["is_conv"]: + adapter.kw_dict = { + "stride": module_info["stride"], + "padding": module_info["padding"], + "dilation": module_info["dilation"], + "groups": module_info["groups"], + } + else: + adapter.kw_dict = {} + + def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x)) + + Note: + Bypass mode does NOT access original model weights (org_weight). + This is intentional - bypass mode is designed for quantized models + where weights may not be in a usable format. All necessary shape + information is provided via adapter attributes set during inject(). + """ + # Check if adapter has custom bypass_forward (e.g., GLoRA) + adapter_bypass = getattr(self.adapter, "bypass_forward", None) + if adapter_bypass is not None: + # Check if it's overridden (not the base class default) + # Need to check both base classes since adapter could be either type + adapter_type = type(self.adapter) + is_default_bypass = ( + adapter_type.bypass_forward is WeightAdapterBase.bypass_forward + or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward + ) + if not is_default_bypass: + return adapter_bypass(self.original_forward, x, *args, **kwargs) + + # Default bypass: g(f(x) + h(x, f(x))) + base_out = self.original_forward(x, *args, **kwargs) + h_out = self.adapter.h(x, base_out) + return self.adapter.g(base_out + h_out) + + def inject(self): + """Replace module forward with bypass version.""" + if self.original_forward is not None: + logging.debug( + f"[BypassHook] Already injected for {type(self.module).__name__}" + ) + return # Already injected + + # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward + device = None + dtype = None + if hasattr(self.module, "weight") and self.module.weight is not None: + device = self.module.weight.device + dtype = self.module.weight.dtype + elif hasattr(self.module, "W_q"): # Quantized layers might use different attr + device = self.module.W_q.device + dtype = self.module.W_q.dtype + + if device is not None: + self._move_adapter_weights_to_device(device, dtype) + + self.original_forward = self.module.forward + self.module.forward = self._bypass_forward + logging.debug( + f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" + ) + + def _move_adapter_weights_to_device(self, device, dtype=None): + """Move adapter weights to specified device to avoid per-forward transfers. + + Handles both: + - WeightAdapterBase: has self.weights tuple of tensors + - WeightAdapterTrainBase: nn.Module with parameters, uses .to() method + """ + adapter = self.adapter + + # Check if adapter is an nn.Module (WeightAdapterTrainBase) + if isinstance(adapter, nn.Module): + # In training mode we don't touch dtype as trainer will handle it + adapter.to(device=device) + logging.debug( + f"[BypassHook] Moved training adapter (nn.Module) to {device}" + ) + return + + # WeightAdapterBase: handle self.weights tuple + if not hasattr(adapter, "weights") or adapter.weights is None: + return + + weights = adapter.weights + if isinstance(weights, (list, tuple)): + new_weights = [] + for w in weights: + if isinstance(w, torch.Tensor): + if dtype is not None: + new_weights.append(w.to(device=device, dtype=dtype)) + else: + new_weights.append(w.to(device=device)) + else: + new_weights.append(w) + adapter.weights = ( + tuple(new_weights) if isinstance(weights, tuple) else new_weights + ) + elif isinstance(weights, torch.Tensor): + if dtype is not None: + adapter.weights = weights.to(device=device, dtype=dtype) + else: + adapter.weights = weights.to(device=device) + + logging.debug(f"[BypassHook] Moved adapter weights to {device}") + + def eject(self): + """Restore original module forward.""" + if self.original_forward is None: + logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + return # Not injected + + self.module.forward = self.original_forward + self.original_forward = None + logging.debug( + f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" + ) + + +class BypassInjectionManager: + """ + Manages bypass mode injection for a collection of adapters. + + Creates PatcherInjection objects that can be used with ModelPatcher. + + Supports both inference adapters (WeightAdapterBase) and training adapters + (WeightAdapterTrainBase). + + Usage: + manager = BypassInjectionManager() + manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8) + manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8) + + injections = manager.create_injections(model) + model_patcher.set_injections("bypass_lora", injections) + """ + + def __init__(self): + self.adapters: dict[str, tuple[BypassAdapter, float]] = {} + self.hooks: list[BypassForwardHook] = [] + + def add_adapter( + self, + key: str, + adapter: BypassAdapter, + strength: float = 1.0, + ): + """ + Add an adapter for a specific weight key. + + Args: + key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight") + adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.) + strength: Multiplier for adapter effect + """ + # Remove .weight suffix if present for module lookup + module_key = key + if module_key.endswith(".weight"): + module_key = module_key[:-7] + logging.debug( + f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" + ) + + self.adapters[module_key] = (adapter, strength) + logging.debug( + f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" + ) + + def clear_adapters(self): + """Remove all adapters.""" + self.adapters.clear() + + def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]: + """Get a submodule by dot-separated key.""" + parts = key.split(".") + module = model + try: + for i, part in enumerate(parts): + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + logging.debug( + f"[BypassManager] Found module for key {key}: {type(module).__name__}" + ) + return module + except (AttributeError, IndexError, KeyError) as e: + logging.error(f"[BypassManager] Failed to find module for key {key}: {e}") + logging.error( + f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}" + ) + return None + + def create_injections(self, model: nn.Module) -> list[PatcherInjection]: + """ + Create PatcherInjection objects for all registered adapters. + + Args: + model: The model to inject into (e.g., model_patcher.model) + + Returns: + List of PatcherInjection objects to use with model_patcher.set_injections() + """ + self.hooks.clear() + + logging.debug( + f"[BypassManager] create_injections called with {len(self.adapters)} adapters" + ) + logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + + for key, (adapter, strength) in self.adapters.items(): + logging.debug(f"[BypassManager] Looking for module: {key}") + module = self._get_module_by_key(model, key) + + if module is None: + logging.warning(f"[BypassManager] Module not found for key {key}") + continue + + if not hasattr(module, "weight"): + logging.warning( + f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})" + ) + continue + + logging.debug( + f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" + ) + hook = BypassForwardHook(module, adapter, multiplier=strength) + self.hooks.append(hook) + + logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + + # Create single injection that manages all hooks + def inject_all(model_patcher): + logging.debug( + f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.inject() + logging.debug( + f"[BypassManager] Injected hook for {type(hook.module).__name__}" + ) + + def eject_all(model_patcher): + logging.debug( + f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.eject() + + return [PatcherInjection(inject=inject_all, eject=eject_all)] + + def get_hook_count(self) -> int: + """Return number of hooks that will be/are injected.""" + return len(self.hooks) + + +def create_bypass_injections_from_patches( + model: nn.Module, + patches: dict, + strength: float = 1.0, +) -> list[PatcherInjection]: + """ + Convenience function to create bypass injections from a patches dict. + + This is useful when you have patches in the format used by model_patcher.add_patches() + and want to apply them in bypass mode instead. + + Args: + model: The model to inject into + patches: Dict mapping weight keys to adapter data + strength: Global strength multiplier + + Returns: + List of PatcherInjection objects + """ + manager = BypassInjectionManager() + + for key, patch_list in patches.items(): + if not patch_list: + continue + + # patches format: list of (strength_patch, patch_data, strength_model, offset, function) + for patch in patch_list: + patch_strength, patch_data, strength_model, offset, function = patch + + # patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple + if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)): + adapter = patch_data + else: + # Skip non-adapter patches + continue + + combined_strength = strength * patch_strength + manager.add_adapter(key, adapter, strength=combined_strength) + + return manager.create_injections(model) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5..d6b97a23b 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,8 @@ import logging -from typing import Optional +from typing import Callable, Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, weight_decompose @@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + weights = ( + lora[a1_name], + lora[a2_name], + lora[b1_name], + lora[b2_name], + alpha, + dora_scale, + ) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase): old_glora = True if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + if ( + old_glora + and v[1].shape[0] == weight.shape[0] + and weight.shape[0] == weight.shape[1] + ): pass else: old_glora = False rank = v[1].shape[0] - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + a1 = comfy.model_management.cast_to_device( + v[0].flatten(start_dim=1), weight.device, intermediate_dtype + ) + a2 = comfy.model_management.cast_to_device( + v[1].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b1 = comfy.model_management.cast_to_device( + v[2].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b2 = comfy.model_management.cast_to_device( + v[3].flatten(start_dim=1), weight.device, intermediate_dtype + ) if v[4] is not None: alpha = v[4] / rank @@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase): try: if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + lora_diff = ( + torch.mm(b2, b1) + + torch.mm( + torch.mm( + weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2 + ), + a1, + ) + ).reshape( + weight.shape + ) # old lycoris glora else: if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.einsum( + "o i ..., i j -> o j ...", + torch.einsum( + "o i ..., i j -> o j ...", + weight.to(dtype=intermediate_dtype), + a1, + ), + a2, + ).reshape(weight.shape) else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.mm( + torch.mm(weight.to(dtype=intermediate_dtype), a1), a2 + ).reshape(weight.shape) lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _compute_paths(self, x: torch.Tensor): + """ + Compute A path and B path outputs for GLoRA bypass. + + GLoRA: f(x) = Wx + WAx + Bx + - A path: a1(a2(x)) - modifies input to base forward + - B path: b1(b2(x)) - additive component + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Returns: (a_out, b_out) + """ + v = self.weights + # v = (a1, a2, b1, b2, alpha, dora_scale) + a1 = v[0] + a2 = v[1] + b1 = v[2] + b2 = v[3] + alpha = v[4] + + dtype = x.dtype + + # Cast dtype (weights should already be on correct device from inject()) + a1 = a1.to(dtype=dtype) + a2 = a2.to(dtype=dtype) + b1 = b1.to(dtype=dtype) + b2 = b2.to(dtype=dtype) + + # Determine rank and scale + # Check for old vs new glora format + old_glora = False + if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]: + rank = a1.shape[0] + old_glora = True + + if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]: + if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]: + pass + else: + old_glora = False + rank = a2.shape[0] + + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + scale = scale * multiplier + + # Use module info from bypass injection, not input tensor shape + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + # Conv case - conv_dim is 1/2/3 for conv1d/2d/3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Get module's stride/padding for spatial dimension handling + module_stride = kw_dict.get("stride", (1,) * conv_dim) + module_padding = kw_dict.get("padding", (0,) * conv_dim) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Ensure weights are in conv shape + # a1, a2, b1 are always 1x1 kernels + if a1.ndim == 2: + a1 = a1.view(*a1.shape, *([1] * conv_dim)) + if a2.ndim == 2: + a2 = a2.view(*a2.shape, *([1] * conv_dim)) + if b1.ndim == 2: + b1 = b1.view(*b1.shape, *([1] * conv_dim)) + # b2 has actual kernel_size (like LoRA down) + if b2.ndim == 2: + if in_channels is not None: + b2 = b2.view(b2.shape[0], in_channels, *kernel_size) + else: + b2 = b2.view(*b2.shape, *([1] * conv_dim)) + + # A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x + a2_out = conv_fn(x, a2) + a_out = conv_fn(a2_out, a1) * scale + + # B path: b2(x) with kernel/stride/padding -> b1(...) 1x1 + b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding) + b_out = conv_fn(b2_out, b1) * scale + else: + # Linear case + if old_glora: + # Old format: a1 @ a2 @ x, b2 @ b1 + a_out = F.linear(F.linear(x, a2), a1) * scale + b_out = F.linear(F.linear(x, b1), b2) * scale + else: + # New format: x @ a1 @ a2, b1 @ b2 + a_out = F.linear(F.linear(x, a1), a2) * scale + b_out = F.linear(F.linear(x, b2), b1) * scale + + return a_out, b_out + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + GLoRA bypass forward: f(x + a(x)) + b(x) + + Unlike standard adapters, GLoRA modifies the input to the base forward + AND adds the B path output. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Reference: LyCORIS GLoRAModule._bypass_forward + """ + a_out, b_out = self._compute_paths(x) + + # Call base forward with modified input + base_out = org_forward(x + a_out, *args, **kwargs) + + # Add B path + return base_out + b_out + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + For GLoRA, h() returns the B path output. + + Note: + GLoRA's full bypass requires overriding bypass_forward() since + it also modifies the input to org_forward. This h() is provided for + compatibility but bypass_forward() should be used for correct behavior. + + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + _, b_out = self._compute_paths(x) + return b_out diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d403..8007b7b44 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,11 +1,22 @@ import logging +from functools import cache from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose +@cache +def _warn_loha_bypass_inefficient(): + """One-time warning about LoHa bypass inefficiency.""" + logging.warning( + "LoHa bypass mode is inefficient: full weight diff is computed each forward pass. " + "Consider using LoRA or LoKr for training with bypass mode." + ) + + class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): @@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase): scale = self.alpha / self.rank if self.use_tucker: - diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeightTucker.apply( + self.hada_t1, + self.hada_w1_a, + self.hada_w1_b, + self.hada_t2, + self.hada_w2_a, + self.hada_w2_b, + scale, + ) else: - diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeight.apply( + self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale + ) # Add the scaled difference to the original weight weight = w.to(diff_weight) + diff_weight.reshape(w.shape) @@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase): mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) - return LohaDiff( - (mat1, mat2, alpha, mat3, mat4, None, None, None) - ) + return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None)) def to_train(self): return LohaDiff(self.weights) @@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + weights = ( + lora[hada_w1_a_name], + lora[hada_w1_b_name], + alpha, + lora[hada_w2_a_name], + lora[hada_w2_b_name], + hada_t1, + hada_t2, + dora_scale, + ) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase): w2a = v[3] w2b = v[4] dora_scale = v[7] - if v[5] is not None: #cp decomposition + if v[5] is not None: # cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + m1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t1, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + ) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + m2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + ) else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + m1 = torch.mm( + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + ) + m2 = torch.mm( + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + ) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoHa: h(x) = diff_weight @ x + + WARNING: Inefficient - computes full Hadamard product each forward. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/loha.py bypass_forward_diff + """ + _warn_loha_bypass_inefficient() + + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora + w1a = v[0] + w1b = v[1] + alpha = v[2] + w2a = v[3] + w2b = v[4] + t1 = v[5] + t2 = v[6] + + # Compute scale + rank = w1b.shape[0] + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Cast dtype + w1a = w1a.to(dtype=x.dtype) + w1b = w1b.to(dtype=x.dtype) + w2a = w2a.to(dtype=x.dtype) + w2b = w2b.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Compute diff weight using Hadamard product + if t1 is not None and t2 is not None: + t1 = t1.to(dtype=x.dtype) + t2 = t2.to(dtype=x.dtype) + m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a) + m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a) + diff_weight = (m1 * m2) * scale + else: + m1 = w1a @ w1b + m2 = w2a @ w2b + diff_weight = (m1 * m2) * scale + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D diff_weight to conv format using kernel_size + # diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size] + if diff_weight.dim() == 2: + if in_channels is not None: + diff_weight = diff_weight.view( + diff_weight.shape[0], in_channels, *kernel_size + ) + else: + diff_weight = diff_weight.view( + *diff_weight.shape, *([1] * conv_dim) + ) + else: + op = F.linear + kw_dict = {} + + return op(x, diff_weight, **kw_dict) diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7..b83750012 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -14,7 +15,17 @@ from .base import ( class LokrDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) = weights self.use_tucker = False if lokr_w1_a is not None: _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] @@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase): if self.w2_rebuild: if self.use_tucker: w2 = torch.einsum( - 'i j k l, j r, i p -> p r k l', + "i j k l, j r, i p -> p r k l", self.lokr_t2, self.lokr_w2_b, - self.lokr_w2_a + self.lokr_w2_a, ) else: w2 = self.lokr_w2_a @ self.lokr_w2_b @@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase): return self.lokr_w2 def __call__(self, w): - diff = torch.kron(self.w1, self.w2) + w1 = self.w1 + w2 = self.w2 + # Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron) + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + diff = torch.kron(w1, w2) return w + diff.reshape(w.shape).to(w) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr training: efficient Kronecker product. + + Uses w1/w2 properties which handle both direct and decomposed cases. + For create_train (direct w1/w2), no alpha scaling in properties. + For to_train (decomposed), alpha/rank scaling is in properties. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Get w1, w2 from properties (handles rebuild vs direct) + w1 = self.w1 + w2 = self.w2 + + # Multiplier from bypass injection + multiplier = getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Efficient Kronecker application without materializing full weight + # kron(w1, w2) @ x can be computed as nested operations + # w1: [out_l, in_m], w2: [out_k, in_n, *k_size] + # Full weight would be [out_l*out_k, in_m*in_n, *k_size] + + uq = w1.size(1) # in_m - inner grouping dimension + + if is_conv: + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + B, C_in, *spatial = x.shape + # Reshape input for grouped application: [B * uq, C_in // uq, *spatial] + h_in_group = x.reshape(B * uq, -1, *spatial) + + # Ensure w2 has conv dims + if w2.dim() == 2: + w2 = w2.view(*w2.shape, *([1] * conv_dim)) + + # Apply w2 path with stride/padding + hb = conv_fn(h_in_group, w2, **kw_dict) + + # Reshape for cross-group operation + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross = hb.transpose(1, -1) + + # Apply w1 (always 2D, applied as linear on channel dim) + hc = F.linear(h_cross, w1) + hc = hc.transpose(1, -1) + + # Reshape to output + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + # Linear case + # Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n] + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k] + hb = F.linear(h_in_group, w2) + + # Transpose for w1: [..., uq, out_k] -> [..., out_k, uq] + h_cross = hb.transpose(-1, -2) + + # Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l] + hc = F.linear(h_cross, w1) + + # Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k] + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * multiplier + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase): @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] - in_dim = weight.shape[1:].numel() - out1, out2 = factorization(out_dim, rank) - in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) + in_dim = weight.shape[1] # Just in_channels, not flattened with kernel + k_size = weight.shape[2:] if weight.dim() > 2 else () + + out_l, out_k = factorization(out_dim, rank) + in_m, in_n = factorization(in_dim, rank) + + # w1: [out_l, in_m] + mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32) + # w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear + mat2 = torch.empty( + out_k, in_n, *k_size, device=weight.device, dtype=torch.float32 + ) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) - return LokrDiff( - (mat1, mat2, alpha, None, None, None, None, None, None) - ) + return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None)) def to_train(self): return LokrDiff(self.weights) @@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + if ( + (lokr_w1 is not None) + or (lokr_w2 is not None) + or (lokr_w1_a is not None) + or (lokr_w2_a is not None) + ): + weights = ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) return cls(loaded_keys, weights) else: return None @@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + w1 = torch.mm( + comfy.model_management.cast_to_device( + w1_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1_b, weight.device, intermediate_dtype + ), + ) else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + w1 = comfy.model_management.cast_to_device( + w1, weight.device, intermediate_dtype + ) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + w2 = torch.mm( + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + ) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + w2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + ) else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + w2 = comfy.model_management.cast_to_device( + w2, weight.device, intermediate_dtype + ) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr: efficient Kronecker product application. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/lokr.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora + w1 = v[0] + w2 = v[1] + alpha = v[2] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t2 is not None + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) if is_conv else {} + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + else: + op = F.linear + + # Determine rank and scale + rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Build c (w1) + if use_w1: + c = w1.to(dtype=x.dtype) + else: + c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype) + uq = c.size(1) + + # Build w2 components + if use_w2: + ba = w2.to(dtype=x.dtype) + else: + a = w2_b.to(dtype=x.dtype) + b = w2_a.to(dtype=x.dtype) + if is_conv: + if tucker: + # Tucker: a, b get 1s appended (kernel is in t2) + if a.dim() == 2: + a = a.view(*a.shape, *([1] * conv_dim)) + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + else: + # Non-tucker conv: b may need 1s appended + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + + # Reshape input by uq groups + if is_conv: + B, _, *rest = x.shape + h_in_group = x.reshape(B * uq, -1, *rest) + else: + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2 path + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + t = t2.to(dtype=x.dtype) + if t.dim() == 2: + t = t.view(*t.shape, *([1] * conv_dim)) + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a) + hb = op(ha, b) + + # Reshape and apply c (w1) + if is_conv: + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + + if is_conv: + hc = hc.transpose(1, -1) + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * scale diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b..bc4260a8f 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase): rank, in_dim = mat2.shape[0], mat2.shape[1] if mid is not None: convdim = mid.ndim - 2 - layer = ( - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d - )[convdim] + layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim] else: layer = torch.nn.Linear self.lora_up = layer(rank, out_dim, bias=False) @@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase): weight = w + scale * diff.reshape(w.shape) return weight.to(org_dtype) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA training: h(x) = up(down(x)) * scale + + Simple implementation using the nn.Module weights directly. + No mid/dora/reshape branches (create_train doesn't create them). + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Compute scale = alpha / rank * multiplier + scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Get weights (keep in original dtype for numerical stability) + down_weight = self.lora_down.weight + up_weight = self.lora_up.weight + + if is_conv: + # Conv path: use functional conv + # conv_dim: 1=conv1d, 2=conv2d, 3=conv3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Reshape 2D weights to conv format if needed + # down: [rank, in_features] -> [rank, in_channels, *kernel_size] + # up: [out_features, rank] -> [out_features, rank, 1, 1, ...] + if down_weight.dim() == 2: + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + if in_channels is not None: + down_weight = down_weight.view( + down_weight.shape[0], in_channels, *kernel_size + ) + else: + # Fallback: assume 1x1 kernel + down_weight = down_weight.view( + *down_weight.shape, *([1] * conv_dim) + ) + if up_weight.dim() == 2: + # up always uses 1x1 kernel + up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim)) + + # down conv uses stride/padding from module, up is 1x1 + hidden = conv_fn(x, down_weight, **kw_dict) + + # mid layer if exists (tucker decomposition) + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + if mid_weight.dim() == 2: + mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim)) + hidden = conv_fn(hidden, mid_weight) + + # up conv is always 1x1 (no stride/padding) + out = conv_fn(hidden, up_weight) + else: + # Linear path: simple matmul chain + hidden = F.linear(x, down_weight) + + # mid layer if exists + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + hidden = F.linear(hidden, mid_weight) + + out = F.linear(hidden, up_weight) + + return out * scale + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase): mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) - return LoraDiff( - (mat1, mat2, alpha, None, None, None) - ) + return LoraDiff((mat1, mat2, alpha, None, None, None)) def to_train(self): return LoraDiff(self.weights) @@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase): except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA: h(x) = up(down(x)) * scale + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/locon.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape + up = v[0] + down = v[1] + alpha = v[2] + mid = v[3] + + # Compute scale = alpha / rank + rank = down.shape[0] + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + scale = scale * getattr(self, "multiplier", 1.0) + + # Cast dtype + up = up.to(dtype=x.dtype) + down = down.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + op = FUNC_LIST[ + conv_dim + 2 + ] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D weights to conv format using kernel_size + # down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size] + # up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel) + if down.dim() == 2: + # down.shape[1] = in_channels * prod(kernel_size) + if in_channels is not None: + down = down.view(down.shape[0], in_channels, *kernel_size) + else: + # Fallback: assume 1x1 kernel if in_channels unknown + down = down.view(*down.shape, *([1] * conv_dim)) + if up.dim() == 2: + # up always uses 1x1 kernel + up = up.view(*up.shape, *([1] * conv_dim)) + if mid is not None: + mid = mid.to(dtype=x.dtype) + if mid.dim() == 2: + mid = mid.view(*mid.shape, *([1] * conv_dim)) + else: + op = F.linear + kw_dict = {} # linear doesn't take stride/padding + + # Simple chain: down -> mid (if tucker) -> up + if mid is not None: + if not is_conv: + mid = mid.to(dtype=x.dtype) + hidden = op(x, down) + hidden = op(hidden, mid, **kw_dict) + out = op(hidden, up) + else: + hidden = op(x, down, **kw_dict) + out = op(hidden, up) + + return out * scale diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab9635..bc83cf8e8 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,13 +3,18 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) class OFTDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - # Unpack weights tuple from LoHaAdapter + # Unpack weights tuple from OFTAdapter blocks, rescale, alpha, _ = weights # Create trainable parameters @@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase): weight = self.rescale * weight return weight.to(org_dtype) + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + blocks = self.oft_blocks.to(device=device, dtype=dtype) + I = torch.eye(self.block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if set + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r.to(dtype) + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + OFT has no additive component - returns zeros matching base_out shape. + + OFT only transforms the output via g(), it doesn't add to it. + """ + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation. + + OFT transforms output channels using block-diagonal orthogonal matrices. + """ + r = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(self.block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.reshape(*batch_shape, out_features) + + # Apply rescale if present + if self.rescaled: + rescale = self.rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out + def passive_memory_usage(self): """Calculates memory usage of the trainable parameters.""" return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) - return OFTDiff( - (block, None, alpha, None) + block = torch.zeros( + block_num, block_size, block_size, device=weight.device, dtype=torch.float32 ) + return OFTDiff((block, None, alpha, None)) def to_train(self): return OFTDiff(self.weights) @@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase): alpha = 0 dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) block_num, block_size, *_ = blocks.shape @@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(1, 2) normed_q = q - if alpha > 0: # alpha in oft/boft is for constraint + if alpha > 0: # alpha in oft/boft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() r = r.to(weight) + # Create I in weight's dtype for the einsum + I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype) _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", - (r * strength) - strength * I, + (r * strength) - strength * I_w, weight.view(block_num, block_size, *shape), ).view(-1, *shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + block_num, block_size, _ = blocks.shape + I = torch.eye(block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, block_num, block_size + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation to output. + + OFT transforms the output channels using block-diagonal orthogonal matrices. + + Reference: LyCORIS DiagOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.view(*batch_shape, block_num, block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.view(*batch_shape, out_features) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 68a73cf13..024a89391 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps +from comfy.weight_adapter.bypass import BypassInjectionManager from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar @@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + for param_groups in self.optimizer.param_groups: + for param in param_groups["params"]: + if param.grad is None: + continue + param.grad.data = param.grad.data.to(param.data.dtype) self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) @@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): num_images = sum(t.shape[0] for t in latents) multi_res = False # Not using multi_res path in bucket mode - logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") for i, lat in enumerate(latents): - logging.info(f" Bucket {i}: shape {lat.shape}") + logging.debug(f" Bucket {i}: shape {lat.shape}") return latents, num_images, multi_res # Non-bucket mode @@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") + logging.debug(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: @@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode): if bucket_mode: return positive # Skip validation in bucket mode - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: return positive * num_images elif len(positive) != num_images: @@ -596,6 +602,8 @@ def _create_weight_adapter( shape = module.weight.shape lora_params = {} + logging.debug(f"Creating weight adapter for {key} with shape {shape}") + if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get(f"{key}.dora_scale", None) @@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): return lora_sd, all_weight_adapters +def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup LoRA adapters in bypass mode. + + In bypass mode: + - Weight adapters (lora/lokr/oft) use bypass injection (forward hook) + - Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification) + + This is useful when the base model weights are quantized and cannot be + directly modified. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list, bypass_manager) + """ + lora_sd = {} + all_weight_adapters = [] + bypass_manager = BypassInjectionManager() + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + all_weight_adapters.append(adapter) + + key = f"{n}.weight" + # BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass + # Only use bypass for adapters that have h() method (lora/lokr/oft) + if isinstance(adapter, BiasDiff): + mp.add_weight_wrapper(key, adapter) + logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}") + else: + bypass_manager.add_adapter(key, adapter, strength=1.0) + logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}") + + if hasattr(m, "bias") and m.bias is not None: + # Bias adapters still use weight wrapper (bias is usually not quantized) + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}") + + return lora_sd, all_weight_adapters, bypass_manager + + def _create_optimizer(optimizer_name, parameters, learning_rate): """Create optimizer based on name. @@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", ), + io.Boolean.Input( + "bypass_mode", + default=False, + tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", + ), ], outputs=[ - io.Model.Output( - display_name="model", tooltip="Model with LoRA applied" - ), io.Custom("LORA_MODEL").Output( display_name="lora", tooltip="LoRA weights" ), @@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing, existing_lora, bucket_mode, + bypass_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] + bypass_mode = bypass_mode[0] # Process latents based on mode if bucket_mode: @@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode): existing_weights, existing_steps = _load_existing_lora(existing_lora) # Setup LoRA adapters - lora_sd, all_weight_adapters = _setup_lora_adapters( - mp, existing_weights, algorithm, lora_dtype, rank - ) + bypass_manager = None + if bypass_mode: + logging.debug("Using bypass mode for training") + lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass( + mp, existing_weights, algorithm, lora_dtype, rank + ) + else: + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) # Create optimizer and loss function optimizer = _create_optimizer( @@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode): guider = TrainGuider(mp) guider.set_conds(positive) + # Inject bypass hooks if bypass mode is enabled + bypass_injections = None + if bypass_manager is not None: + bypass_injections = bypass_manager.create_injections(mp.model) + for injection in bypass_injections: + injection.inject(mp) + logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks") + # Run training loop try: _run_training_loop( @@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + # Eject bypass hooks if they were injected + if bypass_injections is not None: + for injection in bypass_injections: + injection.eject(mp) + logging.debug("[BypassMode] Ejected bypass hooks") for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer @@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode): for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) + # mp in train node is highly specialized for training + # use it in inference will result in bad behavior so we don't return it + return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) class LoraModelLoader(io.ComfyNode):# diff --git a/nodes.py b/nodes.py index b75247665..8a8df9246 100644 --- a/nodes.py +++ b/nodes.py @@ -722,6 +722,69 @@ class LoraLoaderModelOnly(LoraLoader): def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @@ -2067,6 +2130,8 @@ NODE_CLASS_MAPPINGS = { "LatentFlip": LatentFlip, "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, "CLIPLoader": CLIPLoader, "UNETLoader": UNETLoader, "DualCLIPLoader": DualCLIPLoader, @@ -2106,6 +2171,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA", + "LoraLoaderBypass": "Load LoRA (Bypass)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)",