mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
bypass implementation
This commit is contained in:
parent
aa77a8a461
commit
2a420dc4db
@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
alpha = v[2]
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
blocks = comfy.model_management.cast_to_device(
|
||||
blocks, weight.device, intermediate_dtype
|
||||
)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
rescale = comfy.model_management.cast_to_device(
|
||||
rescale, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||
|
||||
@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(-1, -2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
r = r.to(weight)
|
||||
inp = org = weight
|
||||
|
||||
r_b = boft_b//2
|
||||
r_b = boft_b // 2
|
||||
for i in range(boft_m):
|
||||
bi = r[i]
|
||||
g = 2
|
||||
k = 2**i * r_b
|
||||
if strength != 1:
|
||||
bi = bi * strength + (1-strength) * I
|
||||
bi = bi * strength + (1 - strength) * I
|
||||
inp = (
|
||||
inp.unflatten(0, (-1, g, k))
|
||||
.transpose(1, 2)
|
||||
@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
)
|
||||
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||
inp = (
|
||||
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
||||
inp.flatten(0, 1)
|
||||
.unflatten(0, (-1, k, g))
|
||||
.transpose(1, 2)
|
||||
.flatten(0, 2)
|
||||
)
|
||||
|
||||
if rescale is not None:
|
||||
inp = inp * rescale
|
||||
|
||||
lora_diff = inp - org
|
||||
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||
lora_diff = comfy.model_management.cast_to_device(
|
||||
lora_diff, weight.device, intermediate_dtype
|
||||
)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _get_orthogonal_matrices(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrices R from BOFT blocks."""
|
||||
v = self.weights
|
||||
blocks = v[0].to(device=device, dtype=dtype)
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
|
||||
boft_m, block_num, boft_b, _ = blocks.shape
|
||||
I = torch.eye(boft_b, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(-1, -2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if alpha > 0
|
||||
if alpha > 0:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r, boft_m, boft_b
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for BOFT: applies butterfly orthogonal transform.
|
||||
|
||||
BOFT uses multiple stages of butterfly-structured orthogonal transforms.
|
||||
|
||||
Reference: LyCORIS ButterflyOFTModule._bypass_forward
|
||||
"""
|
||||
v = self.weights
|
||||
rescale = v[1]
|
||||
|
||||
r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype)
|
||||
r_b = boft_b // 2
|
||||
|
||||
# Apply multiplier
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(boft_b, device=y.device, dtype=y.dtype)
|
||||
|
||||
# Use module info from bypass injection to determine conv vs linear
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# Apply butterfly transform stages
|
||||
inp = y
|
||||
for i in range(boft_m):
|
||||
bi = r[i] # (block_num, boft_b, boft_b)
|
||||
g = 2
|
||||
k = 2**i * r_b
|
||||
|
||||
# Interpolate with identity based on multiplier
|
||||
if multiplier != 1:
|
||||
bi = bi * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten
|
||||
inp = (
|
||||
inp.unflatten(-1, (-1, g, k))
|
||||
.transpose(-2, -1)
|
||||
.flatten(-3)
|
||||
.unflatten(-1, (-1, boft_b))
|
||||
)
|
||||
# Apply block-diagonal orthogonal transform
|
||||
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
|
||||
# Reshape back
|
||||
inp = (
|
||||
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
||||
)
|
||||
|
||||
# Apply rescale if present
|
||||
if rescale is not None:
|
||||
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||
inp = inp * rescale.transpose(0, -1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
inp = inp.transpose(1, -1)
|
||||
|
||||
return inp
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase):
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||
weights = (
|
||||
lora[a1_name],
|
||||
lora[a2_name],
|
||||
lora[b1_name],
|
||||
lora[b2_name],
|
||||
alpha,
|
||||
dora_scale,
|
||||
)
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase):
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
if (
|
||||
old_glora
|
||||
and v[1].shape[0] == weight.shape[0]
|
||||
and weight.shape[0] == weight.shape[1]
|
||||
):
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a1 = comfy.model_management.cast_to_device(
|
||||
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
a2 = comfy.model_management.cast_to_device(
|
||||
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
b1 = comfy.model_management.cast_to_device(
|
||||
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
b2 = comfy.model_management.cast_to_device(
|
||||
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase):
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
lora_diff = (
|
||||
torch.mm(b2, b1)
|
||||
+ torch.mm(
|
||||
torch.mm(
|
||||
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
|
||||
),
|
||||
a1,
|
||||
)
|
||||
).reshape(
|
||||
weight.shape
|
||||
) # old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff = torch.einsum(
|
||||
"o i ..., i j -> o j ...",
|
||||
torch.einsum(
|
||||
"o i ..., i j -> o j ...",
|
||||
weight.to(dtype=intermediate_dtype),
|
||||
a1,
|
||||
),
|
||||
a2,
|
||||
).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff = torch.mm(
|
||||
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
|
||||
).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _compute_paths(self, x: torch.Tensor):
|
||||
"""
|
||||
Compute A path and B path outputs for GLoRA bypass.
|
||||
|
||||
GLoRA: f(x) = Wx + WAx + Bx
|
||||
- A path: a1(a2(x)) - modifies input to base forward
|
||||
- B path: b1(b2(x)) - additive component
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Returns: (a_out, b_out)
|
||||
"""
|
||||
v = self.weights
|
||||
# v = (a1, a2, b1, b2, alpha, dora_scale)
|
||||
a1 = v[0]
|
||||
a2 = v[1]
|
||||
b1 = v[2]
|
||||
b2 = v[3]
|
||||
alpha = v[4]
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
# Cast dtype (weights should already be on correct device from inject())
|
||||
a1 = a1.to(dtype=dtype)
|
||||
a2 = a2.to(dtype=dtype)
|
||||
b1 = b1.to(dtype=dtype)
|
||||
b2 = b2.to(dtype=dtype)
|
||||
|
||||
# Determine rank and scale
|
||||
# Check for old vs new glora format
|
||||
old_glora = False
|
||||
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
|
||||
rank = a1.shape[0]
|
||||
old_glora = True
|
||||
|
||||
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
|
||||
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = a2.shape[0]
|
||||
|
||||
if alpha is not None:
|
||||
scale = alpha / rank
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
# Apply multiplier
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
scale = scale * multiplier
|
||||
|
||||
# Use module info from bypass injection, not input tensor shape
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
if is_conv:
|
||||
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
# Get module's stride/padding for spatial dimension handling
|
||||
module_stride = kw_dict.get("stride", (1,) * conv_dim)
|
||||
module_padding = kw_dict.get("padding", (0,) * conv_dim)
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
|
||||
# Ensure weights are in conv shape
|
||||
# a1, a2, b1 are always 1x1 kernels
|
||||
if a1.ndim == 2:
|
||||
a1 = a1.view(*a1.shape, *([1] * conv_dim))
|
||||
if a2.ndim == 2:
|
||||
a2 = a2.view(*a2.shape, *([1] * conv_dim))
|
||||
if b1.ndim == 2:
|
||||
b1 = b1.view(*b1.shape, *([1] * conv_dim))
|
||||
# b2 has actual kernel_size (like LoRA down)
|
||||
if b2.ndim == 2:
|
||||
if in_channels is not None:
|
||||
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
|
||||
else:
|
||||
b2 = b2.view(*b2.shape, *([1] * conv_dim))
|
||||
|
||||
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
|
||||
a2_out = conv_fn(x, a2)
|
||||
a_out = conv_fn(a2_out, a1) * scale
|
||||
|
||||
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
|
||||
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
|
||||
b_out = conv_fn(b2_out, b1) * scale
|
||||
else:
|
||||
# Linear case
|
||||
if old_glora:
|
||||
# Old format: a1 @ a2 @ x, b2 @ b1
|
||||
a_out = F.linear(F.linear(x, a2), a1) * scale
|
||||
b_out = F.linear(F.linear(x, b1), b2) * scale
|
||||
else:
|
||||
# New format: x @ a1 @ a2, b1 @ b2
|
||||
a_out = F.linear(F.linear(x, a1), a2) * scale
|
||||
b_out = F.linear(F.linear(x, b2), b1) * scale
|
||||
|
||||
return a_out, b_out
|
||||
|
||||
def bypass_forward(
|
||||
self,
|
||||
org_forward: Callable,
|
||||
x: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GLoRA bypass forward: f(x + a(x)) + b(x)
|
||||
|
||||
Unlike standard adapters, GLoRA modifies the input to the base forward
|
||||
AND adds the B path output.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Reference: LyCORIS GLoRAModule._bypass_forward
|
||||
"""
|
||||
a_out, b_out = self._compute_paths(x)
|
||||
|
||||
# Call base forward with modified input
|
||||
base_out = org_forward(x + a_out, *args, **kwargs)
|
||||
|
||||
# Add B path
|
||||
return base_out + b_out
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
For GLoRA, h() returns the B path output.
|
||||
|
||||
Note:
|
||||
GLoRA's full bypass requires overriding bypass_forward() since
|
||||
it also modifies the input to org_forward. This h() is provided for
|
||||
compatibility but bypass_forward() should be used for correct behavior.
|
||||
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
_, b_out = self._compute_paths(x)
|
||||
return b_out
|
||||
|
||||
@ -1,11 +1,22 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||
|
||||
|
||||
@cache
|
||||
def _warn_loha_bypass_inefficient():
|
||||
"""One-time warning about LoHa bypass inefficiency."""
|
||||
logging.warning(
|
||||
"LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
|
||||
"Consider using LoRA or LoKr for training with bypass mode."
|
||||
)
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||
@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase):
|
||||
|
||||
scale = self.alpha / self.rank
|
||||
if self.use_tucker:
|
||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
diff_weight = HadaWeightTucker.apply(
|
||||
self.hada_t1,
|
||||
self.hada_w1_a,
|
||||
self.hada_w1_b,
|
||||
self.hada_t2,
|
||||
self.hada_w2_a,
|
||||
self.hada_w2_b,
|
||||
scale,
|
||||
)
|
||||
else:
|
||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
diff_weight = HadaWeight.apply(
|
||||
self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
|
||||
)
|
||||
|
||||
# Add the scaled difference to the original weight
|
||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||
@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||
)
|
||||
return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LohaDiff(self.weights)
|
||||
@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||
weights = (
|
||||
lora[hada_w1_a_name],
|
||||
lora[hada_w1_b_name],
|
||||
alpha,
|
||||
lora[hada_w2_a_name],
|
||||
lora[hada_w2_b_name],
|
||||
hada_t1,
|
||||
hada_t2,
|
||||
dora_scale,
|
||||
)
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
if v[5] is not None: # cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
m1 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t1, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
m2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t2, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
m1 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w1a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
m2 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w2a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoHa: h(x) = diff_weight @ x
|
||||
|
||||
WARNING: Inefficient - computes full Hadamard product each forward.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/loha.py bypass_forward_diff
|
||||
"""
|
||||
_warn_loha_bypass_inefficient()
|
||||
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
alpha = v[2]
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
|
||||
# Compute scale
|
||||
rank = w1b.shape[0]
|
||||
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||
self, "multiplier", 1.0
|
||||
)
|
||||
|
||||
# Cast dtype
|
||||
w1a = w1a.to(dtype=x.dtype)
|
||||
w1b = w1b.to(dtype=x.dtype)
|
||||
w2a = w2a.to(dtype=x.dtype)
|
||||
w2b = w2b.to(dtype=x.dtype)
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Compute diff weight using Hadamard product
|
||||
if t1 is not None and t2 is not None:
|
||||
t1 = t1.to(dtype=x.dtype)
|
||||
t2 = t2.to(dtype=x.dtype)
|
||||
m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
|
||||
m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
|
||||
diff_weight = (m1 * m2) * scale
|
||||
else:
|
||||
m1 = w1a @ w1b
|
||||
m2 = w2a @ w2b
|
||||
diff_weight = (m1 * m2) * scale
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[conv_dim + 2]
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
|
||||
# Reshape 2D diff_weight to conv format using kernel_size
|
||||
# diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
|
||||
if diff_weight.dim() == 2:
|
||||
if in_channels is not None:
|
||||
diff_weight = diff_weight.view(
|
||||
diff_weight.shape[0], in_channels, *kernel_size
|
||||
)
|
||||
else:
|
||||
diff_weight = diff_weight.view(
|
||||
*diff_weight.shape, *([1] * conv_dim)
|
||||
)
|
||||
else:
|
||||
op = F.linear
|
||||
kw_dict = {}
|
||||
|
||||
return op(x, diff_weight, **kw_dict)
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
@ -14,7 +15,17 @@ from .base import (
|
||||
class LokrDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||
(
|
||||
lokr_w1,
|
||||
lokr_w2,
|
||||
alpha,
|
||||
lokr_w1_a,
|
||||
lokr_w1_b,
|
||||
lokr_w2_a,
|
||||
lokr_w2_b,
|
||||
lokr_t2,
|
||||
dora_scale,
|
||||
) = weights
|
||||
self.use_tucker = False
|
||||
if lokr_w1_a is not None:
|
||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||
@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase):
|
||||
if self.w2_rebuild:
|
||||
if self.use_tucker:
|
||||
w2 = torch.einsum(
|
||||
'i j k l, j r, i p -> p r k l',
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
self.lokr_t2,
|
||||
self.lokr_w2_b,
|
||||
self.lokr_w2_a
|
||||
self.lokr_w2_a,
|
||||
)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase):
|
||||
return self.lokr_w2
|
||||
|
||||
def __call__(self, w):
|
||||
diff = torch.kron(self.w1, self.w2)
|
||||
w1 = self.w1
|
||||
w2 = self.w2
|
||||
# Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron)
|
||||
for _ in range(w2.dim() - w1.dim()):
|
||||
w1 = w1.unsqueeze(-1)
|
||||
diff = torch.kron(w1, w2)
|
||||
return w + diff.reshape(w.shape).to(w)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoKr training: efficient Kronecker product.
|
||||
|
||||
Uses w1/w2 properties which handle both direct and decomposed cases.
|
||||
For create_train (direct w1/w2), no alpha scaling in properties.
|
||||
For to_train (decomposed), alpha/rank scaling is in properties.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
# Get w1, w2 from properties (handles rebuild vs direct)
|
||||
w1 = self.w1
|
||||
w2 = self.w2
|
||||
|
||||
# Multiplier from bypass injection
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Get module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Efficient Kronecker application without materializing full weight
|
||||
# kron(w1, w2) @ x can be computed as nested operations
|
||||
# w1: [out_l, in_m], w2: [out_k, in_n, *k_size]
|
||||
# Full weight would be [out_l*out_k, in_m*in_n, *k_size]
|
||||
|
||||
uq = w1.size(1) # in_m - inner grouping dimension
|
||||
|
||||
if is_conv:
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
B, C_in, *spatial = x.shape
|
||||
# Reshape input for grouped application: [B * uq, C_in // uq, *spatial]
|
||||
h_in_group = x.reshape(B * uq, -1, *spatial)
|
||||
|
||||
# Ensure w2 has conv dims
|
||||
if w2.dim() == 2:
|
||||
w2 = w2.view(*w2.shape, *([1] * conv_dim))
|
||||
|
||||
# Apply w2 path with stride/padding
|
||||
hb = conv_fn(h_in_group, w2, **kw_dict)
|
||||
|
||||
# Reshape for cross-group operation
|
||||
hb = hb.view(B, -1, *hb.shape[1:])
|
||||
h_cross = hb.transpose(1, -1)
|
||||
|
||||
# Apply w1 (always 2D, applied as linear on channel dim)
|
||||
hc = F.linear(h_cross, w1)
|
||||
hc = hc.transpose(1, -1)
|
||||
|
||||
# Reshape to output
|
||||
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||
else:
|
||||
# Linear case
|
||||
# Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n]
|
||||
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||
|
||||
# Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k]
|
||||
hb = F.linear(h_in_group, w2)
|
||||
|
||||
# Transpose for w1: [..., uq, out_k] -> [..., out_k, uq]
|
||||
h_cross = hb.transpose(-1, -2)
|
||||
|
||||
# Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l]
|
||||
hc = F.linear(h_cross, w1)
|
||||
|
||||
# Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k]
|
||||
hc = hc.transpose(-1, -2)
|
||||
out = hc.reshape(*hc.shape[:-2], -1)
|
||||
|
||||
return out * multiplier
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
|
||||
in_dim = weight.shape[1] # Just in_channels, not flattened with kernel
|
||||
k_size = weight.shape[2:] if weight.dim() > 2 else ()
|
||||
|
||||
out_l, out_k = factorization(out_dim, rank)
|
||||
in_m, in_n = factorization(in_dim, rank)
|
||||
|
||||
# w1: [out_l, in_m]
|
||||
mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32)
|
||||
# w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear
|
||||
mat2 = torch.empty(
|
||||
out_k, in_n, *k_size, device=weight.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LokrDiff(self.weights)
|
||||
@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||
if (
|
||||
(lokr_w1 is not None)
|
||||
or (lokr_w2 is not None)
|
||||
or (lokr_w1_a is not None)
|
||||
or (lokr_w2_a is not None)
|
||||
):
|
||||
weights = (
|
||||
lokr_w1,
|
||||
lokr_w2,
|
||||
alpha,
|
||||
lokr_w1_a,
|
||||
lokr_w1_b,
|
||||
lokr_w2_a,
|
||||
lokr_w2_b,
|
||||
lokr_t2,
|
||||
dora_scale,
|
||||
)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
w1 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w1_a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1_b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
w1 = comfy.model_management.cast_to_device(
|
||||
w1, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
w2 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
w2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t2, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
w2 = comfy.model_management.cast_to_device(
|
||||
w2, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoKr: efficient Kronecker product application.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/lokr.py bypass_forward_diff
|
||||
"""
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
alpha = v[2]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
|
||||
use_w1 = w1 is not None
|
||||
use_w2 = w2 is not None
|
||||
tucker = t2 is not None
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {}) if is_conv else {}
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[conv_dim + 2]
|
||||
else:
|
||||
op = F.linear
|
||||
|
||||
# Determine rank and scale
|
||||
rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha
|
||||
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||
self, "multiplier", 1.0
|
||||
)
|
||||
|
||||
# Build c (w1)
|
||||
if use_w1:
|
||||
c = w1.to(dtype=x.dtype)
|
||||
else:
|
||||
c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype)
|
||||
uq = c.size(1)
|
||||
|
||||
# Build w2 components
|
||||
if use_w2:
|
||||
ba = w2.to(dtype=x.dtype)
|
||||
else:
|
||||
a = w2_b.to(dtype=x.dtype)
|
||||
b = w2_a.to(dtype=x.dtype)
|
||||
if is_conv:
|
||||
if tucker:
|
||||
# Tucker: a, b get 1s appended (kernel is in t2)
|
||||
if a.dim() == 2:
|
||||
a = a.view(*a.shape, *([1] * conv_dim))
|
||||
if b.dim() == 2:
|
||||
b = b.view(*b.shape, *([1] * conv_dim))
|
||||
else:
|
||||
# Non-tucker conv: b may need 1s appended
|
||||
if b.dim() == 2:
|
||||
b = b.view(*b.shape, *([1] * conv_dim))
|
||||
|
||||
# Reshape input by uq groups
|
||||
if is_conv:
|
||||
B, _, *rest = x.shape
|
||||
h_in_group = x.reshape(B * uq, -1, *rest)
|
||||
else:
|
||||
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||
|
||||
# Apply w2 path
|
||||
if use_w2:
|
||||
hb = op(h_in_group, ba, **kw_dict)
|
||||
else:
|
||||
if is_conv:
|
||||
if tucker:
|
||||
t = t2.to(dtype=x.dtype)
|
||||
if t.dim() == 2:
|
||||
t = t.view(*t.shape, *([1] * conv_dim))
|
||||
ha = op(h_in_group, a)
|
||||
ht = op(ha, t, **kw_dict)
|
||||
hb = op(ht, b)
|
||||
else:
|
||||
ha = op(h_in_group, a, **kw_dict)
|
||||
hb = op(ha, b)
|
||||
else:
|
||||
ha = op(h_in_group, a)
|
||||
hb = op(ha, b)
|
||||
|
||||
# Reshape and apply c (w1)
|
||||
if is_conv:
|
||||
hb = hb.view(B, -1, *hb.shape[1:])
|
||||
h_cross_group = hb.transpose(1, -1)
|
||||
else:
|
||||
h_cross_group = hb.transpose(-1, -2)
|
||||
|
||||
hc = F.linear(h_cross_group, c)
|
||||
|
||||
if is_conv:
|
||||
hc = hc.transpose(1, -1)
|
||||
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||
else:
|
||||
hc = hc.transpose(-1, -2)
|
||||
out = hc.reshape(*hc.shape[:-2], -1)
|
||||
|
||||
return out * scale
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase):
|
||||
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||
if mid is not None:
|
||||
convdim = mid.ndim - 2
|
||||
layer = (
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d
|
||||
)[convdim]
|
||||
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
|
||||
else:
|
||||
layer = torch.nn.Linear
|
||||
self.lora_up = layer(rank, out_dim, bias=False)
|
||||
@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase):
|
||||
weight = w + scale * diff.reshape(w.shape)
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
|
||||
|
||||
Simple implementation using the nn.Module weights directly.
|
||||
No mid/dora/reshape branches (create_train doesn't create them).
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
# Compute scale = alpha / rank * multiplier
|
||||
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Get module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Get weights (keep in original dtype for numerical stability)
|
||||
down_weight = self.lora_down.weight
|
||||
up_weight = self.lora_up.weight
|
||||
|
||||
if is_conv:
|
||||
# Conv path: use functional conv
|
||||
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
# Reshape 2D weights to conv format if needed
|
||||
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
|
||||
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
|
||||
if down_weight.dim() == 2:
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
if in_channels is not None:
|
||||
down_weight = down_weight.view(
|
||||
down_weight.shape[0], in_channels, *kernel_size
|
||||
)
|
||||
else:
|
||||
# Fallback: assume 1x1 kernel
|
||||
down_weight = down_weight.view(
|
||||
*down_weight.shape, *([1] * conv_dim)
|
||||
)
|
||||
if up_weight.dim() == 2:
|
||||
# up always uses 1x1 kernel
|
||||
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
|
||||
|
||||
# down conv uses stride/padding from module, up is 1x1
|
||||
hidden = conv_fn(x, down_weight, **kw_dict)
|
||||
|
||||
# mid layer if exists (tucker decomposition)
|
||||
if self.lora_mid is not None:
|
||||
mid_weight = self.lora_mid.weight
|
||||
if mid_weight.dim() == 2:
|
||||
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
|
||||
hidden = conv_fn(hidden, mid_weight)
|
||||
|
||||
# up conv is always 1x1 (no stride/padding)
|
||||
out = conv_fn(hidden, up_weight)
|
||||
else:
|
||||
# Linear path: simple matmul chain
|
||||
hidden = F.linear(x, down_weight)
|
||||
|
||||
# mid layer if exists
|
||||
if self.lora_mid is not None:
|
||||
mid_weight = self.lora_mid.weight
|
||||
hidden = F.linear(hidden, mid_weight)
|
||||
|
||||
out = F.linear(hidden, up_weight)
|
||||
|
||||
return out * scale
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
(mat1, mat2, alpha, None, None, None)
|
||||
)
|
||||
return LoraDiff((mat1, mat2, alpha, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LoraDiff(self.weights)
|
||||
@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/locon.py bypass_forward_diff
|
||||
"""
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
|
||||
up = v[0]
|
||||
down = v[1]
|
||||
alpha = v[2]
|
||||
mid = v[3]
|
||||
|
||||
# Compute scale = alpha / rank
|
||||
rank = down.shape[0]
|
||||
if alpha is not None:
|
||||
scale = alpha / rank
|
||||
else:
|
||||
scale = 1.0
|
||||
scale = scale * getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Cast dtype
|
||||
up = up.to(dtype=x.dtype)
|
||||
down = down.to(dtype=x.dtype)
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[
|
||||
conv_dim + 2
|
||||
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
|
||||
# Reshape 2D weights to conv format using kernel_size
|
||||
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
|
||||
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
|
||||
if down.dim() == 2:
|
||||
# down.shape[1] = in_channels * prod(kernel_size)
|
||||
if in_channels is not None:
|
||||
down = down.view(down.shape[0], in_channels, *kernel_size)
|
||||
else:
|
||||
# Fallback: assume 1x1 kernel if in_channels unknown
|
||||
down = down.view(*down.shape, *([1] * conv_dim))
|
||||
if up.dim() == 2:
|
||||
# up always uses 1x1 kernel
|
||||
up = up.view(*up.shape, *([1] * conv_dim))
|
||||
if mid is not None:
|
||||
mid = mid.to(dtype=x.dtype)
|
||||
if mid.dim() == 2:
|
||||
mid = mid.view(*mid.shape, *([1] * conv_dim))
|
||||
else:
|
||||
op = F.linear
|
||||
kw_dict = {} # linear doesn't take stride/padding
|
||||
|
||||
# Simple chain: down -> mid (if tucker) -> up
|
||||
if mid is not None:
|
||||
if not is_conv:
|
||||
mid = mid.to(dtype=x.dtype)
|
||||
hidden = op(x, down)
|
||||
hidden = op(hidden, mid, **kw_dict)
|
||||
out = op(hidden, up)
|
||||
else:
|
||||
hidden = op(x, down, **kw_dict)
|
||||
out = op(hidden, up)
|
||||
|
||||
return out * scale
|
||||
|
||||
@ -3,13 +3,18 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
factorization,
|
||||
)
|
||||
|
||||
|
||||
class OFTDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
# Unpack weights tuple from OFTAdapter
|
||||
blocks, rescale, alpha, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase):
|
||||
weight = self.rescale * weight
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def _get_orthogonal_matrix(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||
blocks = self.oft_blocks.to(device=device, dtype=dtype)
|
||||
I = torch.eye(self.block_size, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if set
|
||||
if self.constraint:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > self.constraint:
|
||||
normed_q = q * self.constraint / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r.to(dtype)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
OFT has no additive component - returns zeros matching base_out shape.
|
||||
|
||||
OFT only transforms the output via g(), it doesn't add to it.
|
||||
"""
|
||||
return torch.zeros_like(base_out)
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for OFT: applies orthogonal rotation.
|
||||
|
||||
OFT transforms output channels using block-diagonal orthogonal matrices.
|
||||
"""
|
||||
r = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||
|
||||
# Apply multiplier to interpolate between identity and full transform
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
|
||||
r = r * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Use module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# y now has channels in last dim
|
||||
*batch_shape, out_features = y.shape
|
||||
|
||||
# Reshape to apply block-diagonal transform
|
||||
# (*, out_features) -> (*, block_num, block_size)
|
||||
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
|
||||
|
||||
# Apply orthogonal transform: R @ y for each block
|
||||
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||
|
||||
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||
out = out_blocked.reshape(*batch_shape, out_features)
|
||||
|
||||
# Apply rescale if present
|
||||
if self.rescaled:
|
||||
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
|
||||
out = out * rescale.view(-1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
out = out.transpose(1, -1)
|
||||
|
||||
return out
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
block = torch.zeros(
|
||||
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
|
||||
)
|
||||
return OFTDiff((block, None, alpha, None))
|
||||
|
||||
def to_train(self):
|
||||
return OFTDiff(self.weights)
|
||||
@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase):
|
||||
alpha = 0
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
blocks = comfy.model_management.cast_to_device(
|
||||
blocks, weight.device, intermediate_dtype
|
||||
)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
rescale = comfy.model_management.cast_to_device(
|
||||
rescale, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
block_num, block_size, *_ = blocks.shape
|
||||
|
||||
@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase):
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(weight)
|
||||
# Create I in weight's dtype for the einsum
|
||||
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
|
||||
_, *shape = weight.shape
|
||||
lora_diff = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
(r * strength) - strength * I,
|
||||
(r * strength) - strength * I_w,
|
||||
weight.view(block_num, block_size, *shape),
|
||||
).view(-1, *shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _get_orthogonal_matrix(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||
v = self.weights
|
||||
blocks = v[0].to(device=device, dtype=dtype)
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
|
||||
block_num, block_size, _ = blocks.shape
|
||||
I = torch.eye(block_size, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if alpha > 0
|
||||
if alpha > 0:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r, block_num, block_size
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for OFT: applies orthogonal rotation to output.
|
||||
|
||||
OFT transforms the output channels using block-diagonal orthogonal matrices.
|
||||
|
||||
Reference: LyCORIS DiagOFTModule._bypass_forward
|
||||
"""
|
||||
v = self.weights
|
||||
rescale = v[1]
|
||||
|
||||
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||
|
||||
# Apply multiplier to interpolate between identity and full transform
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
|
||||
r = r * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Use module info from bypass injection to determine conv vs linear
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# y now has channels in last dim
|
||||
*batch_shape, out_features = y.shape
|
||||
|
||||
# Reshape to apply block-diagonal transform
|
||||
# (*, out_features) -> (*, block_num, block_size)
|
||||
y_blocked = y.view(*batch_shape, block_num, block_size)
|
||||
|
||||
# Apply orthogonal transform: R @ y for each block
|
||||
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||
|
||||
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||
out = out_blocked.view(*batch_shape, out_features)
|
||||
|
||||
# Apply rescale if present
|
||||
if rescale is not None:
|
||||
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||
out = out * rescale.view(-1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
out = out.transpose(1, -1)
|
||||
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user