ComfyUI/comfy/weight_adapter/glora.py
Kohaku-Blueleaf a97c98068f
[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module

* bypass implementation

* add bypass fwd into nodes list/trainer
2026-01-24 22:56:22 -05:00

291 lines
9.5 KiB
Python

import logging
from typing import Callable, Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class GLoRAAdapter(WeightAdapterBase):
name = "glora"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["GLoRAAdapter"]:
if loaded_keys is None:
loaded_keys = set()
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
weights = (
lora[a1_name],
lora[a2_name],
lora[b1_name],
lora[b2_name],
alpha,
dora_scale,
)
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if (
old_glora
and v[1].shape[0] == weight.shape[0]
and weight.shape[0] == weight.shape[1]
):
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
)
a2 = comfy.model_management.cast_to_device(
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
)
b1 = comfy.model_management.cast_to_device(
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
)
b2 = comfy.model_management.cast_to_device(
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (
torch.mm(b2, b1)
+ torch.mm(
torch.mm(
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
),
a1,
)
).reshape(
weight.shape
) # old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum(
"o i ..., i j -> o j ...",
torch.einsum(
"o i ..., i j -> o j ...",
weight.to(dtype=intermediate_dtype),
a1,
),
a2,
).reshape(weight.shape)
else:
lora_diff = torch.mm(
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _compute_paths(self, x: torch.Tensor):
"""
Compute A path and B path outputs for GLoRA bypass.
GLoRA: f(x) = Wx + WAx + Bx
- A path: a1(a2(x)) - modifies input to base forward
- B path: b1(b2(x)) - additive component
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Returns: (a_out, b_out)
"""
v = self.weights
# v = (a1, a2, b1, b2, alpha, dora_scale)
a1 = v[0]
a2 = v[1]
b1 = v[2]
b2 = v[3]
alpha = v[4]
dtype = x.dtype
# Cast dtype (weights should already be on correct device from inject())
a1 = a1.to(dtype=dtype)
a2 = a2.to(dtype=dtype)
b1 = b1.to(dtype=dtype)
b2 = b2.to(dtype=dtype)
# Determine rank and scale
# Check for old vs new glora format
old_glora = False
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
rank = a1.shape[0]
old_glora = True
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
pass
else:
old_glora = False
rank = a2.shape[0]
if alpha is not None:
scale = alpha / rank
else:
scale = 1.0
# Apply multiplier
multiplier = getattr(self, "multiplier", 1.0)
scale = scale * multiplier
# Use module info from bypass injection, not input tensor shape
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
if is_conv:
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
# Get module's stride/padding for spatial dimension handling
module_stride = kw_dict.get("stride", (1,) * conv_dim)
module_padding = kw_dict.get("padding", (0,) * conv_dim)
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Ensure weights are in conv shape
# a1, a2, b1 are always 1x1 kernels
if a1.ndim == 2:
a1 = a1.view(*a1.shape, *([1] * conv_dim))
if a2.ndim == 2:
a2 = a2.view(*a2.shape, *([1] * conv_dim))
if b1.ndim == 2:
b1 = b1.view(*b1.shape, *([1] * conv_dim))
# b2 has actual kernel_size (like LoRA down)
if b2.ndim == 2:
if in_channels is not None:
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
else:
b2 = b2.view(*b2.shape, *([1] * conv_dim))
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
a2_out = conv_fn(x, a2)
a_out = conv_fn(a2_out, a1) * scale
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
b_out = conv_fn(b2_out, b1) * scale
else:
# Linear case
if old_glora:
# Old format: a1 @ a2 @ x, b2 @ b1
a_out = F.linear(F.linear(x, a2), a1) * scale
b_out = F.linear(F.linear(x, b1), b2) * scale
else:
# New format: x @ a1 @ a2, b1 @ b2
a_out = F.linear(F.linear(x, a1), a2) * scale
b_out = F.linear(F.linear(x, b2), b1) * scale
return a_out, b_out
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
GLoRA bypass forward: f(x + a(x)) + b(x)
Unlike standard adapters, GLoRA modifies the input to the base forward
AND adds the B path output.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Reference: LyCORIS GLoRAModule._bypass_forward
"""
a_out, b_out = self._compute_paths(x)
# Call base forward with modified input
base_out = org_forward(x + a_out, *args, **kwargs)
# Add B path
return base_out + b_out
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
For GLoRA, h() returns the B path output.
Note:
GLoRA's full bypass requires overriding bypass_forward() since
it also modifies the input to org_forward. This h() is provided for
compatibility but bypass_forward() should be used for correct behavior.
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
_, b_out = self._compute_paths(x)
return b_out