from typing import Callable, Optional import torch import torch.nn as nn import comfy.model_management class WeightAdapterBase: """ Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.) Bypass Mode: All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) - h(x): Additive component (LoRA path). Returns delta to add to base output. - g(y): Output transformation. Applied after base + h(x). For LoRA/LoHa/LoKr: g = identity, h = adapter(x) For OFT/BOFT: g = transform, h = 0 """ name: str loaded_keys: set[str] weights: list[torch.Tensor] # Attributes set by bypass system multiplier: float = 1.0 shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel) @classmethod def load( cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor, ) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": raise NotImplementedError @classmethod def create_train(cls, weight, *args) -> "WeightAdapterTrainBase": """ weight: The original weight tensor to be modified. *args: Additional arguments for configuration, such as rank, alpha etc. """ raise NotImplementedError def calculate_weight( self, weight, key, strength, strength_model, offset, function, intermediate_dtype=torch.float32, original_weight=None, ): raise NotImplementedError # ===== Bypass Mode Methods ===== # # IMPORTANT: Bypass mode is designed for quantized models where original weights # may not be accessible in a usable format. Therefore, h() and bypass_forward() # do NOT take org_weight as a parameter. All necessary information (out_channels, # in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook. def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ Additive bypass component: h(x, base_out) Computes the adapter's contribution to be added to base forward output. For adapters that only transform output (OFT/BOFT), returns zeros. Note: This method does NOT access original model weights. Bypass mode is designed for quantized models where weights may not be in a usable format. All shape info comes from module attributes set by BypassForwardHook. Args: x: Input tensor base_out: Output from base forward f(x), can be used for shape reference Returns: Delta tensor to add to base output. Shape matches base output. Reference: LyCORIS LoConModule.bypass_forward_diff """ # Default: no additive component (for OFT/BOFT) # Simply return zeros matching base_out shape return torch.zeros_like(base_out) def g(self, y: torch.Tensor) -> torch.Tensor: """ Output transformation: g(y) Applied after base forward + h(x). For most adapters this is identity. OFT/BOFT override this to apply orthogonal transformation. Args: y: Combined output (base + h(x)) Returns: Transformed output Reference: LyCORIS OFTModule applies orthogonal transform here """ # Default: identity (for LoRA/LoHa/LoKr) return y def bypass_forward( self, org_forward: Callable, x: torch.Tensor, *args, **kwargs, ) -> torch.Tensor: """ Full bypass forward: g(f(x) + h(x, f(x))) Note: This method does NOT take org_weight/org_bias parameters. Bypass mode is designed for quantized models where weights may not be accessible. The original forward function handles weight access internally. Args: org_forward: Original module forward function x: Input tensor *args, **kwargs: Additional arguments for org_forward Returns: Output with adapter applied in bypass mode Reference: LyCORIS LoConModule.bypass_forward """ # Base forward: f(x) base_out = org_forward(x, *args, **kwargs) # Additive component: h(x, base_out) - base_out provided for shape reference h_out = self.h(x, base_out) # Output transformation: g(base + h) return self.g(base_out + h_out) class WeightAdapterTrainBase(nn.Module): """ Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.) Bypass Mode: All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) - h(x): Additive component (LoRA path). Returns delta to add to base output. - g(y): Output transformation. Applied after base + h(x). For LoRA/LoHa/LoKr: g = identity, h = adapter(x) For OFT: g = transform, h = 0 Note: Unlike WeightAdapterBase, TrainBase classes have simplified weight formats with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition). We follow the scheme of PR #7032 """ # Attributes set by bypass system (BypassForwardHook) # These are set before h()/g()/bypass_forward() are called multiplier: float = 1.0 is_conv: bool = False conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups kernel_size: tuple = () in_channels: int = None out_channels: int = None def __init__(self): super().__init__() def __call__(self, w): """ Weight modification mode: returns modified weight. Args: w: The original weight tensor to be modified. Returns: Modified weight tensor. """ raise NotImplementedError # ===== Bypass Mode Methods ===== def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ Additive bypass component: h(x, base_out) Computes the adapter's contribution to be added to base forward output. For adapters that only transform output (OFT), returns zeros. Args: x: Input tensor base_out: Output from base forward f(x), can be used for shape reference Returns: Delta tensor to add to base output. Shape matches base output. Subclasses should override this method. """ raise NotImplementedError( f"{self.__class__.__name__}.h() not implemented. " "Subclasses must implement h() for bypass mode." ) def g(self, y: torch.Tensor) -> torch.Tensor: """ Output transformation: g(y) Applied after base forward + h(x). For most adapters this is identity. OFT overrides this to apply orthogonal transformation. Args: y: Combined output (base + h(x)) Returns: Transformed output """ # Default: identity (for LoRA/LoHa/LoKr) return y def bypass_forward( self, org_forward: Callable, x: torch.Tensor, *args, **kwargs, ) -> torch.Tensor: """ Full bypass forward: g(f(x) + h(x, f(x))) Args: org_forward: Original module forward function x: Input tensor *args, **kwargs: Additional arguments for org_forward Returns: Output with adapter applied in bypass mode """ # Base forward: f(x) base_out = org_forward(x, *args, **kwargs) # Additive component: h(x, base_out) - base_out provided for shape reference h_out = self.h(x, base_out) # Output transformation: g(base + h) return self.g(base_out + h_out) def passive_memory_usage(self): raise NotImplementedError("passive_memory_usage is not implemented") def move_to(self, device): self.to(device) return self.passive_memory_usage() def weight_decompose( dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function ): dora_scale = comfy.model_management.cast_to_device( dora_scale, weight.device, intermediate_dtype ) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0] if wd_on_output_axis: weight_norm = ( weight.reshape(weight.shape[0], -1) .norm(dim=1, keepdim=True) .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) ) else: weight_norm = ( weight_calc.transpose(0, 1) .reshape(weight_calc.shape[1], -1) .norm(dim=1, keepdim=True) .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) .transpose(0, 1) ) weight_norm = weight_norm + torch.finfo(weight.dtype).eps weight_calc *= (dora_scale / weight_norm).type(weight.dtype) if strength != 1.0: weight_calc -= weight weight += strength * (weight_calc) else: weight[:] = weight_calc return weight def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: """ Pad a tensor to a new shape with zeros. Args: tensor (torch.Tensor): The original tensor to be padded. new_shape (List[int]): The desired shape of the padded tensor. Returns: torch.Tensor: A new tensor padded with zeros to the specified shape. Note: If the new shape is smaller than the original tensor in any dimension, the original tensor will be truncated in that dimension. """ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): raise ValueError( "The new shape must be larger than the original tensor in all dimensions" ) if len(new_shape) != len(tensor.shape): raise ValueError( "The new shape must have the same number of dimensions as the original tensor" ) # Create a new tensor filled with zeros padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) # Create slicing tuples for both tensors orig_slices = tuple(slice(0, dim) for dim in tensor.shape) new_slices = tuple(slice(0, dim) for dim in tensor.shape) # Copy the original tensor into the new tensor padded_tensor[new_slices] = tensor[orig_slices] return padded_tensor def tucker_weight_from_conv(up, down, mid): up = up.reshape(up.size(0), up.size(1)) down = down.reshape(down.size(0), down.size(1)) return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) def tucker_weight(wa, wb, t): temp = torch.einsum("i j ..., j r -> i r ...", t, wb) return torch.einsum("i j ..., i r -> r j ...", temp, wa) def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: """ return a tuple of two value of input dimension decomposed by the number closest to factor second value is higher or equal than first value. examples) factor -1 2 4 8 16 ... 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 """ if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2: m = factor n = dimension // factor if m > n: n, m = m, n return m, n if factor < 0: factor = dimension m, n = 1, dimension length = m + n while m < n: new_m = m + 1 while dimension % new_m != 0: new_m += 1 new_n = dimension // new_m if new_m + new_n > length or new_m > factor: break else: m, n = new_m, new_n if m > n: n, m = m, n return m, n