mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 22:00:19 +08:00
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
328 lines
11 KiB
Python
328 lines
11 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import comfy.model_management
|
|
from .base import (
|
|
WeightAdapterBase,
|
|
WeightAdapterTrainBase,
|
|
weight_decompose,
|
|
factorization,
|
|
)
|
|
|
|
|
|
class OFTDiff(WeightAdapterTrainBase):
|
|
def __init__(self, weights):
|
|
super().__init__()
|
|
# Unpack weights tuple from OFTAdapter
|
|
blocks, rescale, alpha, _ = weights
|
|
|
|
# Create trainable parameters
|
|
self.oft_blocks = torch.nn.Parameter(blocks)
|
|
if rescale is not None:
|
|
self.rescale = torch.nn.Parameter(rescale)
|
|
self.rescaled = True
|
|
else:
|
|
self.rescaled = False
|
|
self.block_num, self.block_size, _ = blocks.shape
|
|
self.constraint = float(alpha)
|
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
|
|
|
def __call__(self, w):
|
|
org_dtype = w.dtype
|
|
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
|
|
|
## generate r
|
|
# for Q = -Q^T
|
|
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
|
normed_q = q
|
|
if self.constraint:
|
|
q_norm = torch.norm(q) + 1e-8
|
|
if q_norm > self.constraint:
|
|
normed_q = q * self.constraint / q_norm
|
|
# use float() to prevent unsupported type
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
|
|
## Apply chunked matmul on weight
|
|
_, *shape = w.shape
|
|
org_weight = w.to(dtype=r.dtype)
|
|
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
|
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
|
weight = torch.einsum(
|
|
"k n m, k n ... -> k m ...",
|
|
r,
|
|
org_weight,
|
|
).flatten(0, 1)
|
|
if self.rescaled:
|
|
weight = self.rescale * weight
|
|
return weight.to(org_dtype)
|
|
|
|
def _get_orthogonal_matrix(self, device, dtype):
|
|
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
|
blocks = self.oft_blocks.to(device=device, dtype=dtype)
|
|
I = torch.eye(self.block_size, device=device, dtype=dtype)
|
|
|
|
# Q = blocks - blocks^T (skew-symmetric)
|
|
q = blocks - blocks.transpose(1, 2)
|
|
normed_q = q
|
|
|
|
# Apply constraint if set
|
|
if self.constraint:
|
|
q_norm = torch.norm(q) + 1e-8
|
|
if q_norm > self.constraint:
|
|
normed_q = q * self.constraint / q_norm
|
|
|
|
# Cayley transform: R = (I + Q)(I - Q)^-1
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
return r.to(dtype)
|
|
|
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
OFT has no additive component - returns zeros matching base_out shape.
|
|
|
|
OFT only transforms the output via g(), it doesn't add to it.
|
|
"""
|
|
return torch.zeros_like(base_out)
|
|
|
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Output transformation for OFT: applies orthogonal rotation.
|
|
|
|
OFT transforms output channels using block-diagonal orthogonal matrices.
|
|
"""
|
|
r = self._get_orthogonal_matrix(y.device, y.dtype)
|
|
|
|
# Apply multiplier to interpolate between identity and full transform
|
|
multiplier = getattr(self, "multiplier", 1.0)
|
|
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
|
|
r = r * multiplier + (1 - multiplier) * I
|
|
|
|
# Use module info from bypass injection
|
|
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
|
|
|
if is_conv:
|
|
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
|
y = y.transpose(1, -1)
|
|
|
|
# y now has channels in last dim
|
|
*batch_shape, out_features = y.shape
|
|
|
|
# Reshape to apply block-diagonal transform
|
|
# (*, out_features) -> (*, block_num, block_size)
|
|
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
|
|
|
|
# Apply orthogonal transform: R @ y for each block
|
|
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
|
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
|
|
|
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
|
out = out_blocked.reshape(*batch_shape, out_features)
|
|
|
|
# Apply rescale if present
|
|
if self.rescaled:
|
|
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
|
|
out = out * rescale.view(-1)
|
|
|
|
if is_conv:
|
|
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
|
out = out.transpose(1, -1)
|
|
|
|
return out
|
|
|
|
def passive_memory_usage(self):
|
|
"""Calculates memory usage of the trainable parameters."""
|
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
|
|
|
|
|
class OFTAdapter(WeightAdapterBase):
|
|
name = "oft"
|
|
|
|
def __init__(self, loaded_keys, weights):
|
|
self.loaded_keys = loaded_keys
|
|
self.weights = weights
|
|
|
|
@classmethod
|
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
|
out_dim = weight.shape[0]
|
|
block_size, block_num = factorization(out_dim, rank)
|
|
block = torch.zeros(
|
|
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
|
|
)
|
|
return OFTDiff((block, None, alpha, None))
|
|
|
|
def to_train(self):
|
|
return OFTDiff(self.weights)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
loaded_keys: set[str] = None,
|
|
) -> Optional["OFTAdapter"]:
|
|
if loaded_keys is None:
|
|
loaded_keys = set()
|
|
blocks_name = "{}.oft_blocks".format(x)
|
|
rescale_name = "{}.rescale".format(x)
|
|
|
|
blocks = None
|
|
if blocks_name in lora.keys():
|
|
blocks = lora[blocks_name]
|
|
if blocks.ndim == 3:
|
|
loaded_keys.add(blocks_name)
|
|
else:
|
|
blocks = None
|
|
if blocks is None:
|
|
return None
|
|
|
|
rescale = None
|
|
if rescale_name in lora.keys():
|
|
rescale = lora[rescale_name]
|
|
loaded_keys.add(rescale_name)
|
|
|
|
weights = (blocks, rescale, alpha, dora_scale)
|
|
return cls(loaded_keys, weights)
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
v = self.weights
|
|
blocks = v[0]
|
|
rescale = v[1]
|
|
alpha = v[2]
|
|
if alpha is None:
|
|
alpha = 0
|
|
dora_scale = v[3]
|
|
|
|
blocks = comfy.model_management.cast_to_device(
|
|
blocks, weight.device, intermediate_dtype
|
|
)
|
|
if rescale is not None:
|
|
rescale = comfy.model_management.cast_to_device(
|
|
rescale, weight.device, intermediate_dtype
|
|
)
|
|
|
|
block_num, block_size, *_ = blocks.shape
|
|
|
|
try:
|
|
# Get r
|
|
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
|
# for Q = -Q^T
|
|
q = blocks - blocks.transpose(1, 2)
|
|
normed_q = q
|
|
if alpha > 0: # alpha in oft/boft is for constraint
|
|
q_norm = torch.norm(q) + 1e-8
|
|
if q_norm > alpha:
|
|
normed_q = q * alpha / q_norm
|
|
# use float() to prevent unsupported type in .inverse()
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
r = r.to(weight)
|
|
# Create I in weight's dtype for the einsum
|
|
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
|
|
_, *shape = weight.shape
|
|
lora_diff = torch.einsum(
|
|
"k n m, k n ... -> k m ...",
|
|
(r * strength) - strength * I_w,
|
|
weight.view(block_num, block_size, *shape),
|
|
).view(-1, *shape)
|
|
if dora_scale is not None:
|
|
weight = weight_decompose(
|
|
dora_scale,
|
|
weight,
|
|
lora_diff,
|
|
alpha,
|
|
strength,
|
|
intermediate_dtype,
|
|
function,
|
|
)
|
|
else:
|
|
weight += function((strength * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
return weight
|
|
|
|
def _get_orthogonal_matrix(self, device, dtype):
|
|
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
|
v = self.weights
|
|
blocks = v[0].to(device=device, dtype=dtype)
|
|
alpha = v[2]
|
|
if alpha is None:
|
|
alpha = 0
|
|
|
|
block_num, block_size, _ = blocks.shape
|
|
I = torch.eye(block_size, device=device, dtype=dtype)
|
|
|
|
# Q = blocks - blocks^T (skew-symmetric)
|
|
q = blocks - blocks.transpose(1, 2)
|
|
normed_q = q
|
|
|
|
# Apply constraint if alpha > 0
|
|
if alpha > 0:
|
|
q_norm = torch.norm(q) + 1e-8
|
|
if q_norm > alpha:
|
|
normed_q = q * alpha / q_norm
|
|
|
|
# Cayley transform: R = (I + Q)(I - Q)^-1
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
return r, block_num, block_size
|
|
|
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Output transformation for OFT: applies orthogonal rotation to output.
|
|
|
|
OFT transforms the output channels using block-diagonal orthogonal matrices.
|
|
|
|
Reference: LyCORIS DiagOFTModule._bypass_forward
|
|
"""
|
|
v = self.weights
|
|
rescale = v[1]
|
|
|
|
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
|
|
|
|
# Apply multiplier to interpolate between identity and full transform
|
|
multiplier = getattr(self, "multiplier", 1.0)
|
|
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
|
|
r = r * multiplier + (1 - multiplier) * I
|
|
|
|
# Use module info from bypass injection to determine conv vs linear
|
|
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
|
|
|
if is_conv:
|
|
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
|
y = y.transpose(1, -1)
|
|
|
|
# y now has channels in last dim
|
|
*batch_shape, out_features = y.shape
|
|
|
|
# Reshape to apply block-diagonal transform
|
|
# (*, out_features) -> (*, block_num, block_size)
|
|
y_blocked = y.view(*batch_shape, block_num, block_size)
|
|
|
|
# Apply orthogonal transform: R @ y for each block
|
|
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
|
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
|
|
|
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
|
out = out_blocked.view(*batch_shape, out_features)
|
|
|
|
# Apply rescale if present
|
|
if rescale is not None:
|
|
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
|
out = out * rescale.view(-1)
|
|
|
|
if is_conv:
|
|
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
|
out = out.transpose(1, -1)
|
|
|
|
return out
|