mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 22:00:19 +08:00
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
362 lines
13 KiB
Python
362 lines
13 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import comfy.model_management
|
|
from .base import (
|
|
WeightAdapterBase,
|
|
WeightAdapterTrainBase,
|
|
weight_decompose,
|
|
pad_tensor_to_shape,
|
|
tucker_weight_from_conv,
|
|
)
|
|
|
|
|
|
class LoraDiff(WeightAdapterTrainBase):
|
|
def __init__(self, weights):
|
|
super().__init__()
|
|
mat1, mat2, alpha, mid, dora_scale, reshape = weights
|
|
out_dim, rank = mat1.shape[0], mat1.shape[1]
|
|
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
|
if mid is not None:
|
|
convdim = mid.ndim - 2
|
|
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
|
|
else:
|
|
layer = torch.nn.Linear
|
|
self.lora_up = layer(rank, out_dim, bias=False)
|
|
self.lora_down = layer(in_dim, rank, bias=False)
|
|
self.lora_up.weight.data.copy_(mat1)
|
|
self.lora_down.weight.data.copy_(mat2)
|
|
if mid is not None:
|
|
self.lora_mid = layer(mid, rank, bias=False)
|
|
self.lora_mid.weight.data.copy_(mid)
|
|
else:
|
|
self.lora_mid = None
|
|
self.rank = rank
|
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
|
|
|
def __call__(self, w):
|
|
org_dtype = w.dtype
|
|
if self.lora_mid is None:
|
|
diff = self.lora_up.weight @ self.lora_down.weight
|
|
else:
|
|
diff = tucker_weight_from_conv(
|
|
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
|
|
)
|
|
scale = self.alpha / self.rank
|
|
weight = w + scale * diff.reshape(w.shape)
|
|
return weight.to(org_dtype)
|
|
|
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
|
|
|
|
Simple implementation using the nn.Module weights directly.
|
|
No mid/dora/reshape branches (create_train doesn't create them).
|
|
|
|
Args:
|
|
x: Input tensor
|
|
base_out: Output from base forward (unused, for API consistency)
|
|
"""
|
|
# Compute scale = alpha / rank * multiplier
|
|
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
|
|
|
|
# Get module info from bypass injection
|
|
is_conv = getattr(self, "is_conv", False)
|
|
conv_dim = getattr(self, "conv_dim", 0)
|
|
kw_dict = getattr(self, "kw_dict", {})
|
|
|
|
# Get weights (keep in original dtype for numerical stability)
|
|
down_weight = self.lora_down.weight
|
|
up_weight = self.lora_up.weight
|
|
|
|
if is_conv:
|
|
# Conv path: use functional conv
|
|
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
|
|
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
|
|
|
# Reshape 2D weights to conv format if needed
|
|
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
|
|
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
|
|
if down_weight.dim() == 2:
|
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
|
in_channels = getattr(self, "in_channels", None)
|
|
if in_channels is not None:
|
|
down_weight = down_weight.view(
|
|
down_weight.shape[0], in_channels, *kernel_size
|
|
)
|
|
else:
|
|
# Fallback: assume 1x1 kernel
|
|
down_weight = down_weight.view(
|
|
*down_weight.shape, *([1] * conv_dim)
|
|
)
|
|
if up_weight.dim() == 2:
|
|
# up always uses 1x1 kernel
|
|
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
|
|
|
|
# down conv uses stride/padding from module, up is 1x1
|
|
hidden = conv_fn(x, down_weight, **kw_dict)
|
|
|
|
# mid layer if exists (tucker decomposition)
|
|
if self.lora_mid is not None:
|
|
mid_weight = self.lora_mid.weight
|
|
if mid_weight.dim() == 2:
|
|
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
|
|
hidden = conv_fn(hidden, mid_weight)
|
|
|
|
# up conv is always 1x1 (no stride/padding)
|
|
out = conv_fn(hidden, up_weight)
|
|
else:
|
|
# Linear path: simple matmul chain
|
|
hidden = F.linear(x, down_weight)
|
|
|
|
# mid layer if exists
|
|
if self.lora_mid is not None:
|
|
mid_weight = self.lora_mid.weight
|
|
hidden = F.linear(hidden, mid_weight)
|
|
|
|
out = F.linear(hidden, up_weight)
|
|
|
|
return out * scale
|
|
|
|
def passive_memory_usage(self):
|
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
|
|
|
|
|
class LoRAAdapter(WeightAdapterBase):
|
|
name = "lora"
|
|
|
|
def __init__(self, loaded_keys, weights):
|
|
self.loaded_keys = loaded_keys
|
|
self.weights = weights
|
|
|
|
@classmethod
|
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
|
out_dim = weight.shape[0]
|
|
in_dim = weight.shape[1:].numel()
|
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
|
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
|
torch.nn.init.constant_(mat2, 0.0)
|
|
return LoraDiff((mat1, mat2, alpha, None, None, None))
|
|
|
|
def to_train(self):
|
|
return LoraDiff(self.weights)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
loaded_keys: set[str] = None,
|
|
) -> Optional["LoRAAdapter"]:
|
|
if loaded_keys is None:
|
|
loaded_keys = set()
|
|
|
|
reshape_name = "{}.reshape_weight".format(x)
|
|
regular_lora = "{}.lora_up.weight".format(x)
|
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
mochi_lora = "{}.lora_B".format(x)
|
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
qwen_default_lora = "{}.lora_B.default.weight".format(x)
|
|
A_name = None
|
|
|
|
if regular_lora in lora.keys():
|
|
A_name = regular_lora
|
|
B_name = "{}.lora_down.weight".format(x)
|
|
mid_name = "{}.lora_mid.weight".format(x)
|
|
elif diffusers_lora in lora.keys():
|
|
A_name = diffusers_lora
|
|
B_name = "{}_lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers2_lora in lora.keys():
|
|
A_name = diffusers2_lora
|
|
B_name = "{}.lora_A.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers3_lora in lora.keys():
|
|
A_name = diffusers3_lora
|
|
B_name = "{}.lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif mochi_lora in lora.keys():
|
|
A_name = mochi_lora
|
|
B_name = "{}.lora_A".format(x)
|
|
mid_name = None
|
|
elif transformers_lora in lora.keys():
|
|
A_name = transformers_lora
|
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
|
mid_name = None
|
|
elif qwen_default_lora in lora.keys():
|
|
A_name = qwen_default_lora
|
|
B_name = "{}.lora_A.default.weight".format(x)
|
|
mid_name = None
|
|
|
|
if A_name is not None:
|
|
mid = None
|
|
if mid_name is not None and mid_name in lora.keys():
|
|
mid = lora[mid_name]
|
|
loaded_keys.add(mid_name)
|
|
reshape = None
|
|
if reshape_name in lora.keys():
|
|
try:
|
|
reshape = lora[reshape_name].tolist()
|
|
loaded_keys.add(reshape_name)
|
|
except:
|
|
pass
|
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
|
loaded_keys.add(A_name)
|
|
loaded_keys.add(B_name)
|
|
return cls(loaded_keys, weights)
|
|
else:
|
|
return None
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
v = self.weights
|
|
mat1 = comfy.model_management.cast_to_device(
|
|
v[0], weight.device, intermediate_dtype
|
|
)
|
|
mat2 = comfy.model_management.cast_to_device(
|
|
v[1], weight.device, intermediate_dtype
|
|
)
|
|
dora_scale = v[4]
|
|
reshape = v[5]
|
|
|
|
if reshape is not None:
|
|
weight = pad_tensor_to_shape(weight, reshape)
|
|
|
|
if v[2] is not None:
|
|
alpha = v[2] / mat2.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
if v[3] is not None:
|
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
mat3 = comfy.model_management.cast_to_device(
|
|
v[3], weight.device, intermediate_dtype
|
|
)
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
mat2 = (
|
|
torch.mm(
|
|
mat2.transpose(0, 1).flatten(start_dim=1),
|
|
mat3.transpose(0, 1).flatten(start_dim=1),
|
|
)
|
|
.reshape(final_shape)
|
|
.transpose(0, 1)
|
|
)
|
|
try:
|
|
lora_diff = torch.mm(
|
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
|
).reshape(weight.shape)
|
|
del mat1, mat2
|
|
if dora_scale is not None:
|
|
weight = weight_decompose(
|
|
dora_scale,
|
|
weight,
|
|
lora_diff,
|
|
alpha,
|
|
strength,
|
|
intermediate_dtype,
|
|
function,
|
|
)
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
return weight
|
|
|
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
|
|
|
|
Note:
|
|
Does not access original model weights - bypass mode is designed
|
|
for quantized models where weights may not be accessible.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
base_out: Output from base forward (unused, for API consistency)
|
|
|
|
Reference: LyCORIS functional/locon.py bypass_forward_diff
|
|
"""
|
|
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
|
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
|
|
|
v = self.weights
|
|
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
|
|
up = v[0]
|
|
down = v[1]
|
|
alpha = v[2]
|
|
mid = v[3]
|
|
|
|
# Compute scale = alpha / rank
|
|
rank = down.shape[0]
|
|
if alpha is not None:
|
|
scale = alpha / rank
|
|
else:
|
|
scale = 1.0
|
|
scale = scale * getattr(self, "multiplier", 1.0)
|
|
|
|
# Cast dtype
|
|
up = up.to(dtype=x.dtype)
|
|
down = down.to(dtype=x.dtype)
|
|
|
|
# Use module info from bypass injection, not weight dimension
|
|
is_conv = getattr(self, "is_conv", False)
|
|
conv_dim = getattr(self, "conv_dim", 0)
|
|
kw_dict = getattr(self, "kw_dict", {})
|
|
|
|
if is_conv:
|
|
op = FUNC_LIST[
|
|
conv_dim + 2
|
|
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
|
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
|
in_channels = getattr(self, "in_channels", None)
|
|
|
|
# Reshape 2D weights to conv format using kernel_size
|
|
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
|
|
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
|
|
if down.dim() == 2:
|
|
# down.shape[1] = in_channels * prod(kernel_size)
|
|
if in_channels is not None:
|
|
down = down.view(down.shape[0], in_channels, *kernel_size)
|
|
else:
|
|
# Fallback: assume 1x1 kernel if in_channels unknown
|
|
down = down.view(*down.shape, *([1] * conv_dim))
|
|
if up.dim() == 2:
|
|
# up always uses 1x1 kernel
|
|
up = up.view(*up.shape, *([1] * conv_dim))
|
|
if mid is not None:
|
|
mid = mid.to(dtype=x.dtype)
|
|
if mid.dim() == 2:
|
|
mid = mid.view(*mid.shape, *([1] * conv_dim))
|
|
else:
|
|
op = F.linear
|
|
kw_dict = {} # linear doesn't take stride/padding
|
|
|
|
# Simple chain: down -> mid (if tucker) -> up
|
|
if mid is not None:
|
|
if not is_conv:
|
|
mid = mid.to(dtype=x.dtype)
|
|
hidden = op(x, down)
|
|
hidden = op(hidden, mid, **kw_dict)
|
|
out = op(hidden, up)
|
|
else:
|
|
hidden = op(x, down, **kw_dict)
|
|
out = op(hidden, up)
|
|
|
|
return out * scale
|