ComfyUI/comfy/weight_adapter/loha.py
Kohaku-Blueleaf a97c98068f
[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module

* bypass implementation

* add bypass fwd into nodes list/trainer
2026-01-24 22:56:22 -05:00

379 lines
13 KiB
Python

import logging
from functools import cache
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
@cache
def _warn_loha_bypass_inefficient():
"""One-time warning about LoHa bypass inefficiency."""
logging.warning(
"LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
"Consider using LoRA or LoKr for training with bypass mode."
)
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
class LohaDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
# Create trainable parameters
self.hada_w1_a = torch.nn.Parameter(w1a)
self.hada_w1_b = torch.nn.Parameter(w1b)
self.hada_w2_a = torch.nn.Parameter(w2a)
self.hada_w2_b = torch.nn.Parameter(w2b)
self.use_tucker = False
if t1 is not None and t2 is not None:
self.use_tucker = True
self.hada_t1 = torch.nn.Parameter(t1)
self.hada_t2 = torch.nn.Parameter(t2)
else:
# Keep the attributes for consistent access
self.hada_t1 = None
self.hada_t2 = None
# Store rank and non-trainable alpha
self.rank = w1b.shape[0]
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(
self.hada_t1,
self.hada_w1_a,
self.hada_w1_b,
self.hada_t2,
self.hada_w2_a,
self.hada_w2_b,
scale,
)
else:
diff_weight = HadaWeight.apply(
self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoHaAdapter(WeightAdapterBase):
name = "loha"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))
def to_train(self):
return LohaDiff(self.weights)
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["LoHaAdapter"]:
if loaded_keys is None:
loaded_keys = set()
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
weights = (
lora[hada_w1_a_name],
lora[hada_w1_b_name],
alpha,
lora[hada_w2_a_name],
lora[hada_w2_b_name],
hada_t1,
hada_t2,
dora_scale,
)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: # cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t1, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1a, weight.device, intermediate_dtype
),
)
m2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t2, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2a, weight.device, intermediate_dtype
),
)
else:
m1 = torch.mm(
comfy.model_management.cast_to_device(
w1a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1b, weight.device, intermediate_dtype
),
)
m2 = torch.mm(
comfy.model_management.cast_to_device(
w2a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2b, weight.device, intermediate_dtype
),
)
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoHa: h(x) = diff_weight @ x
WARNING: Inefficient - computes full Hadamard product each forward.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/loha.py bypass_forward_diff
"""
_warn_loha_bypass_inefficient()
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
w1a = v[0]
w1b = v[1]
alpha = v[2]
w2a = v[3]
w2b = v[4]
t1 = v[5]
t2 = v[6]
# Compute scale
rank = w1b.shape[0]
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
self, "multiplier", 1.0
)
# Cast dtype
w1a = w1a.to(dtype=x.dtype)
w1b = w1b.to(dtype=x.dtype)
w2a = w2a.to(dtype=x.dtype)
w2b = w2b.to(dtype=x.dtype)
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Compute diff weight using Hadamard product
if t1 is not None and t2 is not None:
t1 = t1.to(dtype=x.dtype)
t2 = t2.to(dtype=x.dtype)
m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
diff_weight = (m1 * m2) * scale
else:
m1 = w1a @ w1b
m2 = w2a @ w2b
diff_weight = (m1 * m2) * scale
if is_conv:
op = FUNC_LIST[conv_dim + 2]
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Reshape 2D diff_weight to conv format using kernel_size
# diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
if diff_weight.dim() == 2:
if in_channels is not None:
diff_weight = diff_weight.view(
diff_weight.shape[0], in_channels, *kernel_size
)
else:
diff_weight = diff_weight.view(
*diff_weight.shape, *([1] * conv_dim)
)
else:
op = F.linear
kw_dict = {}
return op(x, diff_weight, **kw_dict)