[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)

* Add API of bypass forward module

* bypass implementation

* add bypass fwd into nodes list/trainer
This commit is contained in:
Kohaku-Blueleaf 2026-01-25 11:56:22 +08:00 committed by GitHub
parent 635406e283
commit a97c98068f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 2039 additions and 101 deletions

View File

@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import comfy.weight_adapter
import yaml
import math
import os
@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
return (new_modelpatcher, new_clip)
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
"""
Load LoRA in bypass mode without modifying base model weights.
Instead of patching weights, this injects the LoRA computation into the
forward pass: output = base_forward(x) + lora_path(x)
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
This is useful for training and when model weights are offloaded.
"""
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
# Separate adapters (for bypass) from other patches (for regular patching)
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
regular_patches = {} # diff, set, bias patches -> regular weight patching
for key, patch_data in loaded.items():
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
bypass_patches[key] = patch_data
else:
regular_patches[key] = patch_data
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
k = set()
k1 = set()
if model is not None:
new_modelpatcher = model.clone()
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
if regular_patches:
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
k.update(patched_keys)
# Apply adapter patches via bypass injection
manager = comfy.weight_adapter.BypassInjectionManager()
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in model_sd_keys:
manager.add_adapter(key, adapter, strength=strength_model)
k.add(key)
else:
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
injections = manager.create_injections(new_modelpatcher.model)
if manager.get_hook_count() > 0:
new_modelpatcher.set_injections("bypass_lora", injections)
else:
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
# Apply regular patches to clip
if regular_patches:
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
k1.update(patched_keys)
# Apply adapter patches via bypass injection
clip_manager = comfy.weight_adapter.BypassInjectionManager()
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in clip_sd_keys:
clip_manager.add_adapter(key, adapter, strength=strength_clip)
k1.add(key)
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
if clip_manager.get_hook_count() > 0:
new_clip.patcher.set_injections("bypass_lora", clip_injections)
else:
new_clip = None
for x in loaded:
if (x not in k) and (x not in k1):
patch_data = loaded[x]
patch_type = type(patch_data).__name__
if isinstance(patch_data, tuple):
patch_type = f"tuple({patch_data[0]})"
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
return (new_modelpatcher, new_clip)
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:

View File

@ -5,6 +5,11 @@ from .lokr import LoKrAdapter
from .glora import GLoRAAdapter
from .oft import OFTAdapter
from .boft import BOFTAdapter
from .bypass import (
BypassInjectionManager,
BypassForwardHook,
create_bypass_injections_from_patches,
)
adapters: list[type[WeightAdapterBase]] = [
@ -31,4 +36,7 @@ __all__ = [
"WeightAdapterTrainBase",
"adapters",
"adapter_maps",
"BypassInjectionManager",
"BypassForwardHook",
"create_bypass_injections_from_patches",
] + [a.__name__ for a in adapters]

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn as nn
@ -7,12 +7,35 @@ import comfy.model_management
class WeightAdapterBase:
"""
Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
Bypass Mode:
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
- h(x): Additive component (LoRA path). Returns delta to add to base output.
- g(y): Output transformation. Applied after base + h(x).
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
For OFT/BOFT: g = transform, h = 0
"""
name: str
loaded_keys: set[str]
weights: list[torch.Tensor]
# Attributes set by bypass system
multiplier: float = 1.0
shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel)
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
@ -39,18 +62,202 @@ class WeightAdapterBase:
):
raise NotImplementedError
# ===== Bypass Mode Methods =====
#
# IMPORTANT: Bypass mode is designed for quantized models where original weights
# may not be accessible in a usable format. Therefore, h() and bypass_forward()
# do NOT take org_weight as a parameter. All necessary information (out_channels,
# in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook.
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component: h(x, base_out)
Computes the adapter's contribution to be added to base forward output.
For adapters that only transform output (OFT/BOFT), returns zeros.
Note:
This method does NOT access original model weights. Bypass mode is
designed for quantized models where weights may not be in a usable format.
All shape info comes from module attributes set by BypassForwardHook.
Args:
x: Input tensor
base_out: Output from base forward f(x), can be used for shape reference
Returns:
Delta tensor to add to base output. Shape matches base output.
Reference: LyCORIS LoConModule.bypass_forward_diff
"""
# Default: no additive component (for OFT/BOFT)
# Simply return zeros matching base_out shape
return torch.zeros_like(base_out)
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation: g(y)
Applied after base forward + h(x). For most adapters this is identity.
OFT/BOFT override this to apply orthogonal transformation.
Args:
y: Combined output (base + h(x))
Returns:
Transformed output
Reference: LyCORIS OFTModule applies orthogonal transform here
"""
# Default: identity (for LoRA/LoHa/LoKr)
return y
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
Full bypass forward: g(f(x) + h(x, f(x)))
Note:
This method does NOT take org_weight/org_bias parameters. Bypass mode
is designed for quantized models where weights may not be accessible.
The original forward function handles weight access internally.
Args:
org_forward: Original module forward function
x: Input tensor
*args, **kwargs: Additional arguments for org_forward
Returns:
Output with adapter applied in bypass mode
Reference: LyCORIS LoConModule.bypass_forward
"""
# Base forward: f(x)
base_out = org_forward(x, *args, **kwargs)
# Additive component: h(x, base_out) - base_out provided for shape reference
h_out = self.h(x, base_out)
# Output transformation: g(base + h)
return self.g(base_out + h_out)
class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
"""
Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
Bypass Mode:
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
- h(x): Additive component (LoRA path). Returns delta to add to base output.
- g(y): Output transformation. Applied after base + h(x).
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
For OFT: g = transform, h = 0
Note:
Unlike WeightAdapterBase, TrainBase classes have simplified weight formats
with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition).
We follow the scheme of PR #7032
"""
# Attributes set by bypass system (BypassForwardHook)
# These are set before h()/g()/bypass_forward() are called
multiplier: float = 1.0
is_conv: bool = False
conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d
kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups
kernel_size: tuple = ()
in_channels: int = None
out_channels: int = None
def __init__(self):
super().__init__()
def __call__(self, w):
"""
w: The original weight tensor to be modified.
Weight modification mode: returns modified weight.
Args:
w: The original weight tensor to be modified.
Returns:
Modified weight tensor.
"""
raise NotImplementedError
# ===== Bypass Mode Methods =====
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component: h(x, base_out)
Computes the adapter's contribution to be added to base forward output.
For adapters that only transform output (OFT), returns zeros.
Args:
x: Input tensor
base_out: Output from base forward f(x), can be used for shape reference
Returns:
Delta tensor to add to base output. Shape matches base output.
Subclasses should override this method.
"""
raise NotImplementedError(
f"{self.__class__.__name__}.h() not implemented. "
"Subclasses must implement h() for bypass mode."
)
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation: g(y)
Applied after base forward + h(x). For most adapters this is identity.
OFT overrides this to apply orthogonal transformation.
Args:
y: Combined output (base + h(x))
Returns:
Transformed output
"""
# Default: identity (for LoRA/LoHa/LoKr)
return y
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
Full bypass forward: g(f(x) + h(x, f(x)))
Args:
org_forward: Original module forward function
x: Input tensor
*args, **kwargs: Additional arguments for org_forward
Returns:
Output with adapter applied in bypass mode
"""
# Base forward: f(x)
base_out = org_forward(x, *args, **kwargs)
# Additive component: h(x, base_out) - base_out provided for shape reference
h_out = self.h(x, base_out)
# Output transformation: g(base + h)
return self.g(base_out + h_out)
def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")
@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module):
return self.passive_memory_usage()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
def weight_decompose(
dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function
):
dora_scale = comfy.model_management.cast_to_device(
dora_scale, weight.device, intermediate_dtype
)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
raise ValueError(
"The new shape must be larger than the original tensor in all dimensions"
)
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
raise ValueError(
"The new shape must have the same number of dimensions as the original tensor"
)
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)

View File

@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase):
alpha = v[2]
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
blocks = comfy.model_management.cast_to_device(
blocks, weight.device, intermediate_dtype
)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
rescale = comfy.model_management.cast_to_device(
rescale, weight.device, intermediate_dtype
)
boft_m, block_num, boft_b, *_ = blocks.shape
@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase):
# for Q = -Q^T
q = blocks - blocks.transpose(-1, -2)
normed_q = q
if alpha > 0: # alpha in boft/bboft is for constraint
if alpha > 0: # alpha in boft/bboft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase):
r = r.to(weight)
inp = org = weight
r_b = boft_b//2
r_b = boft_b // 2
for i in range(boft_m):
bi = r[i]
g = 2
k = 2**i * r_b
if strength != 1:
bi = bi * strength + (1-strength) * I
bi = bi * strength + (1 - strength) * I
inp = (
inp.unflatten(0, (-1, g, k))
.transpose(1, 2)
@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase):
)
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
inp = (
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
inp.flatten(0, 1)
.unflatten(0, (-1, k, g))
.transpose(1, 2)
.flatten(0, 2)
)
if rescale is not None:
inp = inp * rescale
lora_diff = inp - org
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
lora_diff = comfy.model_management.cast_to_device(
lora_diff, weight.device, intermediate_dtype
)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _get_orthogonal_matrices(self, device, dtype):
"""Compute the orthogonal rotation matrices R from BOFT blocks."""
v = self.weights
blocks = v[0].to(device=device, dtype=dtype)
alpha = v[2]
if alpha is None:
alpha = 0
boft_m, block_num, boft_b, _ = blocks.shape
I = torch.eye(boft_b, device=device, dtype=dtype)
# Q = blocks - blocks^T (skew-symmetric)
q = blocks - blocks.transpose(-1, -2)
normed_q = q
# Apply constraint if alpha > 0
if alpha > 0:
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# Cayley transform: R = (I + Q)(I - Q)^-1
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r, boft_m, boft_b
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation for BOFT: applies butterfly orthogonal transform.
BOFT uses multiple stages of butterfly-structured orthogonal transforms.
Reference: LyCORIS ButterflyOFTModule._bypass_forward
"""
v = self.weights
rescale = v[1]
r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype)
r_b = boft_b // 2
# Apply multiplier
multiplier = getattr(self, "multiplier", 1.0)
I = torch.eye(boft_b, device=y.device, dtype=y.dtype)
# Use module info from bypass injection to determine conv vs linear
is_conv = getattr(self, "is_conv", y.dim() > 2)
if is_conv:
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
y = y.transpose(1, -1)
# Apply butterfly transform stages
inp = y
for i in range(boft_m):
bi = r[i] # (block_num, boft_b, boft_b)
g = 2
k = 2**i * r_b
# Interpolate with identity based on multiplier
if multiplier != 1:
bi = bi * multiplier + (1 - multiplier) * I
# Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, boft_b))
)
# Apply block-diagonal orthogonal transform
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
# Reshape back
inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
)
# Apply rescale if present
if rescale is not None:
rescale = rescale.to(device=y.device, dtype=y.dtype)
inp = inp * rescale.transpose(0, -1)
if is_conv:
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
inp = inp.transpose(1, -1)
return inp

View File

@ -0,0 +1,437 @@
"""
Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.)
Bypass mode applies adapters during forward pass without modifying base weights:
bypass(f)(x) = g(f(x) + h(x))
Where:
- f(x): Original layer forward
- h(x): Additive component from adapter (LoRA path)
- g(y): Output transformation (identity for most adapters)
This is useful for:
- Training with gradient checkpointing
- Avoiding weight modifications when weights are offloaded
- Supporting multiple adapters with different strengths dynamically
"""
import logging
from typing import Optional, Union
import torch
import torch.nn as nn
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
# Type alias for adapters that support bypass mode
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
def get_module_type_info(module: nn.Module) -> dict:
"""
Determine module type and extract conv parameters from module class.
This is more reliable than checking weight.ndim, especially for quantized layers
where weight shape might be different.
Returns:
dict with keys: is_conv, conv_dim, stride, padding, dilation, groups
"""
info = {
"is_conv": False,
"conv_dim": 0,
"stride": (1,),
"padding": (0,),
"dilation": (1,),
"groups": 1,
"kernel_size": (1,),
"in_channels": None,
"out_channels": None,
}
# Determine conv type
if isinstance(module, nn.Conv1d):
info["is_conv"] = True
info["conv_dim"] = 1
elif isinstance(module, nn.Conv2d):
info["is_conv"] = True
info["conv_dim"] = 2
elif isinstance(module, nn.Conv3d):
info["is_conv"] = True
info["conv_dim"] = 3
elif isinstance(module, nn.Linear):
info["is_conv"] = False
info["conv_dim"] = 0
else:
# Try to infer from class name for custom/quantized layers
class_name = type(module).__name__.lower()
if "conv3d" in class_name:
info["is_conv"] = True
info["conv_dim"] = 3
elif "conv2d" in class_name:
info["is_conv"] = True
info["conv_dim"] = 2
elif "conv1d" in class_name:
info["is_conv"] = True
info["conv_dim"] = 1
elif "conv" in class_name:
info["is_conv"] = True
info["conv_dim"] = 2
# Extract conv parameters if it's a conv layer
if info["is_conv"]:
# Try to get stride, padding, dilation, groups, kernel_size from module
info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"])
info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"])
info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"])
info["groups"] = getattr(module, "groups", 1)
info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"])
info["in_channels"] = getattr(module, "in_channels", None)
info["out_channels"] = getattr(module, "out_channels", None)
# Ensure they're tuples
if isinstance(info["stride"], int):
info["stride"] = (info["stride"],) * info["conv_dim"]
if isinstance(info["padding"], int):
info["padding"] = (info["padding"],) * info["conv_dim"]
if isinstance(info["dilation"], int):
info["dilation"] = (info["dilation"],) * info["conv_dim"]
if isinstance(info["kernel_size"], int):
info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"]
return info
class BypassForwardHook:
"""
Hook that wraps a layer's forward to apply adapter in bypass mode.
Stores the original forward and replaces it with bypass version.
Supports both:
- WeightAdapterBase: Inference adapters (uses self.weights tuple)
- WeightAdapterTrainBase: Training adapters (nn.Module with parameters)
"""
def __init__(
self,
module: nn.Module,
adapter: BypassAdapter,
multiplier: float = 1.0,
):
self.module = module
self.adapter = adapter
self.multiplier = multiplier
self.original_forward = None
# Determine layer type and conv params from module class (works for quantized layers)
module_info = get_module_type_info(module)
# Set multiplier and layer type info on adapter for use in h()
adapter.multiplier = multiplier
adapter.is_conv = module_info["is_conv"]
adapter.conv_dim = module_info["conv_dim"]
adapter.kernel_size = module_info["kernel_size"]
adapter.in_channels = module_info["in_channels"]
adapter.out_channels = module_info["out_channels"]
# Store kw_dict for conv operations (like LyCORIS extra_args)
if module_info["is_conv"]:
adapter.kw_dict = {
"stride": module_info["stride"],
"padding": module_info["padding"],
"dilation": module_info["dilation"],
"groups": module_info["groups"],
}
else:
adapter.kw_dict = {}
def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x))
Note:
Bypass mode does NOT access original model weights (org_weight).
This is intentional - bypass mode is designed for quantized models
where weights may not be in a usable format. All necessary shape
information is provided via adapter attributes set during inject().
"""
# Check if adapter has custom bypass_forward (e.g., GLoRA)
adapter_bypass = getattr(self.adapter, "bypass_forward", None)
if adapter_bypass is not None:
# Check if it's overridden (not the base class default)
# Need to check both base classes since adapter could be either type
adapter_type = type(self.adapter)
is_default_bypass = (
adapter_type.bypass_forward is WeightAdapterBase.bypass_forward
or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward
)
if not is_default_bypass:
return adapter_bypass(self.original_forward, x, *args, **kwargs)
# Default bypass: g(f(x) + h(x, f(x)))
base_out = self.original_forward(x, *args, **kwargs)
h_out = self.adapter.h(x, base_out)
return self.adapter.g(base_out + h_out)
def inject(self):
"""Replace module forward with bypass version."""
if self.original_forward is not None:
logging.debug(
f"[BypassHook] Already injected for {type(self.module).__name__}"
)
return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
device = None
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None:
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward
logging.debug(
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
)
def _move_adapter_weights_to_device(self, device, dtype=None):
"""Move adapter weights to specified device to avoid per-forward transfers.
Handles both:
- WeightAdapterBase: has self.weights tuple of tensors
- WeightAdapterTrainBase: nn.Module with parameters, uses .to() method
"""
adapter = self.adapter
# Check if adapter is an nn.Module (WeightAdapterTrainBase)
if isinstance(adapter, nn.Module):
# In training mode we don't touch dtype as trainer will handle it
adapter.to(device=device)
logging.debug(
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
)
return
# WeightAdapterBase: handle self.weights tuple
if not hasattr(adapter, "weights") or adapter.weights is None:
return
weights = adapter.weights
if isinstance(weights, (list, tuple)):
new_weights = []
for w in weights:
if isinstance(w, torch.Tensor):
if dtype is not None:
new_weights.append(w.to(device=device, dtype=dtype))
else:
new_weights.append(w.to(device=device))
else:
new_weights.append(w)
adapter.weights = (
tuple(new_weights) if isinstance(weights, tuple) else new_weights
)
elif isinstance(weights, torch.Tensor):
if dtype is not None:
adapter.weights = weights.to(device=device, dtype=dtype)
else:
adapter.weights = weights.to(device=device)
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
def eject(self):
"""Restore original module forward."""
if self.original_forward is None:
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
return # Not injected
self.module.forward = self.original_forward
self.original_forward = None
logging.debug(
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
)
class BypassInjectionManager:
"""
Manages bypass mode injection for a collection of adapters.
Creates PatcherInjection objects that can be used with ModelPatcher.
Supports both inference adapters (WeightAdapterBase) and training adapters
(WeightAdapterTrainBase).
Usage:
manager = BypassInjectionManager()
manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8)
manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8)
injections = manager.create_injections(model)
model_patcher.set_injections("bypass_lora", injections)
"""
def __init__(self):
self.adapters: dict[str, tuple[BypassAdapter, float]] = {}
self.hooks: list[BypassForwardHook] = []
def add_adapter(
self,
key: str,
adapter: BypassAdapter,
strength: float = 1.0,
):
"""
Add an adapter for a specific weight key.
Args:
key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight")
adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.)
strength: Multiplier for adapter effect
"""
# Remove .weight suffix if present for module lookup
module_key = key
if module_key.endswith(".weight"):
module_key = module_key[:-7]
logging.debug(
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
)
self.adapters[module_key] = (adapter, strength)
logging.debug(
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
)
def clear_adapters(self):
"""Remove all adapters."""
self.adapters.clear()
def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]:
"""Get a submodule by dot-separated key."""
parts = key.split(".")
module = model
try:
for i, part in enumerate(parts):
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
logging.debug(
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
)
return module
except (AttributeError, IndexError, KeyError) as e:
logging.error(f"[BypassManager] Failed to find module for key {key}: {e}")
logging.error(
f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}"
)
return None
def create_injections(self, model: nn.Module) -> list[PatcherInjection]:
"""
Create PatcherInjection objects for all registered adapters.
Args:
model: The model to inject into (e.g., model_patcher.model)
Returns:
List of PatcherInjection objects to use with model_patcher.set_injections()
"""
self.hooks.clear()
logging.debug(
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
)
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
for key, (adapter, strength) in self.adapters.items():
logging.debug(f"[BypassManager] Looking for module: {key}")
module = self._get_module_by_key(model, key)
if module is None:
logging.warning(f"[BypassManager] Module not found for key {key}")
continue
if not hasattr(module, "weight"):
logging.warning(
f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})"
)
continue
logging.debug(
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
)
hook = BypassForwardHook(module, adapter, multiplier=strength)
self.hooks.append(hook)
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
# Create single injection that manages all hooks
def inject_all(model_patcher):
logging.debug(
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
)
for hook in self.hooks:
hook.inject()
logging.debug(
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
)
def eject_all(model_patcher):
logging.debug(
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
)
for hook in self.hooks:
hook.eject()
return [PatcherInjection(inject=inject_all, eject=eject_all)]
def get_hook_count(self) -> int:
"""Return number of hooks that will be/are injected."""
return len(self.hooks)
def create_bypass_injections_from_patches(
model: nn.Module,
patches: dict,
strength: float = 1.0,
) -> list[PatcherInjection]:
"""
Convenience function to create bypass injections from a patches dict.
This is useful when you have patches in the format used by model_patcher.add_patches()
and want to apply them in bypass mode instead.
Args:
model: The model to inject into
patches: Dict mapping weight keys to adapter data
strength: Global strength multiplier
Returns:
List of PatcherInjection objects
"""
manager = BypassInjectionManager()
for key, patch_list in patches.items():
if not patch_list:
continue
# patches format: list of (strength_patch, patch_data, strength_model, offset, function)
for patch in patch_list:
patch_strength, patch_data, strength_model, offset, function = patch
# patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple
if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)):
adapter = patch_data
else:
# Skip non-adapter patches
continue
combined_strength = strength * patch_strength
manager.add_adapter(key, adapter, strength=combined_strength)
return manager.create_injections(model)

View File

@ -1,7 +1,8 @@
import logging
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase):
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
weights = (
lora[a1_name],
lora[a2_name],
lora[b1_name],
lora[b2_name],
alpha,
dora_scale,
)
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase):
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
if (
old_glora
and v[1].shape[0] == weight.shape[0]
and weight.shape[0] == weight.shape[1]
):
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
a1 = comfy.model_management.cast_to_device(
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
)
a2 = comfy.model_management.cast_to_device(
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
)
b1 = comfy.model_management.cast_to_device(
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
)
b2 = comfy.model_management.cast_to_device(
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
)
if v[4] is not None:
alpha = v[4] / rank
@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase):
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
lora_diff = (
torch.mm(b2, b1)
+ torch.mm(
torch.mm(
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
),
a1,
)
).reshape(
weight.shape
) # old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff = torch.einsum(
"o i ..., i j -> o j ...",
torch.einsum(
"o i ..., i j -> o j ...",
weight.to(dtype=intermediate_dtype),
a1,
),
a2,
).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff = torch.mm(
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _compute_paths(self, x: torch.Tensor):
"""
Compute A path and B path outputs for GLoRA bypass.
GLoRA: f(x) = Wx + WAx + Bx
- A path: a1(a2(x)) - modifies input to base forward
- B path: b1(b2(x)) - additive component
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Returns: (a_out, b_out)
"""
v = self.weights
# v = (a1, a2, b1, b2, alpha, dora_scale)
a1 = v[0]
a2 = v[1]
b1 = v[2]
b2 = v[3]
alpha = v[4]
dtype = x.dtype
# Cast dtype (weights should already be on correct device from inject())
a1 = a1.to(dtype=dtype)
a2 = a2.to(dtype=dtype)
b1 = b1.to(dtype=dtype)
b2 = b2.to(dtype=dtype)
# Determine rank and scale
# Check for old vs new glora format
old_glora = False
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
rank = a1.shape[0]
old_glora = True
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
pass
else:
old_glora = False
rank = a2.shape[0]
if alpha is not None:
scale = alpha / rank
else:
scale = 1.0
# Apply multiplier
multiplier = getattr(self, "multiplier", 1.0)
scale = scale * multiplier
# Use module info from bypass injection, not input tensor shape
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
if is_conv:
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
# Get module's stride/padding for spatial dimension handling
module_stride = kw_dict.get("stride", (1,) * conv_dim)
module_padding = kw_dict.get("padding", (0,) * conv_dim)
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Ensure weights are in conv shape
# a1, a2, b1 are always 1x1 kernels
if a1.ndim == 2:
a1 = a1.view(*a1.shape, *([1] * conv_dim))
if a2.ndim == 2:
a2 = a2.view(*a2.shape, *([1] * conv_dim))
if b1.ndim == 2:
b1 = b1.view(*b1.shape, *([1] * conv_dim))
# b2 has actual kernel_size (like LoRA down)
if b2.ndim == 2:
if in_channels is not None:
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
else:
b2 = b2.view(*b2.shape, *([1] * conv_dim))
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
a2_out = conv_fn(x, a2)
a_out = conv_fn(a2_out, a1) * scale
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
b_out = conv_fn(b2_out, b1) * scale
else:
# Linear case
if old_glora:
# Old format: a1 @ a2 @ x, b2 @ b1
a_out = F.linear(F.linear(x, a2), a1) * scale
b_out = F.linear(F.linear(x, b1), b2) * scale
else:
# New format: x @ a1 @ a2, b1 @ b2
a_out = F.linear(F.linear(x, a1), a2) * scale
b_out = F.linear(F.linear(x, b2), b1) * scale
return a_out, b_out
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
GLoRA bypass forward: f(x + a(x)) + b(x)
Unlike standard adapters, GLoRA modifies the input to the base forward
AND adds the B path output.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Reference: LyCORIS GLoRAModule._bypass_forward
"""
a_out, b_out = self._compute_paths(x)
# Call base forward with modified input
base_out = org_forward(x + a_out, *args, **kwargs)
# Add B path
return base_out + b_out
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
For GLoRA, h() returns the B path output.
Note:
GLoRA's full bypass requires overriding bypass_forward() since
it also modifies the input to org_forward. This h() is provided for
compatibility but bypass_forward() should be used for correct behavior.
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
_, b_out = self._compute_paths(x)
return b_out

View File

@ -1,11 +1,22 @@
import logging
from functools import cache
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
@cache
def _warn_loha_bypass_inefficient():
"""One-time warning about LoHa bypass inefficiency."""
logging.warning(
"LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
"Consider using LoRA or LoKr for training with bypass mode."
)
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase):
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
diff_weight = HadaWeightTucker.apply(
self.hada_t1,
self.hada_w1_a,
self.hada_w1_b,
self.hada_t2,
self.hada_w2_a,
self.hada_w2_b,
scale,
)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
diff_weight = HadaWeight.apply(
self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase):
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))
def to_train(self):
return LohaDiff(self.weights)
@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
weights = (
lora[hada_w1_a_name],
lora[hada_w1_b_name],
alpha,
lora[hada_w2_a_name],
lora[hada_w2_b_name],
hada_t1,
hada_t2,
dora_scale,
)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase):
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
if v[5] is not None: # cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t1, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1a, weight.device, intermediate_dtype
),
)
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
m2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t2, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2a, weight.device, intermediate_dtype
),
)
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
m1 = torch.mm(
comfy.model_management.cast_to_device(
w1a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1b, weight.device, intermediate_dtype
),
)
m2 = torch.mm(
comfy.model_management.cast_to_device(
w2a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2b, weight.device, intermediate_dtype
),
)
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoHa: h(x) = diff_weight @ x
WARNING: Inefficient - computes full Hadamard product each forward.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/loha.py bypass_forward_diff
"""
_warn_loha_bypass_inefficient()
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
w1a = v[0]
w1b = v[1]
alpha = v[2]
w2a = v[3]
w2b = v[4]
t1 = v[5]
t2 = v[6]
# Compute scale
rank = w1b.shape[0]
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
self, "multiplier", 1.0
)
# Cast dtype
w1a = w1a.to(dtype=x.dtype)
w1b = w1b.to(dtype=x.dtype)
w2a = w2a.to(dtype=x.dtype)
w2b = w2b.to(dtype=x.dtype)
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Compute diff weight using Hadamard product
if t1 is not None and t2 is not None:
t1 = t1.to(dtype=x.dtype)
t2 = t2.to(dtype=x.dtype)
m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
diff_weight = (m1 * m2) * scale
else:
m1 = w1a @ w1b
m2 = w2a @ w2b
diff_weight = (m1 * m2) * scale
if is_conv:
op = FUNC_LIST[conv_dim + 2]
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Reshape 2D diff_weight to conv format using kernel_size
# diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
if diff_weight.dim() == 2:
if in_channels is not None:
diff_weight = diff_weight.view(
diff_weight.shape[0], in_channels, *kernel_size
)
else:
diff_weight = diff_weight.view(
*diff_weight.shape, *([1] * conv_dim)
)
else:
op = F.linear
kw_dict = {}
return op(x, diff_weight, **kw_dict)

View File

@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import (
WeightAdapterBase,
@ -14,7 +15,17 @@ from .base import (
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
(
lokr_w1,
lokr_w2,
alpha,
lokr_w1_a,
lokr_w1_b,
lokr_w2_a,
lokr_w2_b,
lokr_t2,
dora_scale,
) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
"i j k l, j r, i p -> p r k l",
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
self.lokr_w2_a,
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase):
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
w1 = self.w1
w2 = self.w2
# Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron)
for _ in range(w2.dim() - w1.dim()):
w1 = w1.unsqueeze(-1)
diff = torch.kron(w1, w2)
return w + diff.reshape(w.shape).to(w)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoKr training: efficient Kronecker product.
Uses w1/w2 properties which handle both direct and decomposed cases.
For create_train (direct w1/w2), no alpha scaling in properties.
For to_train (decomposed), alpha/rank scaling is in properties.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
# Get w1, w2 from properties (handles rebuild vs direct)
w1 = self.w1
w2 = self.w2
# Multiplier from bypass injection
multiplier = getattr(self, "multiplier", 1.0)
# Get module info from bypass injection
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Efficient Kronecker application without materializing full weight
# kron(w1, w2) @ x can be computed as nested operations
# w1: [out_l, in_m], w2: [out_k, in_n, *k_size]
# Full weight would be [out_l*out_k, in_m*in_n, *k_size]
uq = w1.size(1) # in_m - inner grouping dimension
if is_conv:
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
B, C_in, *spatial = x.shape
# Reshape input for grouped application: [B * uq, C_in // uq, *spatial]
h_in_group = x.reshape(B * uq, -1, *spatial)
# Ensure w2 has conv dims
if w2.dim() == 2:
w2 = w2.view(*w2.shape, *([1] * conv_dim))
# Apply w2 path with stride/padding
hb = conv_fn(h_in_group, w2, **kw_dict)
# Reshape for cross-group operation
hb = hb.view(B, -1, *hb.shape[1:])
h_cross = hb.transpose(1, -1)
# Apply w1 (always 2D, applied as linear on channel dim)
hc = F.linear(h_cross, w1)
hc = hc.transpose(1, -1)
# Reshape to output
out = hc.reshape(B, -1, *hc.shape[3:])
else:
# Linear case
# Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n]
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
# Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k]
hb = F.linear(h_in_group, w2)
# Transpose for w1: [..., uq, out_k] -> [..., out_k, uq]
h_cross = hb.transpose(-1, -2)
# Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l]
hc = F.linear(h_cross, w1)
# Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k]
hc = hc.transpose(-1, -2)
out = hc.reshape(*hc.shape[:-2], -1)
return out * multiplier
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase):
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
in_dim = weight.shape[1] # Just in_channels, not flattened with kernel
k_size = weight.shape[2:] if weight.dim() > 2 else ()
out_l, out_k = factorization(out_dim, rank)
in_m, in_n = factorization(in_dim, rank)
# w1: [out_l, in_m]
mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32)
# w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear
mat2 = torch.empty(
out_k, in_n, *k_size, device=weight.device, dtype=torch.float32
)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None))
def to_train(self):
return LokrDiff(self.weights)
@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase):
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
if (
(lokr_w1 is not None)
or (lokr_w2 is not None)
or (lokr_w1_a is not None)
or (lokr_w2_a is not None)
):
weights = (
lokr_w1,
lokr_w2,
alpha,
lokr_w1_a,
lokr_w1_b,
lokr_w2_a,
lokr_w2_b,
lokr_t2,
dora_scale,
)
return cls(loaded_keys, weights)
else:
return None
@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase):
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
w1 = torch.mm(
comfy.model_management.cast_to_device(
w1_a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1_b, weight.device, intermediate_dtype
),
)
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
w1 = comfy.model_management.cast_to_device(
w1, weight.device, intermediate_dtype
)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
w2 = torch.mm(
comfy.model_management.cast_to_device(
w2_a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2_b, weight.device, intermediate_dtype
),
)
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
w2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t2, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2_b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2_a, weight.device, intermediate_dtype
),
)
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
w2 = comfy.model_management.cast_to_device(
w2, weight.device, intermediate_dtype
)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoKr: efficient Kronecker product application.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/lokr.py bypass_forward_diff
"""
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora
w1 = v[0]
w2 = v[1]
alpha = v[2]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
use_w1 = w1 is not None
use_w2 = w2 is not None
tucker = t2 is not None
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {}) if is_conv else {}
if is_conv:
op = FUNC_LIST[conv_dim + 2]
else:
op = F.linear
# Determine rank and scale
rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
self, "multiplier", 1.0
)
# Build c (w1)
if use_w1:
c = w1.to(dtype=x.dtype)
else:
c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype)
uq = c.size(1)
# Build w2 components
if use_w2:
ba = w2.to(dtype=x.dtype)
else:
a = w2_b.to(dtype=x.dtype)
b = w2_a.to(dtype=x.dtype)
if is_conv:
if tucker:
# Tucker: a, b get 1s appended (kernel is in t2)
if a.dim() == 2:
a = a.view(*a.shape, *([1] * conv_dim))
if b.dim() == 2:
b = b.view(*b.shape, *([1] * conv_dim))
else:
# Non-tucker conv: b may need 1s appended
if b.dim() == 2:
b = b.view(*b.shape, *([1] * conv_dim))
# Reshape input by uq groups
if is_conv:
B, _, *rest = x.shape
h_in_group = x.reshape(B * uq, -1, *rest)
else:
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
# Apply w2 path
if use_w2:
hb = op(h_in_group, ba, **kw_dict)
else:
if is_conv:
if tucker:
t = t2.to(dtype=x.dtype)
if t.dim() == 2:
t = t.view(*t.shape, *([1] * conv_dim))
ha = op(h_in_group, a)
ht = op(ha, t, **kw_dict)
hb = op(ht, b)
else:
ha = op(h_in_group, a, **kw_dict)
hb = op(ha, b)
else:
ha = op(h_in_group, a)
hb = op(ha, b)
# Reshape and apply c (w1)
if is_conv:
hb = hb.view(B, -1, *hb.shape[1:])
h_cross_group = hb.transpose(1, -1)
else:
h_cross_group = hb.transpose(-1, -2)
hc = F.linear(h_cross_group, c)
if is_conv:
hc = hc.transpose(1, -1)
out = hc.reshape(B, -1, *hc.shape[3:])
else:
hc = hc.transpose(-1, -2)
out = hc.reshape(*hc.shape[:-2], -1)
return out * scale

View File

@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import (
WeightAdapterBase,
@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase):
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase):
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
Simple implementation using the nn.Module weights directly.
No mid/dora/reshape branches (create_train doesn't create them).
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
# Compute scale = alpha / rank * multiplier
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
# Get module info from bypass injection
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Get weights (keep in original dtype for numerical stability)
down_weight = self.lora_down.weight
up_weight = self.lora_up.weight
if is_conv:
# Conv path: use functional conv
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
# Reshape 2D weights to conv format if needed
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
if down_weight.dim() == 2:
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
if in_channels is not None:
down_weight = down_weight.view(
down_weight.shape[0], in_channels, *kernel_size
)
else:
# Fallback: assume 1x1 kernel
down_weight = down_weight.view(
*down_weight.shape, *([1] * conv_dim)
)
if up_weight.dim() == 2:
# up always uses 1x1 kernel
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
# down conv uses stride/padding from module, up is 1x1
hidden = conv_fn(x, down_weight, **kw_dict)
# mid layer if exists (tucker decomposition)
if self.lora_mid is not None:
mid_weight = self.lora_mid.weight
if mid_weight.dim() == 2:
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
hidden = conv_fn(hidden, mid_weight)
# up conv is always 1x1 (no stride/padding)
out = conv_fn(hidden, up_weight)
else:
# Linear path: simple matmul chain
hidden = F.linear(x, down_weight)
# mid layer if exists
if self.lora_mid is not None:
mid_weight = self.lora_mid.weight
hidden = F.linear(hidden, mid_weight)
out = F.linear(hidden, up_weight)
return out * scale
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase):
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
return LoraDiff((mat1, mat2, alpha, None, None, None))
def to_train(self):
return LoraDiff(self.weights)
@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase):
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/locon.py bypass_forward_diff
"""
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
up = v[0]
down = v[1]
alpha = v[2]
mid = v[3]
# Compute scale = alpha / rank
rank = down.shape[0]
if alpha is not None:
scale = alpha / rank
else:
scale = 1.0
scale = scale * getattr(self, "multiplier", 1.0)
# Cast dtype
up = up.to(dtype=x.dtype)
down = down.to(dtype=x.dtype)
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
if is_conv:
op = FUNC_LIST[
conv_dim + 2
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Reshape 2D weights to conv format using kernel_size
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
if down.dim() == 2:
# down.shape[1] = in_channels * prod(kernel_size)
if in_channels is not None:
down = down.view(down.shape[0], in_channels, *kernel_size)
else:
# Fallback: assume 1x1 kernel if in_channels unknown
down = down.view(*down.shape, *([1] * conv_dim))
if up.dim() == 2:
# up always uses 1x1 kernel
up = up.view(*up.shape, *([1] * conv_dim))
if mid is not None:
mid = mid.to(dtype=x.dtype)
if mid.dim() == 2:
mid = mid.view(*mid.shape, *([1] * conv_dim))
else:
op = F.linear
kw_dict = {} # linear doesn't take stride/padding
# Simple chain: down -> mid (if tucker) -> up
if mid is not None:
if not is_conv:
mid = mid.to(dtype=x.dtype)
hidden = op(x, down)
hidden = op(hidden, mid, **kw_dict)
out = op(hidden, up)
else:
hidden = op(x, down, **kw_dict)
out = op(hidden, up)
return out * scale

View File

@ -3,13 +3,18 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
# Unpack weights tuple from OFTAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase):
weight = self.rescale * weight
return weight.to(org_dtype)
def _get_orthogonal_matrix(self, device, dtype):
"""Compute the orthogonal rotation matrix R from OFT blocks."""
blocks = self.oft_blocks.to(device=device, dtype=dtype)
I = torch.eye(self.block_size, device=device, dtype=dtype)
# Q = blocks - blocks^T (skew-symmetric)
q = blocks - blocks.transpose(1, 2)
normed_q = q
# Apply constraint if set
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# Cayley transform: R = (I + Q)(I - Q)^-1
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r.to(dtype)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
OFT has no additive component - returns zeros matching base_out shape.
OFT only transforms the output via g(), it doesn't add to it.
"""
return torch.zeros_like(base_out)
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation for OFT: applies orthogonal rotation.
OFT transforms output channels using block-diagonal orthogonal matrices.
"""
r = self._get_orthogonal_matrix(y.device, y.dtype)
# Apply multiplier to interpolate between identity and full transform
multiplier = getattr(self, "multiplier", 1.0)
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
r = r * multiplier + (1 - multiplier) * I
# Use module info from bypass injection
is_conv = getattr(self, "is_conv", y.dim() > 2)
if is_conv:
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
y = y.transpose(1, -1)
# y now has channels in last dim
*batch_shape, out_features = y.shape
# Reshape to apply block-diagonal transform
# (*, out_features) -> (*, block_num, block_size)
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
# Apply orthogonal transform: R @ y for each block
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
# Reshape back: (*, block_num, block_size) -> (*, out_features)
out = out_blocked.reshape(*batch_shape, out_features)
# Apply rescale if present
if self.rescaled:
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
out = out * rescale.view(-1)
if is_conv:
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
out = out.transpose(1, -1)
return out
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
return OFTDiff(
(block, None, alpha, None)
block = torch.zeros(
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
)
return OFTDiff((block, None, alpha, None))
def to_train(self):
return OFTDiff(self.weights)
@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase):
alpha = 0
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
blocks = comfy.model_management.cast_to_device(
blocks, weight.device, intermediate_dtype
)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
rescale = comfy.model_management.cast_to_device(
rescale, weight.device, intermediate_dtype
)
block_num, block_size, *_ = blocks.shape
@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase):
# for Q = -Q^T
q = blocks - blocks.transpose(1, 2)
normed_q = q
if alpha > 0: # alpha in oft/boft is for constraint
if alpha > 0: # alpha in oft/boft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(weight)
# Create I in weight's dtype for the einsum
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
_, *shape = weight.shape
lora_diff = torch.einsum(
"k n m, k n ... -> k m ...",
(r * strength) - strength * I,
(r * strength) - strength * I_w,
weight.view(block_num, block_size, *shape),
).view(-1, *shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _get_orthogonal_matrix(self, device, dtype):
"""Compute the orthogonal rotation matrix R from OFT blocks."""
v = self.weights
blocks = v[0].to(device=device, dtype=dtype)
alpha = v[2]
if alpha is None:
alpha = 0
block_num, block_size, _ = blocks.shape
I = torch.eye(block_size, device=device, dtype=dtype)
# Q = blocks - blocks^T (skew-symmetric)
q = blocks - blocks.transpose(1, 2)
normed_q = q
# Apply constraint if alpha > 0
if alpha > 0:
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# Cayley transform: R = (I + Q)(I - Q)^-1
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r, block_num, block_size
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation for OFT: applies orthogonal rotation to output.
OFT transforms the output channels using block-diagonal orthogonal matrices.
Reference: LyCORIS DiagOFTModule._bypass_forward
"""
v = self.weights
rescale = v[1]
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
# Apply multiplier to interpolate between identity and full transform
multiplier = getattr(self, "multiplier", 1.0)
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
r = r * multiplier + (1 - multiplier) * I
# Use module info from bypass injection to determine conv vs linear
is_conv = getattr(self, "is_conv", y.dim() > 2)
if is_conv:
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
y = y.transpose(1, -1)
# y now has channels in last dim
*batch_shape, out_features = y.shape
# Reshape to apply block-diagonal transform
# (*, out_features) -> (*, block_num, block_size)
y_blocked = y.view(*batch_shape, block_num, block_size)
# Apply orthogonal transform: R @ y for each block
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
# Reshape back: (*, block_num, block_size) -> (*, out_features)
out = out_blocked.view(*batch_shape, out_features)
# Apply rescale if present
if rescale is not None:
rescale = rescale.to(device=y.device, dtype=y.dtype)
out = out * rescale.view(-1)
if is_conv:
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
out = out.transpose(1, -1)
return out

View File

@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapters, adapter_maps
from comfy.weight_adapter.bypass import BypassInjectionManager
from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
logging.debug(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
logging.debug(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
@ -596,6 +602,8 @@ def _create_weight_adapter(
shape = module.weight.shape
lora_params = {}
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
return lora_sd, all_weight_adapters
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup LoRA adapters in bypass mode.
In bypass mode:
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
This is useful when the base model weights are quantized and cannot be
directly modified.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
"""
lora_sd = {}
all_weight_adapters = []
bypass_manager = BypassInjectionManager()
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
all_weight_adapters.append(adapter)
key = f"{n}.weight"
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
# Only use bypass for adapters that have h() method (lora/lokr/oft)
if isinstance(adapter, BiasDiff):
mp.add_weight_wrapper(key, adapter)
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
else:
bypass_manager.add_adapter(key, adapter, strength=1.0)
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
if hasattr(m, "bias") and m.bias is not None:
# Bias adapters still use weight wrapper (bias is usually not quantized)
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
return lora_sd, all_weight_adapters, bypass_manager
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode):
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
io.Boolean.Input(
"bypass_mode",
default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="Model with LoRA applied"
),
io.Custom("LORA_MODEL").Output(
display_name="lora", tooltip="LoRA weights"
),
@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing,
existing_lora,
bucket_mode,
bypass_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
# Process latents based on mode
if bucket_mode:
@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode):
existing_weights, existing_steps = _load_existing_lora(existing_lora)
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
bypass_manager = None
if bypass_mode:
logging.debug("Using bypass mode for training")
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
mp, existing_weights, algorithm, lora_dtype, rank
)
else:
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
# Create optimizer and loss function
optimizer = _create_optimizer(
@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode):
guider = TrainGuider(mp)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
bypass_injections = None
if bypass_manager is not None:
bypass_injections = bypass_manager.create_injections(mp.model)
for injection in bypass_injections:
injection.inject(mp)
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
# Run training loop
try:
_run_training_loop(
@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
injection.eject(mp)
logging.debug("[BypassMode] Ejected bypass hooks")
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode):
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):#

View File

@ -722,6 +722,69 @@ class LoraLoaderModelOnly(LoraLoader):
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class LoraLoaderBypass:
"""
Apply LoRA in bypass mode without modifying base model weights.
Bypass mode computes: output = base_forward(x) + lora_path(x)
This is useful for training and when model weights are offloaded.
"""
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
CATEGORY = "loaders"
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
self.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@ -2067,6 +2130,8 @@ NODE_CLASS_MAPPINGS = {
"LatentFlip": LatentFlip,
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"LoraLoaderBypass": LoraLoaderBypass,
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
"CLIPLoader": CLIPLoader,
"UNETLoader": UNETLoader,
"DualCLIPLoader": DualCLIPLoader,
@ -2106,6 +2171,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"LoraLoader": "Load LoRA",
"LoraLoaderBypass": "Load LoRA (Bypass)",
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",
"DiffControlNetLoader": "Load ControlNet Model (diff)",